CapsNet Trained on MNIST Data

Identify the handwritten digit in an image

Released in 2017, this model makes use of capsules as the fundamental building blocks to replace neurons in artificial neural networks. As opposed to scalar neurons, whose activation reacts to the presence of a particular feature or object, a capsule is a group of neurons all reacting to the same entity, whose vectorial activation can also encode the properties of the detected instance. In addition, features are passed from one capsule layer to another using a novel dynamic routing technique.

Number of layers: 52 | Parameter count: 8,141,840 | Trained size: 33 MB |

Training Set Information

Training Set Data

Performance

Examples

Resource retrieval

Retrieve the pre-trained net:

In[1]:=
NetModel["CapsNet Trained on MNIST Data"]
Out[1]=

Basic usage

Apply the trained net to a set of inputs:

In[2]:=
NetModel["CapsNet Trained on MNIST Data"][{\!\(\*
GraphicsBox[
TagBox[RasterBox[CompressedData["
1:eJxTTMoPSmNiYGAo5gASQYnljkVFiZXBAkBOaF5xZnpeaopnXklqemqRRRJI
mQwU/6cK2MHQhEtqiTTTJVxyQUy633BIneZSeolLWwSTBS6pP1ZM+bjkzjEJ
4ZL6X81kjkvqKydTHS65L0y45c4wMS3FJdfOxHRtcW3tkY/Y5cBA+T5WOaPk
5GgmJkNMOQWwW/6FMMliyqWB5X46MlliNTPo//+JTExbMOU+Ad0h683ElP8P
U+73fCGQM4XuYPHD//97LZmYLPdhlRrUAABgHMjK
"], {{0, 28}, {28, 0}}, {0, 255},
ColorFunction->GrayLevel],
BoxForm`ImageTag[
      "Byte", ColorSpace -> Automatic, Interleaving -> None],
Selectable->False],
DefaultBaseStyle->"ImageGraphics",
ImageSizeRaw->{28, 28},
PlotRange->{{0, 28}, {0, 28}}]\), \!\(\*
GraphicsBox[
TagBox[RasterBox[CompressedData["
1:eJxTTMoPSmNiYGAo5gASQYnljkVFiZXBAkBOaF5xZnpeaopnXklqemqRRRJI
mQwU/x9IsJWBKew2dql3okzMzH6/sMr5M7tPNGWeglXOQPf//yPiZdjlgMI7
mbHKVTOVQQkM8FCaef+v3dIih7HIfQphnnKCmdntBzYzdzKLizEzM+/DJvdD
hxnoP+xueafNxCvJwWSG3S0SV/+7M/tikbvAzHzyP1AOi5lvNJj9fwPlVJdj
ytUzMx8HUu7xWIw0ZdJ49f/PJUWTL5hyTMwT//+/xMxcgkUfE/Ou/92KzKGf
sMqpGLBzTb6MRer/Sh1gmGDzG/0AACEauS8=
"], {{0, 28}, {28, 0}}, {0, 255},
       
ColorFunction->GrayLevel],
BoxForm`ImageTag[
      "Byte", ColorSpace -> Automatic, Interleaving -> None],
Selectable->False],
DefaultBaseStyle->"ImageGraphics",
ImageSizeRaw->{28, 28},
PlotRange->{{0, 28}, {0, 28}}]\), \!\(\*
GraphicsBox[
TagBox[RasterBox[CompressedData["
1:eJxTTMoPSmNiYGAo5gASQYnljkVFiZXBAkBOaF5xZnpeaopnXklqemqRRRJI
mQwU/x9IsN/BdW1Vstx8LFK7eZgYmYBAcepvNJkvG/mZIHJMTLfR5OaBBAtX
r44CKqn4hSK1TRgoZQNipQEZp5ClPpoCRSr+gJiXgarykeWWAqUqoSalocrd
52ViCnwJMx5VLpeJyQVufyOK3Fo2JqYWOM+WiWkFkjZGRjs4p4mRkXE5nPcp
GegQGOeXM9BZ1vth3EAgT3I+lHMOyLE7Adf3y4uJqeY1hP3MFShXhbDuTwgT
0ywIa5s9Wnh+AnIFErYBgSVIRrz4G0LumxoTFIDiQaztPzLIQ5JzPogi9f+Z
NkzOof77fzTwY1EcSKps5y90GToCAMITbxU=
"], {{0, 28}, {28, 0}}, {0, 255},
       
ColorFunction->GrayLevel],
BoxForm`ImageTag[
      "Byte", ColorSpace -> Automatic, Interleaving -> None],
Selectable->False],
DefaultBaseStyle->"ImageGraphics",
ImageSizeRaw->{28, 28},
PlotRange->{{0, 28}, {0, 28}}]\)}]
Out[2]=

Give class probabilities for a single input:

In[3]:=
NetModel["CapsNet Trained on MNIST Data"][\!\(\*
GraphicsBox[
TagBox[RasterBox[CompressedData["
1:eJxTTMoPSmNiYGAo5gASQYnljkVFiZXBAkBOaF5xZnpeaopnXklqemqRRRJI
mQwU/x9IsJWBKew2dql3okzMzH6/sMr5M7tPNGWeglXOQPf//yPiZdjlgMI7
mbHKVTOVQQkM8FCaef+v3dIih7HIfQphnnKCmdntBzYzdzKLizEzM+/DJvdD
hxnoP+xueafNxCvJwWSG3S0SV/+7M/tikbvAzHzyP1AOi5lvNJj9fwPlVJdj
ytUzMx8HUu7xWIw0ZdJ49f/PJUWTL5hyTMwT//+/xMxcgkUfE/Ou/92KzKGf
sMqpGLBzTb6MRer/Sh1gmGDzG/0AACEauS8=
"], {{0, 28}, {28, 0}}, {0, 255},
      
ColorFunction->GrayLevel],
BoxForm`ImageTag[
     "Byte", ColorSpace -> Automatic, Interleaving -> None],
Selectable->False],
DefaultBaseStyle->"ImageGraphics",
ImageSizeRaw->{28, 28},
PlotRange->{{0, 28}, {0, 28}}]\), "Classification" -> "Probabilities"]
Out[3]=

Feature extraction

Create a subset of the MNIST dataset:

In[4]:=
sample = Keys[RandomSample[ResourceData["MNIST"], 100]]
Out[4]=

Remove the last linear layer of the net, which will be used as a feature extractor:

In[5]:=
extractor = NetTake[NetModel["CapsNet Trained on MNIST Data"], {"ReLUConv1", "Pick"}]
Out[5]=

Visualize the features of a subset of the MNIST dataset:

In[6]:=
FeatureSpacePlot[sample, FeatureExtractor -> extractor]
Out[6]=

Image generation

Extract the image reconstruction part:

In[7]:=
reconstructor = NetReplacePart[
  NetExtract[NetModel["CapsNet Trained on MNIST Data"], "Reconstruct"], "Output" -> NetDecoder["Image"]]
Out[7]=

Extract the DigitCaps feature vector for a given digit image:

In[8]:=
featureVect = NetModel["CapsNet Trained on MNIST Data"][\!\(\*
GraphicsBox[
TagBox[RasterBox[CompressedData["
1:eJxTTMoPSmNiYGAo5gASQYnljkVFiZXBAkBOaF5xZnpeaopnXklqemqRRRJI
mQwU/x848KaBMRGH1EFrJqZC7FKT+JiYmDY0T96w4QGazI8eTqAUEyOI4PVF
lUsHCQulgIAdI3sFstRMRka+zIdQzhFGxqsIqe+pYja74bx+Jia1JzgcrMDE
xH8dh1wyEz+6U2FgswRTIA6p+yEMfA+xS/3wYGL0xhC99RwkFcbE5PEZRfzT
iXIRHn6V8jIlJia37yhSp7SY/FIEIWE28SOqaQIiS4HUDk5GBkbGbFS56TzH
/r9fHwuMiEBVDia3U8hyyQIbakWAxunO/v9/NQeTcPiG2wh9TEwcWqmT3nwD
cXY4AFUZwuU+Hj16DsmYdeZMpdg9T2MAAC64lw4=
"], {{0, 28}, {28, 0}}, {0, 255},
ColorFunction->GrayLevel],
BoxForm`ImageTag[
      "Byte", ColorSpace -> Automatic, Interleaving -> None],
Selectable->False],
DefaultBaseStyle->"ImageGraphics",
ImageSizeRaw->{28, 28},
PlotRange->{{0, 28}, {0, 28}}]\), NetPort["Pick", "Output"]]
Out[8]=

Reconstruct the image from the feature vector:

In[9]:=
reconstructor[featureVect]
Out[9]=

Experiment with changing the feature vector. Add a shift along a single coordinate at a time:

In[10]:=
reconstructor@
 Table[featureVect + c*UnitVector[16, 11], {c, -1, 1, 0.1}]
Out[10]=
In[11]:=
reconstructor@
 Table[featureVect + c*UnitVector[16, 14], {c, -1, 1, 0.1}]
Out[11]=
In[12]:=
reconstructor@
 Table[featureVect + c*UnitVector[16, 16], {c, -1, 1, 0.1}]
Out[12]=

Training the uninitialized architecture

Retrieve the uninitialized training architecture:

In[13]:=
trainingNet = NetModel["CapsNet Trained on MNIST Data", "TrainingNet"]
Out[13]=

Retrieve the MNIST dataset:

In[14]:=
mnist = ResourceObject["MNIST"]
Out[14]=

Use the training dataset provided:

In[15]:=
trainSet = ResourceData[mnist, "TrainingData"]
Out[15]=

Use the test dataset provided:

In[16]:=
valSet = ResourceData[mnist, "TestData"]
Out[16]=

Initialize the “W” matrices properly:

In[17]:=
trainingCapsNetInitialized = NetReplacePart[
  trainingNet,
  {"CapsNet", "PrimaryPredVects", 2, "Array"} -> RandomVariate[UniformDistribution[{-1, 1}*0.005], {1152, 10, 16, 8}]
  ]
Out[17]=

Train the net (if a GPU is available, setting TargetDevice -> "GPU" is recommended):

In[18]:=
trained = NetTrain[
  trainingCapsNetInitialized, trainSet,
  ValidationSet -> valSet,
  LossFunction -> {"ClassLoss" -> Scaled[1], "RecoLoss" -> Scaled[0.392]},
  MaxTrainingRounds -> 50,
  TargetDevice -> "CPU"
  ]
Out[18]=

Net information

Inspect the number of parameters of all arrays in the net:

In[19]:=
NetInformation[
 NetModel["CapsNet Trained on MNIST Data"], "ArraysElementCounts"]
Out[19]=

Obtain the total number of parameters:

In[20]:=
NetInformation[
 NetModel["CapsNet Trained on MNIST Data"], "ArraysTotalElementCount"]
Out[20]=

Obtain the layer type counts:

In[21]:=
NetInformation[
 NetModel["CapsNet Trained on MNIST Data"], "LayerTypeCounts"]
Out[21]=

Display the summary graphic:

In[22]:=
NetInformation[
 NetModel["CapsNet Trained on MNIST Data"], "SummaryGraphic"]
Out[22]=

Export to MXNet

Export the net into a format that can be opened in MXNet:

In[23]:=
jsonPath = Export[FileNameJoin[{$TemporaryDirectory, "net.json"}], NetModel["CapsNet Trained on MNIST Data"], "MXNet"]
Out[23]=

Export also creates a net.params file containing parameters:

In[24]:=
paramPath = FileNameJoin[{DirectoryName[jsonPath], "net.params"}]
Out[24]=

Get the size of the parameter file:

In[25]:=
FileByteCount[paramPath]
Out[25]=

The size is similar to the byte count of the resource object:

In[26]:=
ResourceObject["CapsNet Trained on MNIST Data"]["ByteCount"]
Out[26]=

Represent the MXNet net as a graph:

In[27]:=
Import[jsonPath, {"MXNet", "NodeGraphPlot"}]
Out[27]=

Requirements

Wolfram Language 11.3 (March 2018) or above

Resource History

Reference

  • S. Sabour, N. Frosst, G. E. Hinton, "Dynamic Routing between Capsules," arXiv:1710.09829 (2017)