🧬

Alpha Fold

Title
Highly accurate protein structure prediction with AlphaFold
Authors
John Jumper, Richard Evans, Alexander Pritzel, Tim Green, Michael Figurnov, Olaf Ronneberger, Kathryn Tunyasuvunakool, Russ Bates, Augustin Žídek, Anna Potapenko, Alex Bridgland, Clemens Meyer, Simon A. A. Kohl, Andrew J. Ballard, Andrew Cowie, Bernardino Romera-Paredes, Stanislav Nikolov, Rishub Jain, Jonas Adler, Trevor Back, Stig Petersen, David Reiman, Ellen Clancy, Michal Zielinski, Martin Steinegger, Michalina Pacholska, Tamas Berghammer, Sebastian Bodenstein, David Silver, Oriol Vinyals, Andrew W. Senior, Koray Kavukcuoglu, Pushmeet Kohli, Demis Hassabis
Date
2021
Venue
Nature
DBLP
Keywords

Overview

Abstract

Motivation: Proteins are essential to life. Knowing their 3D structure helps us understand how they work.
Context:
  • The structure of 100,000 proteins has been determined experimentally
  • This process takes months to years per protein
  • Billions of known protein sequences don't have their structure mapped
Protein folding problem: mapping 1D amino acid sequence to 3D protein structure
Method: novel ML approach leveraging
  1. Physical knowledge
  1. Biological knowledge
  1. Multi-sequence alignment
Experiments: entered into CASP14
Results:
  1. Greatly improved over alternative methods
  1. Accuracy competitive with experimental methods in a majority of cases

Innovations

  • new architecture to jointly embed multiple sequence alignments (MSAs) and pairwise features
  • new output representation and associated loss
  • new equivariant attention architecture
  • use of intermediate losses
  • masked MSA loss to jointly train with structure
  • learning from unlabelled protein sequences via self-distillation
  • self-estimates of accuracy

Network

Process: primary amino acid sequence + aligned sequences of homologues → 3-D coordinates of all heavy atoms
  1. Trunk:
    1. Repeated (new) Evoformer layers
    2. Outputs array representing "a processed MSA"
    3. Outputs array representing residue pairs
  1. Structure Module
    1. predicts a rotation + translation per residue
    2. (simultaneous) local refinement across structure
    3. (novel) equivariant transformer for reasoning about unrepresented side chain atoms
    4. Initially: no rotation & position is origin

Proteins

See
🧪
Proteins 101

Training

Overview

Inputs:
  1. Primary sequence
  1. MSA: sequences from related proteins
  1. Templates: 3D atom coordinates from related proteins
Data source: Protein Data Bank (PDB) + other databases for MSA & templates
Outputs:
  1. Atom coordinates
  1. Distogram (distance histogram)
  1. Per-residue confidence scores
Pre-processing steps: (every time we see a protein)
  1. Self-distilation
  1. Filtering training pairs
  1. MSA block deletion
  1. MSA clustering
  1. Residue cropping

MSA Clustering

Problem: we want to leverage a large/variable number of MSA sequences, but in a smaller/fixed-size batch
Solution: MSA clustering
MSA clustering:
  1. Fixed number of sequences randomly added to batch (always inc. original seq.)
  1. Stochastic masking applied
  1. Remaining sequences matched with closest chosen sequence via Hamming distance ➡️ to create clusters
  1. Fixed-size cluster aggregate statistics appended to base sequence representation
  1. Set of random non-centre sequences appended to batch

Model Inputs

Main inputs to the model:
  1. Target features
  1. MSA features
  1. Extra MSA features
  1. Template pair features
  1. Template angle features
Target features:
  • Per:
    • amino acid in the target sequence
    • genetic amino acids (20 possible)
  • one-hot feature (across genetic amino acid)
MSA Features:
  • Per:
    • amino acid in the target sequence
    • cluster
    • genetic amino acids (20 possible)
  • one-hot representation (across genetic amino acids) for base sequence
  • + continuous representation for cluster histogram
Extra MSA Features: like MSA features but without histogram
Template pair features:
  • Per:
    • template
    • residue pair
  • Feature vector based on particular 3D structural properties
Template angle features:
  • Per:
    • template
    • residue
  • Feature vector based on angles

Self-Distilation

75% of training examples from self-distillation set, 25% from the Protein Data Bank (PDB)
Self-distilation outline:
  1. Train model on basic dataset
  1. Perform inference on unlabelled dataset
  1. Measure confidence of predictions
  1. Add most confident prediction pairs to dataset

Inference

Inference-specific processes:
  1. Recycling
  1. Ensembling

Input Feature Embeddings

notion image

Relpos

relpos is just an embedding corresponding to the relative distance between two indices

Template embedding

  • Each template independently processed by bottom-half evoformer blocks
  • Attention computed from output to initial main pairwise representation
  • ↪️ this is done independently for each pairwise dim ("pointwise") , summing across sequence dim

Extra MSA stack

Modified evoformer blocks:
  • ⬇️ representation sizes as ⬆️ number of sequences
  • Uses global variant of column-wise self-attention:
    • Key and value use same representation across attention heads
    • Query computes and then means multiple attention head values
      • (I assume this is a compute-saving measure as we have ⬆️ sequences)
(we use just the pairwise output representation)

Evoformer

Overview

Key principle: view protein structure prediction as graph inference problem where the presence of an edge indicates nearby residues
48 evoformer blocks - the relationship between the two matrices is refined many times!
Done via the two representations outlined above:
Pair representations: relationship info
MSA representations: relationship with homologous sequences per-residue
All layers have residual connections & many have dropout
MSA uses axial/criss-cross attention: attention for tensors where we compute attention independently across and then across
notion image
notion image

MSA row-wise gated self-attention with pair bias

What does row-wise mean? Attention requires a 2D matrix, with one sequence dimension and one channel dimension. We have a set channel dimension already here, so we just need to select one of the other two dimensions as our "sequence". The one we don't select is then processed independently across each slice.
Row-wise attention gives weights for residue-pairs ➡️ means we can combine in pair representations here ➡️ do this via added bias term
notion image

MSA column-wise gated self-attention

notion image

Transition

notion image
Same operation for MSA & pairwise features

Outer product mean

notion image

Triangular multiplicative update

notion image
Incoming is identical but uses the columns of the matrix

Triangular self-attention

  • Row-wise gated self-attention + triangle update
  • Row-wise means we calculate attention values between all pairs of column values for the given row ➡️ matrix of outgoing edge attn. values
  • To complete the triangle we add these attention values to (a projection of) the values in the original tensor
notion image
Same for ending node but across columns

Structure Module

notion image