Where Does the Memory Go?
(Taken from the DeepSpeed ZeRO paper - note: this is a really good source)
The following needs to be stored
Permanent:
- parameters
- optimiser states
Temporary
- batches of training data
- activations (can then forget training data, but activations needed for 🔙-wards pass)
- batches of labels (just to compute loss)
- gradients (can then forget activations)
The three big things that take up memory (as they are O(#param)) are:
- parameters (can be fp16) →
- gradients (can be fp16) →
- optimiser states (in examples I've seen, fp32) →
- params (in higher precision)
- 1st moment estimates
- 2nd moment estimate
Total =
Note: for CNN-based models we have smaller models (e.g. 500M) and the bottleneck is activation memory. However, for transformer-based models it's the model states (i.e. O#param things)
Parallelism
Data parallelism:
- Same model/weights for each device/thread
- Different mini-batches
- Gradients need to be synchronized (averaged) after each pass through a mini-batch
Model parallelism:
- Same mini-batches for each device/thread
- Split model/weights
- Outputs need to be synchronised (stacked) after each layer
Data Parallelism
Drawback:
During backwards pass whole gradient must be sent between GPUs at each layer.
We therefore experience a slowdown relative to the number of weights in the model.
(very rough) example: we want to send 1GB of weight gradients across 32 GB/s PCIe connection → this will take 1/32 seconds!
Some overhead can be hidden by optimising the first (/i.e. last) backwards-pass layer while sending gradients to the next.
Also if batch size becomes too small we no longer utilise hardware well.
➡️ diminishing returns for scaling:
Data parallelism scales well for models like ResNet-50, which have a large number of convolutional layers with compact weight representations, but scales less well for other models with LSTM or fully-connected layers, which have more dense weight representations. (PipeDream paper)
Model Parallelism
(sometimes referred to as Tensor Parallelism)
There are two ways we can split a weight matrix across devices: along either the input or output dimensions:
( devices, weight dims (), batch size )
Split on input dim
Forward pass:
- split batch on input dim too
- across the devices, multiply:
- these represent the contribution to the output of the device's selected input dims
- add each of the matrices to get total output
- Total exchange size =
Backward pass:
- batch not split
- across the devices, multiply:
- these represent the contribution to the error of the device's selected input dims
- stack each of the matrices to get total error
- Total exchange size =
Total exchange size = . Best choice for large .
Split on output dim
Forward pass:
- batch not split
- across the devices, multiply:
- these represent the contribution to the output of the device's selected output dims
- stack each of the matrices to get total error
- Total exchange size =
Backward pass:
- split batch on output dim
- across the devices, multiply:
- these represent the contribution to the error of the device's selected output dims
- add each of the matrices to get total output
- Total exchange size =
Total exchange size = . Best choice for large .
Compare this to data parallelism, which exchanges parameters. Hence as long as or , model parallelism looks promising.
Harder to implement though, and whereas speedups (e.g. async) are possible for data parallelism, not clear that the same can be done for model parallelism?
For fixed-size layer, can become too granular.
Pipelining
(Confusingly, this is often known as model parallelism)
Subsequent groups of layers on different devices.
Benefits:
- Only have to communicate activations and their gradients across the network (small relative to total weights)
- One-to-one communication, rather than all-to-all
Problem: standard pipelining has very poor utilisation!
Solution: use an async approach (see image)
Further problem: this can lead to stale params ➡️
- The forward pass and the backward pass end up using different versions of the weights
- The param updates are applied to a changed set of weights
- Earlier layers are more stale
Further solution: We can address 1. (the most damaging problem) by storing multiple versions of the weights (memory issues!?)
Hybrid approach: use gradient accumulation across a number of microbatches, interleaving forward and backward passes → then occasionally do full synchronous (blocking) update (see image)
Distributed SGD
Synchronous
Nodes compute gradients locally and send to a central server, which aggregates and sends out updates to waiting nodes.
Drawbacks:
- As slow as the slowest node (scales poorly) → replicas
- Single point of failure → all-reduce type algorithm
- Network congestion → tree-like updates
Asynchronous
Nodes no longer wait for "synchronisation barrier".
Drawbacks:
- Stale gradients can lead to performance drop
- Still single point of failure
Gradient accumulation
- We want to optimise over a minibatch of size
- But can't fit that into memory
- We split it into microbatches, send them through the network, and only do the backwards pass after all of them have had a forward pass
- Note that we still end up with the same amount of activations stored.
In a distributed setup we would ideally do a local reduce before the all-reduce step.
Gradient Checkpointing / Rematerialisation
For a simple ff-NN with layers, the (forward and backward) computation graph is as follows:
The standard (computationally efficient) approach keeps the following purple values in memory at any given time:
We have to hold all the forward activations at once!
The checkpointing (memory efficient) approach keeps only selected (in this case the first) activations and re-computes as appropriate:
If we select checkpoints:
- memory from ➡️
- computation from
If we set then:
- memory from ➡️
- computation from ➡️
Given a complex computation graph, checkpointing certain nodes will give more of an advantage than others (depending on how many nodes it allows us to recompute). Specifically, we ideally want to checkpoint nodes which if removed would give us disjoint subgraphs. We call these articulation points (if none satisfy this, we may need to select a group). E.g.:
Load Balancing
Problem: in the above image a lot of work is done solely on GPU-1 → we have to store the losses and gradients at the output layer across all distributed batches. This can add significant overhead.
Solution: simply calculate the loss on each individual machine by spreading the labels:
DeepSpeed
Features:
- Mixed precision
- Model, data and pipeline parallelism (3D-parallelism)
- Zero Redundancy Optimizer (ZeRO)
- Dense transformer kernels
- Sparse attention
- 1-bit Adam & 1-bit LAMB
- Smart Gradient Accumulation
- Communication/Computation Overlap
- Advanced Parameter Search
- Progressive Layer Dropping
3D parallelism
All three types of parallelism in-one:
An example of micro-batch training with pipeline and data parallelism. Here Step is the all-reduce operation that updates parameters:
Zero-redundancy optimiser (ZeRO)
Uses data parallelism, but without replicating data. Instead partitions optimiser, gradient and parameter state and broadcasts when necessary:
- The processor/node that has the "current" params broadcasts them to all others
- They store them in a temporary buffer and then compute
- For the backwards pass, weights are broadcast in the same way
- So are final gradients, with each processor/node storing the aggregate grads of the params they "own"
- Optimisation done in parallel on each node - nothing needs to be communicated here!
- Activations can also be checkpointed and partitioned. An all-gather is used at the start of re-construction.
Linear memory reduction - e.g. with 64 GPUs, each needs 1/64 of the params needed for data parallel. only 1.5x communication overhead increase.
ZeRO vs Model Parallelism:
ZeRO:
- Doesn't have as bad communication issues as MP
- Less complex to implement
MP:
- Can reduce activation footpring
- Doesn't have the DP/ZeRO problem of min per-device minibatch size causing total minibatch size to increase
Can combine both!
Results: super-linear throughput scaling for fixed 60bn model-size, owing to reduced memory footprint allowing larger batch sizes per-GPU.
ZeRO-Offload
Can reduce the huge memory requirements of Adam by offloading to CPU memory and using CPU compute (on optimised version of algorithm)
ZeRO-Infinity
To train a trillion parameter model, currently requires 800 V100s just to satisfy memory requirements (e.g. using 3D parallelism).
ZeRO-Infinity can fit models with tens and even hundreds of trillions of parameters for training on current generation GPU clusters
Can be used to fine-tune trillion parameter models on a single NVIDIA DGX-2 node (16 GPUs)
GPT-3
We use a mixture of model parallelism within each matrix multiply and model parallelism across the layers of the network.
Also uses huge batch sizes:
Apparently with larger models, larger batch sizes can be used, although it starts smaller than reported here (see scaling laws paper).