OverviewAbstractInnovationsNetworkProteinsTrainingOverviewMSA ClusteringModel InputsSelf-DistilationInferenceInput Feature EmbeddingsRelposTemplate embeddingExtra MSA stackEvoformerOverviewMSA row-wise gated self-attention with pair biasMSA column-wise gated self-attentionTransitionOuter product meanTriangular multiplicative updateTriangular self-attentionStructure Module
Overview
Abstract
Motivation: Proteins are essential to life. Knowing their 3D structure helps us understand how they work.
Context:
- The structure of 100,000 proteins has been determined experimentally
- This process takes months to years per protein
- Billions of known protein sequences don't have their structure mapped
Protein folding problem: mapping 1D amino acid sequence to 3D protein structure
Method: novel ML approach leveraging
- Physical knowledge
- Biological knowledge
- Multi-sequence alignment
Experiments: entered into CASP14
Results:
- Greatly improved over alternative methods
- Accuracy competitive with experimental methods in a majority of cases
Innovations
- new architecture to jointly embed multiple sequence alignments (MSAs) and pairwise features
- new output representation and associated loss
- new equivariant attention architecture
- use of intermediate losses
- masked MSA loss to jointly train with structure
- learning from unlabelled protein sequences via self-distillation
- self-estimates of accuracy
Network
Process: primary amino acid sequence + aligned sequences of homologues → 3-D coordinates of all heavy atoms
- Trunk:
- Repeated (new) Evoformer layers
- Outputs array representing "a processed MSA"
- Outputs array representing residue pairs
- Structure Module
- predicts a rotation + translation per residue
- (simultaneous) local refinement across structure
- (novel) equivariant transformer for reasoning about unrepresented side chain atoms
Initially: no rotation & position is origin
Proteins
See
Proteins 101Training
Overview
Inputs:
- Primary sequence
- MSA: sequences from related proteins
- Templates: 3D atom coordinates from related proteins
Data source: Protein Data Bank (PDB) + other databases for MSA & templates
Outputs:
- Atom coordinates
- Distogram (distance histogram)
- Per-residue confidence scores
Pre-processing steps: (every time we see a protein)
- Self-distilation
- Filtering training pairs
- MSA block deletion
- MSA clustering
- Residue cropping
MSA Clustering
Problem: we want to leverage a large/variable number of MSA sequences, but in a smaller/fixed-size batch
Solution: MSA clustering
MSA clustering:
- Fixed number of sequences randomly added to batch (always inc. original seq.)
- Stochastic masking applied
- Remaining sequences matched with closest chosen sequence via Hamming distance ➡️ to create clusters
- Fixed-size cluster aggregate statistics appended to base sequence representation
- Set of random non-centre sequences appended to batch
Model Inputs
Main inputs to the model:
- Target features
- MSA features
- Extra MSA features
- Template pair features
- Template angle features
Target features:
- Per:
- amino acid in the target sequence
- genetic amino acids (20 possible)
- one-hot feature (across genetic amino acid)
MSA Features:
- Per:
- amino acid in the target sequence
- cluster
- genetic amino acids (20 possible)
- one-hot representation (across genetic amino acids) for base sequence
- + continuous representation for cluster histogram
Extra MSA Features: like MSA features but without histogram
Template pair features:
- Per:
- template
- residue pair
- Feature vector based on particular 3D structural properties
Template angle features:
- Per:
- template
- residue
- Feature vector based on angles
Self-Distilation
75% of training examples from self-distillation set, 25% from the Protein Data Bank (PDB)
Self-distilation outline:
- Train model on basic dataset
- Perform inference on unlabelled dataset
- Measure confidence of predictions
- Add most confident prediction pairs to dataset
Inference
Inference-specific processes:
- Recycling
- Ensembling
Input Feature Embeddings
Relpos
relpos
is just an embedding corresponding to the relative distance between two indicesTemplate embedding
- Each template independently processed by bottom-half evoformer blocks
- Attention computed from output to initial main pairwise representation
- ↪️ this is done independently for each pairwise dim ("pointwise") , summing across sequence dim
Extra MSA stack
Modified evoformer blocks:
- ⬇️ representation sizes as ⬆️ number of sequences
- Uses global variant of column-wise self-attention:
- Key and value use same representation across attention heads
- Query computes and then means multiple attention head values
(I assume this is a compute-saving measure as we have ⬆️ sequences)
(we use just the pairwise output representation)
Evoformer
Overview
Key principle: view protein structure prediction as graph inference problem where the presence of an edge indicates nearby residues
48 evoformer blocks - the relationship between the two matrices is refined many times!
Done via the two representations outlined above:
Pair representations: relationship info
MSA representations: relationship with homologous sequences per-residue
All layers have residual connections & many have dropout
MSA uses axial/criss-cross attention: attention for tensors where we compute attention independently across and then across
MSA row-wise gated self-attention with pair bias
What does row-wise mean? Attention requires a 2D matrix, with one sequence dimension and one channel dimension. We have a set channel dimension already here, so we just need to select one of the other two dimensions as our "sequence". The one we don't select is then processed independently across each slice.
Row-wise attention gives weights for residue-pairs
➡️ means we can combine in pair representations here
➡️ do this via added bias term
MSA column-wise gated self-attention
Transition
Same operation for MSA & pairwise features
Outer product mean
Triangular multiplicative update
Incoming is identical but uses the columns of the matrix
Triangular self-attention
- Row-wise gated self-attention + triangle update
- Row-wise means we calculate attention values between all pairs of column values for the given row ➡️ matrix of outgoing edge attn. values
- To complete the triangle we add these attention values to (a projection of) the values in the original tensor
Same for ending node but across columns