Introduction
Hardware challenges of LLM inference:
- Inference 16-bit 175B param models has massive memory (350GB or 5x A100s) & latency overheads
- INT8 quant of weights & acts is good approach to reduce this
- ~ mem & latency
LLM-specific issues:
- Unlike smaller models, activations become hard to (INT8?) quantise, due to systemic outliers
- Leads to large quantization errors and accuracy degradation
INT8 quantisation solutions:
- ZeroQuant:
- activation quant = per-token
- weight quant = groups of channels
- Fine for GPT-J (6B)
- Degrades for OPT (175B)
- LLM.int8():
- Fixes this using mixed-precision decomposition
- Outliers in special FP16 tensor
- Not friendly for hardware
- SmoothQuant is best of both (without needing QAT)
- Observation: although activations contain problematic outliers, they are in consistent channels
- Based on this, SmoothQuant scales channels to be more similar, and adjust weights accordingly
- This moves quantisation difficulty from (changing) acts, to (constant) weights
- Simple to implement & integrated into FasterTransformer lib
Problem Analysis
- Recall that there are two steps here:
- Scale down for the INT8 quant
- Re-scale the values back up at the mathematically appropriate time (depends on type of quant)
- Hardware can efficiently support two kinds of quantisation - per-tensor, and outer-dim (per-token; per-out-channel):
- Outer-dim quant re-scaling can be implemented entirely after the matmul
- However, inner-dim quant re-scaling has to be done in the sum-reduction within the matmul itself. As the matmul is implemented in hardware, this alteration isn’t feasible
- Activations are hard to quantise (see Fig. 3 below) due to outlier channels
- Per-tensor is distorted by outlier channels
- Only other option is per-token, which means we have same problem
- If we could do inner-quantisation it would fix the issue:
- Weights much easier to quantise as no outliers
Method
SmoothQuant is a pre-processing step, after which you can use a standard quantisation method (they use three kinds, see Table 3)
it is not inner-scaling!
Taking the original X and W, consider the following:
Mathematically, equivalence is preserved, but computationally:
LHS: the scale is applied before the matmul (potentially fused into a previous op)
RHS: the scale is folded into the weights (offline)
Hence this transformation can all be done cheaply.
As increases above 1, it begins to transfer channel-outlierness from activations to weights!
The choice of determines this trade-off. At the extremes, it pushes all the problem into one of the two tensors. We can interpolate between these using:
Typically is generally a good balance. Quantisation suffers if either side is too “spiky” on the inner dim.
We can understand SmoothQuant as effectively sharing the outlier problem between weights and acts to mitigate it.
Three levels of quantisation considered for use with SmoothQuant. Note that we don’t need to resort to per-channel/group on weights any more:
Experiments
Experiments showing no degradation even for SmoothQuant-O3 across tasks and models (OPT, BLOOM, GLM).
The only other method without degradation (the others degrade hugely) is LLM.int8() - but this has nearly 2x the latency
Alpha sweet spot relatively narrow - this method doesn’t seem super robust! Would it work at 1T scale?