Transformers are becoming the go-to model architecture in deep learning. With vision transformers (see ViT) achieving competitive performance to CNNs on the largest image datasets, and a proliferation of transformer variants emerging, the field is moving closer to a unified architecture across modalities.
An important question then, is what can we do to improve the transformer? One line of thought is simplicity - even if simplifying models initially hurts, perhaps in the long run it enables us to find breakthroughs with that more complex architectures wouldn’t allow. Given the field seems to be moving towards simpler architectures, this is a reasonable research direction.
With this in mind, several transformer alternatives have been proposed that adapt the architecture to only use MLPs. We explore some of them here, and the trade-offs they bring.
For reference, here’s a standard ViT:
What they change w.r.t standard transformer-encoder:
- tokens + embeddings ➡️ flattened patches + linear projections
That’s more-or-less it!
How to think about the token-mixing MLP The channel-mixing MLP is just our regular FFN. It’s easiest to think of this in terms of a single token. It takes this token, feeds its channels into a two-layer MLP which mix them and return a new token of the same dimensionality. The token-mixing MLP does exactly the same thing, but with a new “token” created by slicing our batch vertically. This token is a strange thing, but simply represents the information spread across the batch in a single channel. This MLP then has the effect of mixing across the tokens for each channel independently.
What they change w.r.t VIT:
- self-attention ➡️ token mixing MLP
- single output head ➡️ average of outputs + MLP (a.k.a. Global Average Pooling (GAP))
- doesn’t use positional embeddings (token-mixing MLPs are sensitive to the order of the input tokens)
- network loses variable sequence length property (unlike a SA layer input dim of token-mixing layer is fixed)
- they claim each MLP’s hidden size is independent of number of tokens & channels respectively, giving linear complexity (🧐 this seems a bit dubious as we typically set hidden size as a multiple of MLP input)
Comparison with CNNs:
Token-mixing: single-channel depth-wise convolutions of a full receptive field and parameter sharing
Channel-mixing: 1×1 convolutions
I think this is similar, but looks rather complicated. Will leave this for now...
What they change w.r.t VIT:
- Just the token mixing MLP
What they change w.r.t MLP-Mixer:
- Token mixer is a linear layer
- Layernorm replaces with affine transformation:
What they change w.r.t. ViT:
- Remove self-attention entirely & modify the FFN
- Add spatial gating unit after activation
- Split the input (TODO: what size & ratio?)
- Do norm and spatial projection (linear token mixing) on one half, use it to gate the other
X: (H,W,C) # Whole image! X^ = X.reshape(H,W,N,S) X_H = X^.einsum((H,W,N,S),(H,H^,S,S^)->(H^,W,N,S^)) X_W = X^.einsum((H,W,N,S),(W,W^,S,S^)->(H,W^,N,S^)) X_C = X .einsum((H,W,N,S),(N,N^,S,S^)->(H,W,N^,S^)) # same as (H,W,C),(C,C^)->(H,W,C^) X_o = (X_H + X_W + X_C).reshape(H,W,C).proj(C,C^)
What’s this actually doing? So we start with the entire image (assuming ) and split it into segments of size - these are groups in the channel dimension. We do the split such that . [image] And then we’re basically just going to take each axis and do a unique projection on that axis for each segment. Take the vertical axis. We’ll flatten into one dim, and into the other dim, then project/matmul on the latter. This is like doing a standard projection on , but we switch to a different projection every time we cross over a sub-patch boundary. It’s not too hard to see that if we do this on the “depth” axis , this is just a standard channel projection for the whole image (the kind we do in the FFN layer). We then sum all these at the end, and do another channel projection. What have we done overall? Basically a sum of three different segment-grouped matmuls.
ViT (assuming linear projection):
Recalling that we flatten each patch’s width, height and channel into a single token of size :
Params: , ,
Flops: , ,
Flops/param: , ,
It makes sense that the f/p is so much higher, as we basically “stripe” our matmuls across almost the whole image. I’m not really sure what I make of this. The grouping thing seems interesting, a lot like a groupconv. I need to think more about convs to understand this better, and just generally think about it a bit more...
- Why does MLP-mixer use GAP when the default in ViT is single head (though they found it didn’t matter)? I guess it’s conceptually a bit nicer, so why not?