CoCa Image Captioning Nets Trained on LAION-2B Data

Represent words and images as vectors

Released in 2022, the CoCa (Contrastive Captioner) family of transformer-based neural nets is a collection of image-text encoder-decoder models trained jointly to minimize the contrastive and image captioning loss. In contrast to the standard CLIP (Contrastive Language–Image Pre-training) models, CoCa introduces the class token within the text encoder and employs attention pooling in the image encoder to learn a sequence of the hidden representations of the image. The output from the attention pooler is then sequentially fed into the text decoder to produce the final image caption.

Training Set Information

Model Information

Examples

Resource retrieval

Get the pre-trained net:

In[1]:=
NetModel["CoCa Image Captioning Nets Trained on LAION-2B Data"]
Out[1]=

NetModel parameters

This model consists of a family of individual nets, each identified by a specific parameter combination. Inspect the available parameters:

In[2]:=
NetModel["CoCa Image Captioning Nets Trained on LAION-2B Data", "ParametersInformation"]
Out[2]=

Pick a non-default net by specifying the parameters:

In[3]:=
NetModel[{"CoCa Image Captioning Nets Trained on LAION-2B Data", "Part" -> "Text-Encoder"}]
Out[3]=

Pick a non-default uninitialized net:

In[4]:=
NetModel[{"CoCa Image Captioning Nets Trained on LAION-2B Data", "Part" -> "Text-Decoder"}, "UninitializedEvaluationNet"]
Out[4]=

Evaluation function

Define an evaluation function that uses all model parts to obtain the image features and automate the caption generation:

In[5]:=
decodeStep[textDecoder_, textEncoderBare_, imageEmbeddings_, temperature_, tokenDict_, targetDevice_, topK_][textTokens_] := Append[textTokens,
  	textDecoder[
   			<|
    				"ImageEmbeddings" -> imageEmbeddings,
    				"TextEmbeddings" -> textEncoderBare[Lookup[tokenDict, textTokens], NetPort["TextEmbeddings"]]
    			|>,
   			"RandomSample" -> { "Temperature" -> temperature, "TopProbabilities" -> topK},
   			TargetDevice -> targetDevice
   		]
  ]
Options[netevaluate] = {
   "Architecture" -> "ViT-B/32",
   "Finetuned" -> False,
   "Temperature" -> 0,
   "TopProbabilities" -> 100,
   "NumberOfFrames" -> 16,
   MaxIterations -> 25,
   TargetDevice -> "CPU"
   };
netevaluate[input : (_?ImageQ | _?VideoQ), opts : OptionsPattern[]] := Block[
  {imageEmbeddings, textEmbeddings, textTokens, tokenDict, textEncoderBare, decode, imageEncoder, textEncoder, textDecoder, images}, {imageEncoder, textEncoder, textDecoder} = NetModel[{
       		"CoCa Image Captioning Nets Trained on LAION-2B Data",
       		"Architecture" -> OptionValue["Architecture"], "Finetuned" -> OptionValue["Finetuned"], "Part" -> #
       	}] & /@ {"Image-Encoder", "Text-Encoder", "Text-Decoder"}; images = Switch[input,
    _?VideoQ,
    	VideoFrameList[input, OptionValue["NumberOfFrames"]],
    _?ImageQ,
    	input
    ];
  imageEmbeddings = imageEncoder[images, NetPort["ImageEmbeddings"], TargetDevice -> OptionValue[TargetDevice]];
   If[MatchQ[input, _?VideoQ],
   	imageEmbeddings = Mean[imageEmbeddings]
   ]; tokenDict = AssociationThread[# -> Range @ Length @ #] &@ NetExtract[textEncoder, {"Input", "Tokens"}]; textEncoderBare = NetReplacePart[textEncoder, "Input" -> None]; decode = decodeStep[textDecoder, textEncoderBare, imageEmbeddings,
     OptionValue["Temperature"], tokenDict, OptionValue[TargetDevice],
     OptionValue["TopProbabilities"]
    ];
  textTokens = NestWhile[
    	decode, {StartOfString},
    	If[Length[#] > 0, Last[#] =!= EndOfString, True] &,
    	1, OptionValue[MaxIterations]
    ]; StringJoin @ ReplaceAll[textTokens, StartOfString | EndOfString -> ""] ]

Basic usage

Define a test image:

In[6]:=
(* Evaluate this cell to get the example input *) CloudGet["https://www.wolframcloud.com/obj/2c00a3de-a8fe-428e-8ffd-941c0e17c4b4"]

Generate an image caption:

In[7]:=
netevaluate[img]
Out[7]=

Define a test video and visualize the first five frames:

In[8]:=
video = ResourceData["Sample Video: City Street Activity"];
VideoFrameList[video, 5]
Out[9]=

For zero-shot video-text retrieval, the mean image embedding is computed across uniformly sampled video frames:

In[10]:=
netevaluate[video, "Temperature" -> 0.05]
Out[10]=

Efficient generation

The evaluation function defined in the previous section is inefficient because every time a token of the caption is generated, the entire caption is read by the text encoder from the start. It is possible to avoid this by using NetUnfold. Reimplement the evaluation function using nets prepared by NetUnfold:

In[11]:=
Options[prepareNets] = {
   "Architecture" -> "ViT-B/32",
   "Finetuned" -> False
   };
prepareNets[opts : OptionsPattern[]] := Block[
   {imageEncoder, textEncoder, textDecoder}, {imageEncoder, textEncoder, textDecoder} = NetModel[{
        		"CoCa Image Captioning Nets Trained on LAION-2B Data",
        		"Architecture" -> OptionValue["Architecture"], "Finetuned" -> OptionValue["Finetuned"], "Part" -> #
        	}] & /@ {"Image-Encoder", "Text-Encoder", "Text-Decoder"};
   textDecoder = NetUnfold@NetFlatten@NetGraph[{
        "text_decoder" -> textDecoder,
        (*drop the class token and sequence most layer*)
        "text_encoder" -> NetChain[{ NetExtract[
            textEncoder, {"input_embeddings", "token_embeddings"}], NetTake[NetExtract[textEncoder, "input_embeddings"], {NetPort[{"append", "Output"}], "add"}],
           NetExtract[textEncoder, "transformer"]
           }]
        },
       {NetPort["Input"] -> "text_encoder", "text_encoder" -> NetPort[{"text_decoder", "TextEmbeddings"}]}
       ]; {imageEncoder, textDecoder}
   ];
In[12]:=
decodeStepEfficient[textDecoder_, imageEmbeddings_, tokenDict_, temperature_, targetDevice_, topK_][{generatedTokens_, prevState_, index_}] := Block[{decoded},
  	decoded = textDecoder[
    	Join[
     			<|
      		"ImageEmbeddings" -> imageEmbeddings,
      		"Index" -> index, "Input" -> Lookup[tokenDict, Last[generatedTokens]],
      		"Output" -> {"RandomSample", "Temperature" -> temperature, "TopProbabilities" -> topK}
      	      |>,
     		prevState
     	],
    		TargetDevice -> targetDevice
    	]; {
   Append[generatedTokens, decoded["Output"]],
   KeyMap[
    StringReplace["OutState" -> "State"],
    KeySelect[decoded, StringStartsQ["OutState"]]
    ],
   index + 1
   }
  ]
Options[netevaluateEfficient] = {
   "Temperature" -> 0,
   "TopProbabilities" -> 100,
   "NumberOfFrames" -> 16,
   MaxIterations -> 25,
   TargetDevice -> "CPU"
   };
netevaluateEfficient[
  input : (_?ImageQ | _?VideoQ), {imageEncoder_, textDecoder_}, opts : OptionsPattern[]] := Block[
  {imageEmbeddings, textEmbeddings, textTokens,  decode, images, initStates, tokenDict}, images = Switch[input,
    _?VideoQ,
    	VideoFrameList[input, OptionValue["NumberOfFrames"]],
    _?ImageQ,
    	input
    ];
  imageEmbeddings = imageEncoder[images, NetPort["ImageEmbeddings"], TargetDevice -> OptionValue[TargetDevice]];
   If[MatchQ[input, _?VideoQ],
   	imageEmbeddings = Mean[imageEmbeddings]
   ]; initStates = AssociationMap[
    Function[x, {}],
    Select[Information[textDecoder, "InputPortNames"], StringStartsQ["State"]]
    ]; tokenDict = AssociationThread[# -> Range @ Length @ #] &@ NetExtract[textDecoder, {"Output", "Labels"}];
  decode = decodeStepEfficient[textDecoder, imageEmbeddings, tokenDict, OptionValue["Temperature"],  OptionValue[TargetDevice], OptionValue["TopProbabilities"]
    ];
  textTokens = First@ NestWhile[
     	decode, {{StartOfString}, initStates, 1},
     	Last[First@#] =!= EndOfString &,
     	1, OptionValue[MaxIterations]
     ]; StringJoin @ ReplaceAll[textTokens, StartOfString | EndOfString -> ""] ]

Prepare the unfolded net beforehand for the efficient image caption generation:

In[13]:=
{imageEncoder, textDecoder} = prepareNets["Architecture" -> "ViT-B/32", "Finetuned" -> False];

The new textDecoder is an unfolded version of the old text encoder and text decoder combined and exposes ports for the hidden attention states:

In[14]:=
Information[textDecoder, "InputPortNames"]
Out[14]=
In[15]:=
Information[textDecoder, "OutputPortNames"]
Out[15]=

Compare evaluation timing of the methods:

In[16]:=
(* Evaluate this cell to get the example input *) CloudGet["https://www.wolframcloud.com/obj/28d63446-d2cc-4863-87a1-40cb3856b012"]
In[17]:=
AbsoluteTiming@netevaluateEfficient[img, {imageEncoder, textDecoder}]
Out[17]=
In[18]:=
AbsoluteTiming@netevaluate[img]
Out[18]=

The speedup in the efficient case depends on the length of the generated sequence, and image captioning models generally don't generate very long pieces of text. Hence, it's not very significant.

Feature space visualization

Get a set of images:

In[19]:=
(* Evaluate this cell to get the example input *) CloudGet["https://www.wolframcloud.com/obj/5821605f-9b91-4970-a591-72582cb25fc4"]
In[20]:=
FeatureSpacePlot[
 Thread[NetModel[{"CoCa Image Captioning Nets Trained on LAION-2B Data", "Part" -> "Image-Encoder"}][imgs, NetPort[{"embed", "Output"}]] -> imgs],
 LabelingSize -> 70,
 LabelingFunction -> Callout,
 ImageSize -> 700,
 Method -> "TSNE",
 AspectRatio -> 0.9
 ]
Out[20]=

Define a list of sentences in two categories:

In[21]:=
sentences = {
   "The Empire State Building's observation deck in New York is a must-visit for its iconic skyline views.",
   "The Charging Bull in the financial district of New York has become a symbol of market optimism.",
   "Times Square in New York is best known for its bright billboards and bustling atmosphere.",
   "The Statue of Liberty in New York stands as a universal symbol of freedom and opportunity.",
   "Central Park in New York is an urban oasis, providing a natural escape amidst the city's skyscrapers.",
   "Sacré-Cœur in Paris offers both spiritual solace and panoramic views from its hilltop location.",
   "The Eiffel Tower's light in Paris show adds a romantic touch to the city's engineering marvel.",
   "Bridges over the Seine in Paris are scenic spots that often host art and book vendors.",
   "The Louvre's glass pyramid in Paris modernizes the entrance to a museum filled with historical art.",
   "The Panthéon in Paris serves as a tribute to national heroes, complete with educational exhibits."
   };
In[22]:=
FeatureSpacePlot[
 Thread[NetModel[{"CoCa Image Captioning Nets Trained on LAION-2B Data", "Part" -> "Text-Encoder"}][sentences, NetPort[{"normalize", "Output"}]] -> (Tooltip[Style[Text@#, Medium]] & /@ sentences)],
 LabelingSize -> {90, 60}, RandomSeeding -> 23,
 LabelingFunction -> Callout,
 ImageSize -> 700,
 AspectRatio -> 0.9
 ]
Out[22]=

Connecting text and images

Define a test image:

In[23]:=
(* Evaluate this cell to get the example input *) CloudGet["https://www.wolframcloud.com/obj/1d8bfadb-2795-4025-8fc4-74a41137d0e6"]

Define a list of text descriptions:

In[24]:=
descriptions = {
   "Blossoming rose on textbook among delicate petals",
   "Photo of foggy forest",
   "A portrait of a man with red hair taking a picture with an analog camera",
   "Yellow flower in tilt shift lens",
   "Woman in black leather jacket and white pants",
   "A portrait of a woman with red hair taking a picture with an analog camera",
   "Calm body of lake between mountains",
   "Close up shot of a baby girl",
   "Smiling man surfing on wave in ocean",
   "A portrait of a woman with red hair taking a picture with a phone",
   "A portrait of a woman with red hair taking a picture with a digital camera",
   "A woman with eyeglasses smiling",
   "Elderly woman carrying a baby",
   "A portrait of a woman with blue hair taking a picture with an analog camera"
   };

Embed the test image and text descriptions into the same feature space:

In[25]:=
textFeatures = NetModel[{"CoCa Image Captioning Nets Trained on LAION-2B Data", "Part" -> "Text-Encoder"}][descriptions, NetPort["ClassEmbedding"]];
imgFeatures = NetModel[{"CoCa Image Captioning Nets Trained on LAION-2B Data", "Part" -> "Image-Encoder"}][img, NetPort["ClassEmbedding"]];

Rank the text description with respect to the correspondence to the input image according to the CosineDistance. Smaller distances mean higher correspondence between the text and the image:

In[26]:=
Dataset@SortBy[
  Thread[{descriptions, First@DistanceMatrix[{imgFeatures}, textFeatures, DistanceFunction -> CosineDistance]}], Last]
Out[26]=

Zero-shot image classification

By using the text and image feature extractors together, it's possible to perform generic image classification between any set of classes without having to explicitly train any model for those particular classes (zero-shot classification). Obtain the FashionMNIST test data, which contains ten thousand test images and 10 classes:

In[27]:=
testData = ResourceData["FashionMNIST", "TestData"];

Display a few random examples from the set:

In[28]:=
RandomChoice[testData, 5]
Out[28]=

Get a mapping between class IDs and labels:

In[29]:=
idToLabel = ResourceData["FashionMNIST", "ClassLabels"]
Out[29]=

Generate the text templates for the FashionMNIST labels and embed them. The text templates will effectively act as classification labels:

In[30]:=
labelTemplates = "This is a photo of a " <> # & /@ ToLowerCase[Values@idToLabel]
Out[30]=
In[31]:=
textEmbeddings = NetModel["CoCa Image Captioning Nets Trained on LAION-2B Data", "Part" -> "Text-Encoder"][labelTemplates, NetPort["ClassEmbedding"]];
In[32]:=
Dimensions[textEmbeddings]
Out[32]=

Classify an image from the test set. Obtain its embedding:

In[33]:=
img = testData[[5634, 1]]
Out[33]=
In[34]:=
imgFeatures = NetModel["CoCa Image Captioning Nets Trained on LAION-2B Data", "Part" -> "Image-Encoder"][img, NetPort["ClassEmbedding"]];
In[35]:=
Dimensions[imgFeatures]
Out[35]=

The result of the classification is the description of the embedding that is closest to the image embedding:

In[36]:=
Nearest[textEmbeddings -> labelTemplates, imgFeatures, DistanceFunction -> CosineDistance]
Out[36]=

Find the top 10 descriptions nearest to the image embedding:

In[37]:=
SortBy[Rule @@@ MapAt[labelTemplates[[#]] &, Nearest[textEmbeddings -> {"Index", "Distance"}, imgFeatures, 10, DistanceFunction -> CosineDistance], {All, 1}], Last]
Out[37]=

Obtain the accuracy of this procedure on the entire test set. Extract the features for all the images (if a GPU is available, setting TargetDevice -> "GPU" is recommended as the computation will take several minutes on CPU):

In[38]:=
imageEmbeddings = NetModel["CoCa Image Captioning Nets Trained on LAION-2B Data", "Part" -> "Image-Encoder"][testData[[All, 1]], NetPort["ClassEmbedding"], TargetDevice -> "CPU"];
In[39]:=
Dimensions[imageEmbeddings]
Out[39]=

Calculate the distance matrix between the computed text and image embeddings:

In[40]:=
distanceMatrix = DistanceMatrix[imageEmbeddings, textEmbeddings, DistanceFunction -> CosineDistance];

Obtain the top-1 prediction:

In[41]:=
predictedClassIDs = Flatten[Ordering[#, 1] & /@ distanceMatrix] - 1;

Obtain the final classification results:

In[42]:=
ClassifierMeasurements[idToLabel /@ predictedClassIDs, idToLabel /@ testData[[All, 2]]]
Out[42]=

Attention visualization for images

Just like the original Vision Transformer (see the model "Vision Transformer Trained on ImageNet Competition Data"), the image feature extractor divides the input images in 7x7 patches and performs self-attention on a set of 50 vectors: 49 vectors, or "tokens," representing the 7x7 patches and an additional one, a "feature extraction token," that is eventually used to produce the final feature representation of the image. Thus the attention procedure for this model can be visualized by inspecting the attention weights between the feature extraction token and the patch tokens. Define a test image:

In[43]:=
(* Evaluate this cell to get the example input *) CloudGet["https://www.wolframcloud.com/obj/2f8980cc-77f4-41c8-8bc5-f54ba4b64817"]

Extract the attention weights used for the last block of self-attention:

In[44]:=
attentionMatrix = Transpose@
   NetModel[{"CoCa Image Captioning Nets Trained on LAION-2B Data", "Part" -> "Image-Encoder"}][testImage, NetPort[{"transformer", -1, "self-attention", "attention", "AttentionWeights"}]];
In[45]:=
Dimensions[attentionMatrix]
Out[45]=

Extract the attention weights between the feature extraction token and the input patches. These weights can be interpreted as which patches in the original image the net is "looking at" in order to perform the feature extraction:

In[46]:=
featureAttention = attentionMatrix[[All, 1, 2 ;;]];
{numHeads, numPatches} = Dimensions[featureAttention]
Out[47]=

Reshape the weights as a 3D array of 12 7x7 matrices. Each matrix corresponds to an attention head, while each element of the matrices corresponds to a patch in the original image:

In[48]:=
featureAttention = ArrayReshape[
   featureAttention, {numHeads, Sqrt[numPatches], Sqrt[numPatches]}];
In[49]:=
Dimensions[featureAttention]
Out[49]=

Visualize the attention weight matrices. Patches with higher values (red) are what is mostly being "looked at" for each attention head:

In[50]:=
GraphicsRow[MatrixPlot /@ featureAttention, ImageSize -> Full]
Out[50]=

Define a function to visualize the attention matrix on an image:

In[51]:=
visualizeAttention[img_Image, attentionMatrix_] := Block[{heatmap, wh},
  wh = ImageDimensions[img];
  heatmap = ImageApply[{#, 1 - #, 1 - #} &, ImageAdjust@Image[attentionMatrix]];
  heatmap = ImageResize[heatmap, wh*256/Min[wh]];
  ImageCrop[ImageCompose[img, {ColorConvert[heatmap, "RGB"], 0.4}], ImageDimensions[heatmap]]
  ]

Visualize the mean attention across all the attention heads:

In[52]:=
visualizeAttention[testImage, Mean[featureAttention]]
Out[52]=

Visualize each attention head separately:

In[53]:=
visualizeAttention[testImage, #] & /@ featureAttention
Out[53]=

Attention visualization for text

The text feature extractor tokenizes the input string prepending and appending the special tokens StartOfString and EndOfString, adds a special "class" token and then performs causal self-attention on the token embedding vectors. After the self-attention stack, the last vector (corresponding to the "class" token) is used to obtain the final feature representation of the text. Thus the attention procedure for this model can be visualized by inspecting the attention weights between the last vector and the previous ones. Define a test string:

In[54]:=
text = "A portrait of a woman with red hair taking a picture with an analog camera";

Extract the NetEncoder of the net to encode the string:

In[55]:=
netEnc = NetExtract[
  NetModel[{"CoCa Image Captioning Nets Trained on LAION-2B Data", "Part" -> "Text-Encoder"}], "Input"]
Out[55]=
In[56]:=
codes = netEnc[text]
Out[56]=

Extract the list of available tokens and inspect how the input string was tokenized. Even though the BPE tokenization generally segments the input into subwords, it's common to observe that all tokens correspond to full words. Also observe that the StartOfString and EndOfString tokens are added automatically:

In[57]:=
allTokens = NetExtract[netEnc, "Tokens"];
In[58]:=
tokens = allTokens[[codes]]
Out[58]=
In[59]:=
Length[tokens]
Out[59]=

Feed the string to the net and extract the attention weights used for the last block of self-attention. Note that the original implementation uses a slightly different implementation for the attention mask:

In[60]:=
attentionMatrix = Transpose@
   NetModel[{"CoCa Image Captioning Nets Trained on LAION-2B Data", "Part" -> "Text-Encoder"}][text, NetPort[{"transformer", -1, "self-attention", "attention", "AttentionWeights"}]];
Dimensions[attentionMatrix]
Out[61]=

Extract the attention weights between the last vector and the previous ones, leaving out the vectors corresponding to StartOfString, EndOfString and the "class" tokens. These weights can be interpreted as which tokens in the original sentence the net is "looking at" in order to perform the feature extraction:

In[62]:=
featureAttention = attentionMatrix[[All, -1, 2 ;; -3]];
Dimensions[featureAttention]
Out[63]=

Inspect the average attention weights for each token across the attention heads. Observe that the tokens the net is mostly focused on are "hair," "camera" and "picture":

In[64]:=
BarChart[
 Reverse@MapThread[
   Labeled, {Mean[featureAttention], tokens[[2 ;; -2]]}], BarOrigin -> Left]
Out[64]=

Visualize each head separately:

In[65]:=
BarChart[
 Reverse@MapThread[
   Labeled, {Transpose[featureAttention], tokens[[2 ;; -2]]}], BarOrigin -> Left]
Out[65]=

Extract the attention weights for all 12 attention layers:

In[66]:=
spec = Table[
   NetPort[{"transformer", i, "self-attention", "attention", "AttentionWeights"}], {i, 1, 12}];
In[67]:=
allAttentionWeights = Transpose[
   Values@NetModel[{"CoCa Image Captioning Nets Trained on LAION-2B Data", "Part" -> "Text-Encoder"}][text, spec], 2 <-> 3];
Dimensions[allAttentionWeights]
Out[68]=

Compute the average across all heads, leaving the StartOfString and EndOfString tokens out:

In[69]:=
pos = Append[Range[2, 16], 18];
avgAttentionWeights = ArrayReduce[Mean, allAttentionWeights, 2][[All, pos, pos]];
Dimensions[avgAttentionWeights]
Out[70]=

Define a function to visualize the attention weights:

In[71]:=
visualizeTokenAttention[attnMatrix_] := Block[{g, style},
  g = WeightedAdjacencyGraph[attnMatrix];
  style = Thread@Directive[
     Arrowheads[.02],
     Thickness /@ (Rescale[AnnotationValue[g, EdgeWeight]]/200),
     Opacity /@ Rescale@AnnotationValue[g, EdgeWeight]
     ];
  Graph[g, GraphLayout -> "LinearEmbedding", EdgeStyle -> Thread[EdgeList[g] -> style], VertexLabels -> Thread[Range[16] -> Map[Rotate[Style[Text[#], 12, Bold], 60 Degree] &, Append[tokens[[2 ;; -2]], "ClassToken"]]], VertexCoordinates -> Table[{i, 0}, {i, Length[attnMatrix]}], ImageSize -> Large]
  ]

Explore the attention weights for every layer. A thicker arrow pointing from token A to token B indicates that the layer is paying attention to token B when generating the vector corresponding to token A:

In[72]:=
Manipulate[
 visualizeTokenAttention@
  avgAttentionWeights[[i]], {{i, 12, "AttentionLayer #"}, Range[12]}, ControlType -> SetterBar]
Out[72]=

Net information

Inspect the number of parameters of all arrays in the net:

In[73]:=
Information[
 NetModel[
  "CoCa Image Captioning Nets Trained on LAION-2B Data"], "ArraysElementCounts"]
Out[73]=

Obtain the total number of parameters:

In[74]:=
Information[
 NetModel[
  "CoCa Image Captioning Nets Trained on LAION-2B Data"], "ArraysTotalElementCount"]
Out[74]=

Obtain the layer type counts:

In[75]:=
Information[
 NetModel[
  "CoCa Image Captioning Nets Trained on LAION-2B Data"], "LayerTypeCounts"]
Out[75]=

Resource History

Reference