ESMFold-V1 Trained on UR50 Data

Predict protein's 3D structure from amino acid sequences

The ESMFold model family by Meta AI uses a transformer architecture trained on millions of protein sequences to predict atomic-resolution 3D structures directly from single sequences. By eliminating the need for multiple sequence alignments (MSAs), it offers a fast, scalable alternative to traditional methods, delivering performance comparable to AlphaFold2 and enabling fast insights in genomics, drug discovery and molecular biology.

Training Set Information

Model Information

Examples

Resource retrieval

Get the pre-trained net:

In[1]:=
NetModel["ESMFold-V1 Trained on UR50 Data"]
Out[1]=

NetModel parameters

This model consists of a family of individual nets, each identified by a specific parameter combination. Inspect the available parameters:

In[2]:=
NetModel["ESMFold-V1 Trained on UR50 Data", "ParametersInformation"]
Out[2]=

Pick a non-default net by specifying the parameters:

In[3]:=
NetModel[{"ESMFold-V1 Trained on UR50 Data", "Size" -> "150M"}]
Out[3]=

Evaluation function

Write an evaluation function to scale the result to the input image size and suppress the least probable detections:

In[4]:=
encodeSequence[bs_BioSequence, rest___] := encodeSequence[bs["SequenceString"], rest];
encodeSequence[bs_BioMolecule, rest___] := encodeSequence[Flatten@bs["BioSequence"]["SequenceString"], rest];
encodeSequence[seq_, residueIndexOffset_ : 512, chainLinker_ : StringRepeat["G", 25]] := Module[{chains, joined, encoded, residx, linkerMask, chainIndex, offset = 0, restypesWithX = {"A", "R", "N", "D", "C", "Q", "E", "G", "H", "I",
       "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V", "X"}, restypeOrderWithX, chainCount, linkerLen, end},
   restypeOrderWithX = AssociationThread[restypesWithX -> Range[1, Length[restypesWithX]]];
   (*Setup*)
   chains = If[ListQ@seq, seq, {seq}];
   chainCount = Length[chains];
   linkerLen = StringLength[chainLinker];
   joined = StringJoin@Riffle[chains, If[chainCount > 1, chainLinker, ""]];
   (*Encode amino acids*)
   encoded = Lookup[restypeOrderWithX, Characters[joined], restypeOrderWithX["X"]] - 1;
   residx = Range[Length[encoded]] - 1;
   (*Apply offsets only if more than 1 chain*)
   If[residueIndexOffset > 0 && chainCount > 1, offset = 0;
    Do[Module[{start = offset + 1, len = StringLength[chains[[i]]]}, end = start + len - 1;
      residx[[start ;; end]] += (i - 1)*residueIndexOffset;
      offset = end;
      If[i < chainCount,(*Advance offset over linker*)
       offset += linkerLen;];], {i, chainCount}];];
   (*Linker mask:0 where linker,1 elsewhere*)
   linkerMask = ConstantArray[1.0, Length[encoded]];
   If[chainCount > 1, offset = 0;
    Do[offset += StringLength[chains[[i]]];
     If[i < chainCount,
      linkerMask[[offset + 1 ;; offset + linkerLen]] = 0;
      offset += linkerLen;];, {i, chainCount}];
    ];
   (*Chain index*)
   chainIndex = Flatten[Table[
      Join[ConstantArray[i - 1, StringLength[chains[[i]]]], If[i < chainCount, ConstantArray[i - 1, linkerLen], {}]], {i, chainCount}]];
   Transpose[{{encoded, residx, linkerMask, chainIndex}}]
   ];
In[5]:=
(* Evaluate this cell to get the example input *) CloudGet["https://www.wolframcloud.com/obj/44c824c3-61b3-4f77-93d6-134cd3877208"]
In[6]:=
trunkModel[sS0_, sZ0_, aa_, Optional[noRecycles_, 3], Optional[targetDevice_, "CPU"]] := Module[{recycleS, recycleZ, recycleBins, sS, sZ, i, structure, b, l, mask, residx},
  (*Step 1:Initialize the mask and residx*)
  {b, l} = Dimensions@aa;
  {mask, residx} = {ConstantArray[1, Dimensions@aa], {Range[l] - 1}};
  (*Step 2:Initialize recycling tensors*)
  {recycleS, recycleZ, recycleBins} = {ConstantArray[0, Dimensions@sS0], ConstantArray[0, Dimensions@sZ0], ConstantArray[0, {1, l, l}]};
  (*Step 3:Recycling loop*)
  {sS, sZ} = {sS0, sZ0};
  For[i = 0, LessEqual[i, noRecycles], Increment@i, structure = NetModel[{"ESMFold-V1 Trained on UR50 Data", "Part" -> "FoldingTrunk"}][<|"aa" -> aa, "s_s_0" -> sS0, "s_z_0" -> sZ0, "recycle_s" -> recycleS, "recycle_z" -> recycleZ, "recycle_bins" -> recycleBins, "residx" -> residx, "mask" -> mask|>, TargetDevice -> targetDevice];
   {sS, sZ, recycleS, recycleZ, recycleBins} = Lookup[structure, {"s_s", "s_z", "updated_recycle_s", "updated_recycle_z", "updated_recycle_bins"}];
   (*Step 4:Update recycle tensors*)
   {recycleS, recycleZ} = {sS, sZ};];
  structure]
In[7]:=
Options[esmFoldInference] = {"Size" -> "V1", "numRecycles" -> 0, "Confidence" -> True, "chainLinker" -> StringRepeat["X", 25], "residueIndexOffset" -> 512, TargetDevice -> "CPU"};
esmFoldInference[sequence_, OptionsPattern[]] := Module[{seq, aatype, pdbStr, residx, chainIndex, linkerMask, seqs, esm, structure, output, atom37AtomExists, meanPlddt},(*Handle single sequence input*)
  (*Batch encode sequences*)
  {aatype, residx, linkerMask, chainIndex} = encodeSequence[sequence, OptionValue["residueIndexOffset"], OptionValue@"chainLinker"];
  (*Step 1:Run the ESM Language Model Wrapper*)
  esm = NetModel[{"ESMFold-V1 Trained on UR50 Data", "Part" -> "LanguageModel", "Size" -> OptionValue@"Size"}][
    aatype, TargetDevice -> OptionValue@TargetDevice];
  (*Step 2:Run the Trunk Model*)
  structure = trunkModel[esm@"s_s_0", esm@"s_z_0", aatype, OptionValue@"numRecycles", OptionValue@TargetDevice];
  (*Step 3:Post-process the outputs*)
  output = NetModel[{"ESMFold-V1 Trained on UR50 Data", "Part" -> "PostProcessing"}][<|"aa" -> aatype, "frames_0" -> structure["frames"], "sidechain_frames_0" -> structure["sidechain_frames"], "unnormalized_angles_0" -> structure["unnormalized_angles"], "angles_0" -> structure["angles"], "positions_0" -> structure["positions"], "single_0" -> structure["single"], "states_0" -> structure["states"], "s_s_0" -> structure["s_s"], "s_z_0" -> structure["s_z"]|>, TargetDevice -> OptionValue@TargetDevice];
  (*Adjust atom37_atom_exists*)
  atom37AtomExists = output["atom37_atom_exists"]*linkerMask;
  (*Compute mean_plddt*)
  meanPlddt = Total[output["plddt"]*atom37AtomExists, All]/
    Total[atom37AtomExists, All];
  (*Process results*)
  output = Association["positions" -> output["positions"], "aatype" -> output["aatype"], "atom14_atom_exists" -> output["atom14_atom_exists"], "residx_atom14_to_atom37" -> output["residx_atom14_to_atom37"], "residx_atom37_to_atom14" -> output["residx_atom37_to_atom14"], "atom37_atom_exists" -> output["atom37_atom_exists"], "residue_index" -> output["residue_index"], "plddt" -> output["plddt"]];
  (*Add chain_index to output*)
  output["chain_index"] = chainIndex;
  output["mean_plddt"] = meanPlddt;
  pdbStr = outputToPDB@output;
  If[OptionValue["Confidence"], {pdbStr, output["mean_plddt"]},
   pdbStr]
  ]
In[8]:=
Options[netevaluate] = {"Size" -> "V1", "numRecycles" -> 3, TargetDevice -> "CPU", "Confidence" -> True};
netevaluate[sequence_, OptionsPattern[]] := Enclose@Block[{pdb},
   pdb = esmFoldInference[sequence, "Size" -> OptionValue["Size"], "numRecycles" -> If[OptionValue["Size"] == "V1", OptionValue["numRecycles"], 1], "Confidence" -> OptionValue["Confidence"], TargetDevice -> OptionValue[TargetDevice]];
   If[OptionValue["Confidence"], Association@{"Structure" -> ImportString[pdb[[1]], {"PDB", "BioMolecule"}, "PredictedStructure" -> True], "Confidence" -> pdb[[2]]}, ImportString[pdb, {"PDB", "BioMolecule"}, "PredictedStructure" -> True]]
   ]

Basic usage

Get a sequence:

In[9]:=
testSequence = BioSequence[
  "Peptide", "SLIVTTILEEPYVLFKKSDKPLYGNDRFEGYCIDLLRELSTILGFTYEIRLVEDGKYGAQDNGQWNGMVRELIDHKADLAVAPLAITYVREKVIDFSAPFMTLGISILYRKGTPIDSADDLAKQTKIEYGAVEDGATMTFFKKSKISTYDKMWAFMSSRRQSVLVKSNEEGIQRVLTSDYAFLMESTTIEFVTQRNCNLTQIGGLIDSKGYGVGTPMGSPYRDKITIAILQLQEEGKLHMMKEKWW", {}];

Obtain the sequence’s predicted structure:

In[10]:=
detection = netevaluate[testSequence];

The detection includes the predicted structure and confidence score:

In[11]:=
detection
Out[11]=

Plot the result:

In[12]:=
BioMoleculePlot3D[detection["Structure"]]
Out[12]=

Feature extraction

Define two sets of amino acid sequences originating from different protein families, namely enzymes and structural proteins:

In[13]:=
enzymes = Map[Style[#, Red] &, {"ACGYLKTPKLADPPVLRGDSSVTKAICKPDPVLEK", "GVALDECKALDYLPGKPLPMDGKVCQCGSKTPLRP", "VLPGYTCGELDCKPGKPLPKCGADKTQVATPFLRG", "TCGALVQYPSCADPPVLRGSDSSVKACKKLDPQDK", "GALCEECKLCPGADYKPMDGDRLPAAATSKTRPVG", "PAVDCKKALVYLPKPLPMDGKVCRGSKTPKTRPYG", "VLGYTCGALDCKPGKPLPKCGADKTQVATPFLRGA", "CGALVQYPSCADPPVLRGSDSSVKACKKLDPQDKT", "ALCEECKLCPGADYKPMDGDRLPAAATSKTRPVGK", "AVDCKKALVYLPKPLPMDGKVCRGSKTPKTRPYGR"}];
structuralProteins = Map[Style[#, Green] &, {"VGKGFRYGSSQKRYLHCQKSALPPSCRRGKGQGSAT", "KDPTVMTVGTYSCQCPKQDSRGSVQPTSRVKTSRSK", "PLVGKACGRSSDYKCPGQMVSGGSKQTPASQRPSYD", "CGKKLVGYPSSKADVPLQGRSSFSPKACKKDPQMTS", "RKGVASLYCSSKLSCKAQYSKGMSDGRSPKASSTTS", "RPKSAASCEQAKSYRSLSLPSMKGKVPSKCSRSKRP", "RSDVSYTSCSQSKDCKPSKPPKMSGSKDSSTVATPS", "LSTCSKKVAYPSSKADPPSSGRSSFSMKACKKQDPPV", "RVGSASSEPKSSCSVQSYSKPSMSGDSSPKASSTSK", "QPSASNCEKMSSYRPSLPSMSKGVPSSRSKSSPPYQ"}];
In[14]:=
sequences = Join[enzymes, structuralProteins];

Define the feature embeddings extractor using ESMFold:

In[15]:=
extractor = Map[Max, Transpose@
     NetModel[{"ESMFold-V1 Trained on UR50 Data", "Size" -> "150M"}][
        encodeSequence[#][[1]]]["s_s_0"][[1]]] &;

Visualize the features of the protein sequences:

In[16]:=
FeatureSpacePlot[sequences, FeatureExtractor -> extractor, LabelingSize -> 90, LabelingFunction -> Callout]
Out[16]=

Advanced usage

When working with proteins composed of multiple chains (multimers), a glycine linker (commonly 25 residues long) is automatically inserted between chains to create a single continuous sequence suitable for structure prediction. Glycine is typically chosen due to its small size and minimal structural impact. To demonstrate this, we’ll use the antibody 3HFM, which consists of three chains:

In[17]:=
(* Evaluate this cell to get the example input *) CloudGet["https://www.wolframcloud.com/obj/737d995e-aa59-4903-b542-f2862eac724e"]

Get the predicted structure and confidence score:

In[18]:=
detection = netevaluate[testMolecule];
In[19]:=
detection
Out[19]=

Visualize the structure:

In[20]:=
BioMoleculePlot3D[detection["Structure"]]
Out[20]=

Network result

The following diagram shows the modular flow of the ESMFold inference process, which takes a protein sequence and outputs its 3D structure:

Out[50]=

The process starts with a raw amino acid sequence with two chains:

In[51]:=
testSequence = {"MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHF", "VHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFG"};

The encodeSequence function converts the input protein sequence into numerical lists representing amino acid types, residue indices, a linker mask and chain identifiers for each residue:

In[52]:=
{aatype, residx, linkerMask, chainIndex} = encodeSequence[testSequence];

The pretrained ESM-2 language model extracts contextual embeddings from the input sequence:

In[53]:=
esm = NetModel[{"ESMFold-V1 Trained on UR50 Data", "Part" -> "LanguageModel"}][aatype];

The result includes initial single (s_s_0) and pair (s_z_0) representations. The single representation encodes features for each residue, while the pair representation encodes features for every residue pair, enabling joint local and relational information during folding:

In[54]:=
Dimensions /@ esm
Out[54]=

The Folding Trunk is a deep neural network composed of 48 Folding Blocks followed by an additional eight-block Structure Module; together, they refine the internal representations through multiple recycling steps, improving the predicted structure progressively. Initialize input metadata (mask and residue indices) and set up tensors for recycling intermediate representations:

In[55]:=
(*Initialize the mask and residx*)
{sS0, sZ0} = {esm@"s_s_0", esm@"s_z_0"};
{b, l} = Dimensions@aatype;
{mask, residx} = {ConstantArray[1, Dimensions@aatype], {Range[l] - 1}};
(*Initialize recycling tensors*)
{recycleS, recycleZ, recycleBins} = {ConstantArray[0, Dimensions@sS0],
    ConstantArray[0, Dimensions@sZ0], ConstantArray[0, {1, l, l}]};

The recycling loop runs multiple times, each time passing the current representations and recycled values into the Folding Trunk model, which updates the representations and predicts the structure. Since the Folding Trunk model already includes the Structure Module internally, it returns both the refined internal features and the predicted atomic structure in a single call:

In[56]:=
(*Recycling loop*)
{sS, sZ} = {sS0, sZ0};
For[i = 0, LessEqual[i, 3], Increment@i, structure = NetModel[{"ESMFold-V1 Trained on UR50 Data", "Part" -> "FoldingTrunk"}][<|"aa" -> aatype, "s_s_0" -> sS0, "s_z_0" -> sZ0, "recycle_s" -> recycleS, "recycle_z" -> recycleZ,
      "recycle_bins" -> recycleBins, "residx" -> residx, "mask" -> mask|>
    ];
  {sS, sZ, recycleS, recycleZ, recycleBins} = Lookup[structure, {"s_s", "s_z", "updated_recycle_s", "updated_recycle_z", "updated_recycle_bins"}];
  (*Update recycle tensors*)
  {recycleS, recycleZ} = {sS, sZ};
  ];

Visualize the initial single and pair representations:

In[57]:=
GraphicsColumn[{MatrixPlot[sS[[1]], FrameTicks -> {Automatic, None}, PlotLabel -> "Single Representation (s_s)", ImageSize -> Large, ColorFunction -> "Rainbow"], ArrayPlot[sZ[[1, All, All, 1]], ColorFunction -> "Rainbow", FrameTicks -> Automatic, PlotLabel -> "Pair Representation (s_z, Channel 1)", ImageSize -> Large]}]
Out[57]=

The post-processing part predicts the final atomic coordinates, frames, angles and sidechains. It also outputs the per-residue confidence (pLDDT) and prepares the data for visualization (PDB string):

In[58]:=
output = NetModel[{"ESMFold-V1 Trained on UR50 Data", "Part" -> "PostProcessing"}][<|"aa" -> aatype, "frames_0" -> structure["frames"], "sidechain_frames_0" -> structure["sidechain_frames"], "unnormalized_angles_0" -> structure["unnormalized_angles"], "angles_0" -> structure["angles"], "positions_0" -> structure["positions"], "single_0" -> structure["single"], "states_0" -> structure["states"], "s_s_0" -> structure["s_s"], "s_z_0" -> structure["s_z"]|>];
(*Adjust atom37_atom_exists*)
atom37AtomExists = output@"atom37_atom_exists";
(*Compute mean_plddt*)
meanPlddt = Total[output["plddt"]*atom37AtomExists, All]/
   Total[atom37AtomExists, All];
(*Process results*)
output = Association["positions" -> output["positions"], "aatype" -> output["aatype"], "atom14_atom_exists" -> output["atom14_atom_exists"], "residx_atom14_to_atom37" -> output["residx_atom14_to_atom37"], "residx_atom37_to_atom14" -> output["residx_atom37_to_atom14"], "atom37_atom_exists" -> output["atom37_atom_exists"], "residue_index" -> output["residue_index"], "plddt" -> output["plddt"]];
(*Add chain_index to output*)
output["chain_index"] = chainIndex;
output["mean_plddt"] = meanPlddt;

The final output includes atomic positions and confidence scores, used to render the 3D protein structure. Convert the output to PDB format:

In[59]:=
pdbStr = outputToPDB@output
Out[59]=

Get the confidence score:

In[60]:=
output["mean_plddt"]
Out[60]=

Plot the result:

In[61]:=
BioMoleculePlot3D@
 ImportString[pdbStr, {"PDB", "BioMolecule"}, "PredictedStructure" -> True]
Out[61]=

Resource History

Reference

  • Z. Lin, H. Akin, R. Rao, B. Hie, Z. Zhu, W. Lu, N. Smetanin, R. Verkuil, O. Kabeli, Y. Shmueli, A. Dos Santos Costa, M. Fazel-Zarandi, T. Sercu, S. Candido, A. Rives, "Evolutionary-Scale Prediction of Atomic-Level Protein Structure with a Language Model," DOI: 10.1126/science.ade2574 (2023)
  • Available from: https://github.com/facebookresearch/esm
  • Rights: MIT License