In[]:=
$MachineName
Out[]=
threadripper2
Pipeline
Pipeline
Data generation
Data generation
[Generate evolution, then train by masking the second half of the evolution]
<Alternative approach not taken: predict a smaller slice, then iterate that>
<Could mask different regions at random>
<Alternative approach not taken: predict a smaller slice, then iterate that>
<Could mask different regions at random>
[[ Make a version where the line is cyclic, and where it is wraps around the region ]]
In[]:=
lineData[]:=Rasterize[Graphics[{White,Thick,Table[Line[{{RandomInteger[32],0},{RandomInteger[32],64}}],1]},Background->Black,ImageSize->{32,64}],RasterSize->{32,64}]
In[]:=
lineData[]
Out[]=
Network
Network
In[]:=
conv[channels_,size_:{3,3},pad_:1]:=ConvolutionLayer[channels,size,PaddingSize->pad]groupNorm[input_,groups_:32]:=NetChain[{ReshapeLayer[MapAt[Splice[{groups,#/groups}]&,input,{1}]],NormalizationLayer[2;;,;;2,"Epsilon"->1*^-6],ReshapeLayer[input]},"Input"->input]convBlock[input_,groups_:32]:=NetFlatten[NetGraph[{{groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]],groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]]},ThreadingLayer[Plus]},{NetPort["Input"]->1,{NetPort["Input"],1}->2}],1]downBlock[input_,groups_:32]:=NetFlatten[NetGraph[{{groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]],groupNorm[input,groups],ElementwiseLayer["Swish"],conv[2First[input]]},conv[2First[input],{1,1},0],ThreadingLayer[Plus]},{NetPort["Input"]->{1,2}->3}],1]upBlock[input_,groups_:32]:=NetFlatten[NetGraph[{{groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]],groupNorm[input,groups],ElementwiseLayer["Swish"],conv[First[input]/2]},conv[First[input]/2,{1,1},0],ThreadingLayer[Plus]},{NetPort["Input"]->{1,2}->3}],1]attention[dim_]:=NetGraph[{"key"->NetMapOperator[LinearLayer[dim]],"value"->NetMapOperator[LinearLayer[dim]],"query"->NetMapOperator[LinearLayer[dim]],"attention"->AttentionLayer["Dot","ScoreRescaling""DimensionSqrt"],"output"->NetMapOperator[LinearLayer[dim]]},{NetPort["Input"]->{"key","value","query"}->"attention"->"output"},"Input"->{"Varying",dim}]attentionBlock[input_,groups_:32]:=NetFlatten[NetGraph[{{groupNorm[input,groups],FlattenLayer[-1],TransposeLayer[],attention[First@input],TransposeLayer[],ReshapeLayer[input]},ThreadingLayer[Plus]},{NetPort["Input"]->1,{NetPort["Input"],1}->2}],1]downsample[input_]:=ConvolutionLayer[First@input,{3,3},"Stride"->2,PaddingSize->{{0,1},{0,1}}]upsample[input_]:=NetChain[{ResizeLayer[{Scaled[2],Scaled[2]},Resampling->"Nearest","Scheme"->"Bin"],conv[First@input]},"Input"->input]
In[]:=
encoder[input_,outChannels_:4,convChannels_:128,blocks_:5,defGroups_:Automatic]:=Enclose@Block[{net,groups=Replace[defGroups,Automatic:>convChannels/4]},net=NetChain[{conv[convChannels]},"Input"->input];Do[net=NetAppend[net,If[i==1||i>=blocks-1,convBlock,downBlock][NetExtract[net,"Output"],groups]];If[i==blocks,net=NetAppend[net,attentionBlock[NetExtract[net,"Output"],groups]]];net=NetAppend[net,convBlock[NetExtract[net,"Output"],groups]];If[i<blocks-1,net=NetAppend[net,downsample[NetExtract[net,"Output"]]]],{i,blocks}];net=NetAppend[net,{groupNorm[NetExtract[net,"Output"],groups],conv[2outChannels],conv[2outChannels,1,0],NetGraph[{PartLayer[;;outChannels],PartLayer[outChannels+1;;]},{NetPort["Input"]->{1,2},1->NetPort["Mean"],2->NetPort["LogVar"]}]}];net]decoder[input_,outChannels_:1,convChannels_:128,blocks_:5,defGroups_:Automatic]:=Enclose@Block[{net,groups=Replace[defGroups,Automatic:>convChannels/16]},net=NetChain[{conv[First[input],{1,1},0],conv[convChannels]},"Input"->input];Do[net=NetAppend[net,If[i>2,upBlock,convBlock][NetExtract[net,"Output"],groups]];If[i==1,net=NetAppend[net,attentionBlock[NetExtract[net,"Output"],groups]]];net=NetAppend[net,{convBlock[NetExtract[net,"Output"],groups],convBlock[NetExtract[net,"Output"],groups]}];If[1<i<blocks,net=NetAppend[net,upsample[NetExtract[net,"Output"]]]];,{i,blocks}];net=NetAppend[net,{groupNorm[NetExtract[net,"Output"],groups],ElementwiseLayer["Swish"],conv[outChannels]}];net]
[ Loss function is in fact fixed, not variational ]
In[]:=
VAE[input_,encoderArgs_List:{},decoderArgs_List:{}]:=With[{enc=encoder[input,Sequence@@encoderArgs]},NetGraph[<|"encoder"->enc,"exp"->ElementwiseLayer[Exp],(*"z"->RandomArrayLayer[NormalDistribution[#Mean,#Var]&],*)"decoder"->decoder[NetExtract[enc,"Mean"],First[input],Sequence@@decoderArgs],(*"loss"->With[{n=Times@@NetExtract[enc,"Mean"]},FunctionLayer[(Total[]-Total[Unevaluated[Flatten][1+#LogVar-#Mean^2-#Var]])&]],*)"loss"->With[{n=Times@@NetExtract[enc,"Mean"]},FunctionLayer[(Total[](*-Total[Unevaluated[Flatten][1+#LogVar-#Mean^2-#Var]]*))&]]|>,{(*NetPort[{"encoder","Mean"}]->NetPort[{"z","Mean"}],*)(*NetPort[{"encoder","LogVar"}]->"exp"->NetPort[{"z","Var"}],*)(*"z"->"decoder",*)NetPort[{"encoder","Mean"}]->"decoder",NetPort["Input"]->NetPort[{"loss","Input"}],"decoder"->NetPort[{"loss","Output"}],(*NetPort[{"encoder","Mean"}]->NetPort[{"loss","Mean"}],*)(*NetPort[{"encoder","LogVar"}]->NetPort[{"loss","LogVar"}],*)(*"exp"->NetPort[{"loss","Var"}],*)NetPort[{"encoder","Mean"}]->NetPort["Latent"],"loss"->NetPort["Loss"]}]]
2
Unevaluated[Flatten][#Input-#Output]
2
Unevaluated[Flatten][#Input-#Output]
In[]:=
vae=VAE[{1,32,32}];
In[]:=
predictionNet[input_,layers_:3]:=With[{size=Times@@input},NetGraph[<|"x"->PartLayer[1],"y"->PartLayer[2],"predict"->NetChain[{FlattenLayer[],Splice@Table[Splice@{LinearLayer[size],ElementwiseLayer["ReLU"]},layers],LinearLayer[size],ReshapeLayer[input]}],"loss"->MeanSquaredLossLayer[]|>,{NetPort["Input"]->{"x","y"},"x"->"predict",{"predict","y"}->"loss"->NetPort["Loss"]}]]
In[]:=
net=NetInitialize@NetGraph[<|"prep"->NetChain[{ReshapeLayer[{1,2,32,32}],TransposeLayer[1<->2]}],"vae"->NetMapThreadOperator[vae],"VAETotalLoss"->AggregationLayer[Total,1],"TotalLoss"->TotalLayer[],"prediction"->predictionNet[NetExtract[vae,"Latent"]]|>,{NetPort["Input"]->"prep"->"vae",NetPort[{"vae","Latent"}]->"prediction",NetPort[{"vae","Loss"}]->"VAETotalLoss",{"VAETotalLoss",NetPort[{"prediction","Loss"}]}->"TotalLoss"(*"VAETotalLoss"*)->NetPort["Loss"]},"Input"->NetEncoder[{"Image",{32,64},ColorSpace->"Grayscale"}]];
In[]:=
net
Extracted pieces
Extracted pieces
Encoded version is 4×4×4 tensor [ same channel number as for stable diffusion ]
Trained net
Trained net
Training
Training
Inference
Inference
Plan
Single lines [with wraparound]
Single lines [with wraparound]
Multiple non-interacting lines
Multiple non-interacting lines
Bouncing lines
Bouncing lines
CAs
CAs
5 color, nearest neighbor, symmetric, quiescent
Case 1: one network per rule (“specialist”)
Case 1: one network per rule (“specialist”)
[ Train and run a bunch of cases; rank them by how well they work ]
[ Train and run a bunch of cases; rank them by how well they work ]
Case 2: a “generalist” network
Case 2: a “generalist” network