Net construction
Net construction
AutoEncoder
AutoEncoder
Based on these networks:
(*NetModel[{"Stable Diffusion V1","Part"->"Encoder"}];NetModel[{"Stable Diffusion V1","Part"->"Decoder"}];*)
Separate pieces:
In[]:=
conv[channels_,size_:{3,3},pad_:1]:=ConvolutionLayer[channels,size,PaddingSize->pad]groupNorm[input_,groups_:32]:=NetChain[{ReshapeLayer[MapAt[Splice[{groups,#/groups}]&,input,{1}]],NormalizationLayer[2;;,;;2,"Epsilon"->1*^-6],ReshapeLayer[input]},"Input"->input]convBlock[input_,groups_:32]:=NetFlatten[NetGraph[{{groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]],groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]]},ThreadingLayer[Plus]},{NetPort["Input"]->1,{NetPort["Input"],1}->2}],1]downBlock[input_,groups_:32]:=NetFlatten[NetGraph[{{groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]],groupNorm[input,groups],ElementwiseLayer["Swish"],conv[2First[input]]},conv[2First[input],{1,1},0],ThreadingLayer[Plus]},{NetPort["Input"]->{1,2}->3}],1]upBlock[input_,groups_:32]:=NetFlatten[NetGraph[{{groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]],groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]/2]},conv[First[input]/2,{1,1},0],ThreadingLayer[Plus]},{NetPort["Input"]->{1,2}->3}],1]attention[dim_]:=NetGraph[{"key"->NetMapOperator[LinearLayer[dim]],"value"->NetMapOperator[LinearLayer[dim]],"query"->NetMapOperator[LinearLayer[dim]],"attention"->AttentionLayer["Dot","ScoreRescaling""DimensionSqrt"],"output"->NetMapOperator[LinearLayer[dim]]},{NetPort["Input"]->{"key","value","query"}->"attention"->"output"},"Input"->{"Varying",dim}]attentionBlock[input_,groups_:32]:=NetFlatten[NetGraph[{{groupNorm[input,groups],FlattenLayer[-1],TransposeLayer[],attention[First@input],TransposeLayer[],ReshapeLayer[input]},ThreadingLayer[Plus]},{NetPort["Input"]->1,{NetPort["Input"],1}->2}],1]downsample[input_]:=ConvolutionLayer[First@input,{3,3},"Stride"->2,PaddingSize->{{0,1},{0,1}}]upsample[input_]:=NetChain[{ResizeLayer[{Scaled[2],Scaled[2]},Resampling->"Nearest","Scheme"->"Bin"],conv[First@input]},"Input"->input]
Encoder and Decoder networks:
In[]:=
encoder[input_,outChannels_:4,convChannels_:128,blocks_:5,defGroups_:Automatic,simple_:False]:=Enclose@Block[{net,groups=Replace[defGroups,Automatic:>convChannels/4]},net=NetChain[{conv[convChannels]},"Input"->input];Do[If[!simple,net=NetAppend[net,If[i==1||i>=blocks-1,convBlock,downBlock][NetExtract[net,"Output"],groups]]];If[i==blocks,net=NetAppend[net,attentionBlock[NetExtract[net,"Output"],groups]]];If[!simple,net=NetAppend[net,convBlock[NetExtract[net,"Output"],groups]]];If[i<blocks-1,net=NetAppend[net,downsample[NetExtract[net,"Output"]]];If[simple,net=NetAppend[net,ElementwiseLayer["Swish"]]]],{i,blocks}];net=NetAppend[net,{groupNorm[NetExtract[net,"Output"],groups],conv[2outChannels],conv[2outChannels,1,0],NetGraph[{PartLayer[;;outChannels],PartLayer[outChannels+1;;]},{NetPort["Input"]->{1,2},1->NetPort["Mean"],2->NetPort["LogVar"]}]}];net]decoder[input_,outChannels_:1,convChannels_:128,blocks_:5,defGroups_:Automatic,simple_:False]:=Enclose@Block[{net,groups=Replace[defGroups,Automatic:>convChannels/16]},net=NetChain[{conv[First[input],{1,1},0],conv[convChannels]},"Input"->input];Do[If[!simple,net=NetAppend[net,If[i>2,upBlock,convBlock][NetExtract[net,"Output"],groups]]];If[i==1,net=NetAppend[net,attentionBlock[NetExtract[net,"Output"],groups]]];If[!simple,net=NetAppend[net,{convBlock[NetExtract[net,"Output"],groups],convBlock[NetExtract[net,"Output"],groups]}]];If[1<i<blocks,net=NetAppend[net,upsample[NetExtract[net,"Output"]]];If[simple,net=NetAppend[net,ElementwiseLayer["Swish"]]]];,{i,blocks}];net=NetAppend[net,{groupNorm[NetExtract[net,"Output"],groups],ElementwiseLayer["Swish"],conv[outChannels]}];net]
ELBO loss function for Variational Auto-Encoder:
Out[]=
In[]:=
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
VAE network with ELBO loss function:
In[]:=
VAE[input_,encoderArgs_List:{},decoderArgs_List:{}]:=With[{enc=encoder[input,Sequence@@encoderArgs]},NetGraph[<|"encoder"->enc,(*"exp"->ElementwiseLayer[Exp],*)(*"z"->RandomArrayLayer[NormalDistribution[#Mean,#Var]&],*)"decoder"->decoder[NetExtract[enc,"Mean"],First[input],Sequence@@decoderArgs],(*"loss"->With[{n=Times@@NetExtract[enc,"Mean"]},FunctionLayer[(Total[]-Total[Unevaluated[Flatten][1+#LogVar-#Mean^2-#Var]])&]],*)"loss"->With[{n=Times@@NetExtract[enc,"Mean"]},FunctionLayer[(Total[](*-Total[Unevaluated[Flatten][1+#LogVar-#Mean^2-#Var]]*))&]]|>,{(*NetPort[{"encoder","Mean"}]->NetPort[{"z","Mean"}],*)(*NetPort[{"encoder","LogVar"}]->"exp"->NetPort[{"z","Var"}],*)(*"z"->"decoder",*)NetPort[{"encoder","Mean"}]->"decoder",NetPort["Input"]->NetPort[{"loss","Input"}],"decoder"->NetPort[{"loss","Output"}],(*NetPort[{"encoder","Mean"}]->NetPort[{"loss","Mean"}],*)(*NetPort[{"encoder","LogVar"}]->NetPort[{"loss","LogVar"}],*)(*"exp"->NetPort[{"loss","Var"}],*)NetPort[{"encoder","Mean"}]->NetPort["Latent"],"loss"->NetPort["Loss"]}]]
2
Unevaluated[Flatten][#Input-#Output]
2
Unevaluated[Flatten][#Input-#Output]
In[]:=
vae=VAE[{1,32,32}];
In[]:=
simpleVae=VAE[{1,32,32},{4,32,5,Automatic,True},{32,5,Automatic,True}];
Prediction is a simple multi-layer feed-forward linear net:
In[]:=
predictionNet[input_,layers_:3]:=With[{size=Times@@input},NetGraph[<|"x"->PartLayer[1],"y"->PartLayer[2],"predict"->NetChain[{FlattenLayer[],Splice@Table[Splice@{LinearLayer[size],ElementwiseLayer["ReLU"]},layers],LinearLayer[size],ReshapeLayer[input]}],"loss"->MeanSquaredLossLayer[]|>,{NetPort["Input"]->{"x","y"},"x"->"predict",{"predict","y"}->"loss"->NetPort["Loss"]}]]
In[]:=
predictionNet[{4,4,4}]
Out[]=
Final network
Final network
Put it all together:
CA
CA
Data
Data
Training
Training
Only generate samples with no empty 2nd frame
Prediction
Prediction
If all goes well, these should look similar:
Pretty bad generalization:
Lines
Lines
Wrap and bounce lines
Wrap and bounce lines
CA (5 color, nearest neighbor, symmetric, quiescent)
CA (5 color, nearest neighbor, symmetric, quiescent)
Samples
Samples
Transformer
Transformer
[Mask] inpainting transformer
[Mask] inpainting transformer
ARC
ARC
PDE
PDE
1D
1D
TagSystem
TagSystem
HardSphere
HardSphere