Wolfram Research

Sketch-RNN Trained on QuickDraw Data

Generate hand-drawn sketches

Released in 2017, this collection features the models for unconditional generation of Sketch-RNN, which produce simple hand-drawn sketches (or complete a partial input sketch) represented as a sequence of pen strokes. The nets predict the next pen movement given an input sequence (possibly empty) of movements and are trained with teacher forcing. Each pen movement is sampled from a mixture of normal distributions, while a categorical distribution regulates whether the pen is drawing a line, it's lifted from the paper or the drawing has ended. The nets produce the parameters of such distributions.

Training Set Information

Examples

Resource retrieval

Get the pre-trained net:

In[1]:=
NetModel["Sketch-RNN Trained on QuickDraw Data"]
Out[1]=

NetModel parameters

This model consists of a family of individual nets, each identified by a specific parameter combination. Inspect the available parameters and their default values:

In[2]:=
NetModel["Sketch-RNN Trained on QuickDraw Data", \
"ParametersInformation"]
Out[2]=

Pick a non-default net by specifying the parameters:

In[3]:=
NetModel[{"Sketch-RNN Trained on QuickDraw Data", 
  "Object" -> "Chair"}]
Out[3]=

Pick a non-default uninitialized net:

In[4]:=
NetModel[{"Sketch-RNN Trained on QuickDraw Data", 
  "Object" -> "Flower"}, "UninitializedEvaluationNet"]
Out[4]=

Check the default parameter combination:

In[5]:=
NetModel["Sketch-RNN Trained on QuickDraw Data", "DefaultVariant"]
Out[5]=

Evaluation function

Define an evaluation function to generate a sketch from a fixed initial condition using temperature sampling:

In[6]:=
drawSketch[obj_, temp_ : 0.01, maxLen_ : 300] := 
  Block[{stateObject, lastPos, pos, stroke, segments, time, 
    lastAction, action, offset}, 
   stateObject = 
    NetStateObject@
     NetModel[{"Sketch-RNN Trained on QuickDraw Data", 
       "Object" -> obj}];
   lastPos = pos = {0, 0};
   stroke = {0, 0, 0, 0, 0};
   segments = Table[{}, maxLen];
   time = 0; action = 1;
   While[time++ < maxLen,
    lastAction = action;
    stroke = 
     stateObject[{stroke}, {"RandomSample", 
       "Temperature" -> N@temp}];
    offset = stroke[[;; 2]];
    action = First@Ordering[stroke[[3 ;;]], -1];
    lastPos = pos;
    pos += offset*{1, -1};
    Switch[lastAction,
     1, segments[[time]] = {lastPos, pos},
     3, Break[]]
    ];
   segments
   ];
In[7]:=
netevaluate[obj_, temp_ : 0.01, maxLen_ : 300] := 
 Graphics@Line@drawSketch[obj, temp, maxLen]

Basic usage

Generate four sketches of a cat:

In[8]:=
Table[netevaluate["Cat"], 4]
Out[8]=

The third optional argument is a “temperature” parameter that regulates sampling. A higher temperature increases the variability in the output, increasing the probability of sampling less likely strokes:

In[9]:=
Table[netevaluate["Cat", 0.4], 4]
Out[9]=

Very high temperature settings are equivalent to sampling from a flat distribution:

In[10]:=
Table[netevaluate["Elephant", 3], 4]
Out[10]=

Very low temperature settings further increase the probability of extracting more likely strokes. Sampling at zero temperature is equivalent to always picking the stroke with maximum probability, and the function produces the same sketch every time:

In[11]:=
Table[netevaluate["Chair", 0], 4]
Out[11]=

Visualize the sequence of pen strokes

Create a function which displays an animation showing the object being drawn from the sequence of pen strokes chosen by the network:

In[12]:=
animateSketch[obj_] := DynamicModule[{lines, plotRange},
  lines = DeleteCases[drawSketch[obj], {}];
  plotRange = 
   Map[MinMax, 1.05*{lines[[All, All, 1]], lines[[All, All, 2]]}];
  Animate[
   Graphics[Line@lines[[1 ;; i]], PlotRange -> plotRange],
   {i, 1, Length[lines], 1},
   AnimationRepetitions -> 1
   ]
  ]
In[13]:=
animateSketch["Bird"]
Out[13]=
In[14]:=
animateSketch["Flower"]
Out[14]=

Sketch representation

This model represents sketches as a sequence of pen strokes, each identified as a vector of five elements: {Δx, Δy, p1, p2, p3}. The elements {Δx, Δy} represent the pen movement in the sketch plane, while {p1, p2, p3} is a one-hot vector representing the state of the pen. The state {1, 0, 0} indicates that the pen is touching the paper and a segment will be drawn, {0, 1, 0} means the pen is not touching the paper and {0, 0, 1} indicates that the sketch is finished:

In[15]:=
inputSketch = {
   {0.4, 0.5, 1, 0, 0}, {1.2, 0.7, 0, 1, 0}, {1.2, 0.3, 1, 0, 0}, {2, 
    2, 1, 0, 0}
   };

The model works like a language model, reading a sequence of pen strokes and predicting the next one:

In[16]:=
NetModel["Sketch-RNN Trained on QuickDraw Data"]@inputSketch
Out[16]=

Decoder properties

Extract the “Function” net decoder:

In[17]:=
dec = NetExtract[NetModel["Sketch-RNN Trained on QuickDraw Data"], 
  "Output"]
Out[17]=

Inspect the associated Function:

In[18]:=
NetExtract[dec, "Function"]
Out[18]=

This decoder supports properties. The default one, “Decision”, returns the pen stroke with highest associated probability:

In[19]:=
inputSketch = {
   {0.4, 0.5, 1, 0, 0}, {1.2, 0.7, 0, 1, 0}, {1.2, 0.3, 1, 0, 0}, {2, 
    2, 1, 0, 0}
   };
In[20]:=
NetModel["Sketch-RNN Trained on QuickDraw Data"][inputSketch, \
"Decision"]
Out[20]=

It is possible to perform temperature sampling from the output distribution with the property “RandomSample”. In this case, the evaluation will return different results every time:

In[21]:=
NetModel["Sketch-RNN Trained on QuickDraw Data"][inputSketch, \
{"RandomSample", "Temperature" -> 0.4}]
Out[21]=
In[22]:=
NetModel["Sketch-RNN Trained on QuickDraw Data"][inputSketch, \
{"RandomSample", "Temperature" -> 0.4}]
Out[22]=

The property “Distribution” returns the full probability distribution for the stroke direction and the pen state. The stroke direction distribution is a mixture of two-dimensional normal distributions, while the distribution of the pen state is a categorical distribution with three probability values for 1, 2 and 3:

In[23]:=
NetModel["Sketch-RNN Trained on QuickDraw Data"][inputSketch, \
"Distribution"]
Out[23]=

Net information

Inspect the number of parameters of all arrays in the net:

In[24]:=
NetInformation[
 NetModel["Sketch-RNN Trained on QuickDraw Data"], \
"ArraysElementCounts"]
Out[24]=

Obtain the total number of parameters:

In[25]:=
NetInformation[
 NetModel["Sketch-RNN Trained on QuickDraw Data"], \
"ArraysTotalElementCount"]
Out[25]=

Obtain the layer type counts:

In[26]:=
NetInformation[
 NetModel["Sketch-RNN Trained on QuickDraw Data"], "LayerTypeCounts"]
Out[26]=

Display the summary graphic:

In[27]:=
NetInformation[
 NetModel["Sketch-RNN Trained on QuickDraw Data"], "SummaryGraphic"]
Out[27]=

Export to MXNet

Export the net into a format that can be opened in MXNet:

In[28]:=
jsonPath = 
 Export[FileNameJoin[{$TemporaryDirectory, "net.json"}], 
  NetReplacePart[NetModel["Sketch-RNN Trained on QuickDraw Data"], 
   "Input" -> {50, 5}], "MXNet"]
Out[28]=

Export also creates a net.params file containing parameters:

In[29]:=
paramPath = FileNameJoin[{DirectoryName[jsonPath], "net.params"}]
Out[29]=

Get the size of the parameter file:

In[30]:=
FileByteCount[paramPath]
Out[30]=

Represent the MXNet net as a graph:

In[31]:=
Import[jsonPath, {"MXNet", "NodeGraphPlot"}]
Out[31]=

Requirements

Wolfram Language 12.0 (April 2019) or above

Resource History

Reference