Function Repository Resource:

NetContract

Source Notebook

Contract a subset of vertices in a NetGraph or NetChain into a single vertex

Contributed by: Maria Sargsyan

ResourceFunction["NetContract"][net, {start, end}]

contracts all the vertices from the start to the end in a NetGraph or NetChain into a single vertex.

ResourceFunction["NetContract"][net, {start, end} name]

contracts all the vertices from the start to the end in a NetGraph or NetChain into a single named vertex.

ResourceFunction["NetContract"][net, string]

contracts all the vertices which starts with a name string in a NetGraph or NetChain into a single vertex.

ResourceFunction["NetContract"][net, string name]

contracts all the vertices which starts with a name string in a NetGraph or NetChain into a single named vertex.

ResourceFunction["NetContract"][net, pattern]

contracts all the vertices matched by a pattern in a NetGraph or NetChain into a single vertex.

ResourceFunction["NetContract"][net, patternname]

contracts all the vertices matched by a pattern in a NetGraph or NetChain into a single named vertex.

ResourceFunction["NetContract"][net, rule1, rule2 , ]

performs all the contractions specified by the rulei….

Details

The patterns can be only string patterns such as RegularExpression[regex], s1~~ s2 ~~ or pure functions which operate on strings.
In ResourceFunction["NetContract"][net,{start,end}, ], the {start, end} can be any start and end specified in NetTake[net, {start, end}].

Examples

Basic Examples (9) 

Define a net:

In[1]:=
dummyNet = NetGraph[{"1/a" -> Ramp, "1/b" -> Tanh, "1/c" -> BatchNormalizationLayer[], "d" -> ThreadingLayer[Plus], "e" -> Ramp, "f" -> Tanh, "k" -> BatchNormalizationLayer[], "l" -> ThreadingLayer[Plus]}, {NetPort["Input"] -> "1/a" -> "1/b" -> "1/c", {"1/b", "1/c"} -> "d" -> "e" -> "f", "e" -> "k", {"k", "f"} -> "l"}]
Out[1]=

Obtain a new net by contracting everything from the starting to the ending node:

In[2]:=
ResourceFunction["NetContract"][dummyNet, {"1/a", "1/c"}]
Out[2]=

Give a new name to the contracted group of vertices:

In[3]:=
ResourceFunction["NetContract"][dummyNet, {"1/a", "1/c"} -> "a"]
Out[3]=

Contract all the vertices after "1/a" and before "1/c":

In[4]:=
ResourceFunction[
 "NetContract"][dummyNet, {NetPort[{"1/a", "Output"}], NetPort[{"1/c", "Input"}]}]
Out[4]=

Contract all the vertices which start with the name "1/":

In[5]:=
ResourceFunction["NetContract"][dummyNet, "1/" -> "a"]
Out[5]=

Contract all the vertices which match a RegularExpression:

In[6]:=
ResourceFunction["NetContract"][dummyNet, RegularExpression["1/.*"] -> "a"]
Out[6]=

Contract all the vertices which match a StringExpression:

In[7]:=
ResourceFunction["NetContract"][dummyNet, "1/" ~~ _ -> "a"]
Out[7]=

Contract all the vertices which match a pure function:

In[8]:=
ResourceFunction["NetContract"][dummyNet, StringStartsQ[#, "1/"] & -> "1"]
Out[8]=

Contract several sets of vertices:

In[9]:=
ResourceFunction["NetContract"][dummyNet, {"d", "e"} -> "d", "1/"]
Out[9]=

Scope (3) 

Simplify the net by combining ThreadingLayer and LogisticSigmoid into a single layer:

In[10]:=
net = NetInitialize@NetGraph[
   {
    "conv" -> ConvolutionLayer[64, {3, 3}],
    "bn" -> BatchNormalizationLayer[],
    "sigmoid" -> LogisticSigmoid,
    "swish" -> ThreadingLayer[Times]
    },
   {"conv" -> "bn" -> "sigmoid", {"bn", "sigmoid"} -> "swish"},
   "Input" -> {3, 100, 100}
   ]
Out[10]=

Create a new net:

In[11]:=
newNet = NetReplacePart[
  ResourceFunction["NetContract"][
   net, {NetPort["bn", "Output"], "swish"} -> "swish"],
  "swish" -> ThreadingLayer[#*LogisticSigmoid[#] &]
  ]
Out[11]=

Note that the new net is a bit faster:

In[12]:=
input = RandomReal[1, {3, 100, 100}];
In[13]:=
RepeatedTiming[net@input]
Out[13]=
In[14]:=
RepeatedTiming[newNet@input]
Out[14]=

Version History

  • 1.0.0 – 01 November 2022

Related Resources

License Information