IntroductionBasicsEinsum syntaxEinsum key terminologyEinsum algorithmWhat does it mean to "sum across" a dimension in a tensor?AdvancedHow does einsum support broadcasting? Give an example for batch matrix multiplication.How does Numpy and Pytorch einsum differ?ExamplesHow to think about multi-rank tensor operationsHow to think about a tensorOuter product (our ‘default’)Element-wise (dim-sharing)Inner productConvolutionParticular types of convolution
Dealing with linear algebra for tensors with rank >2 (i.e. beyond scalars, vectors and matrices) can be a challenge. We're used to matrix operations, but things get a little more tricky when our tensors get larger than that.
Fortunately, someone has helped us out - none other than Albert Einstein! Using Einstein notation we can express these more complex linear algebra operations using simple mathematical notation.
Inspired by this, the folks at Numpy introduced
einsum, which is a simple way of dealing with multi-rank tensors. Here's how it works.
Einsum enables multiplication, summation, and permutation between input arrays or operands. For each operand we supply a list of subscript labels which names its dimensions. We also provide output labels which implicitly define how the operands are combined.
np.einsum(pattern, A, B, ...)
patterntakes the form: e.g.
Operand: input array
Subscript label: name assigned to an operand rank
The logic of einsum is as follows:
->omitted: fill in with non-shared letters in alphabetical order (different from just omitting RHS)
- Assign an index counter to every unique letter
- Count across the letters on the RHS. At each step: [if no letters on RHS must still compute this once]
- Count across the letters not on the RHS. At each step: [if all letters on RHS must still compute this once]
- Index into each array and multiply the resulting scalars
- Sum the multiplied result
- Store value in output array, according to permutation of letters on RHS
Index into the tensor at every position in that dimension, take the resulting (sub-)tensors and return their element-wise sum.
Replace broadcasted arrays with
einsum('…ij,…jk’, A, B).
Pytorch allows dimensions covered by the ellipsis to be summed over (i.e. not be on the RHS).
Let A and B be two 1D arrays of compatible shapes
Let A and B be two 2D arrays with compatible shapes
('ij,ij->ij', A, B)
A * B
('ij,ji->ij', A, B)
A * B.T
('ij,jk', A, B)
('ij,kj->ik', A, B)
('ij,kj->ikj', A, B)
A[:, None] * B
('ij,kl->ijkl', A, B)
A[:, :, None, None] * B
Thinking about these multi-rank operations can be confusing and sometimes unintuitive. The aim of this section is to provide a way of thinking about doing:
- Einsums (i.e. tensor-products)
- Spatially separable
- Depth-wise separable
- Any generalisation
always counting by stepping through indices. Question is how ops change this counting procedure.
Question: how to think about effect of single op wrt entire counting procedure
At each counting step we have a scalar selected. Stepping through a dimension (for an op) involves fixing the other indices and just changing that one, giving us its corresponding scalar. We do this again for all of the other indices belonging to other dims.
We can also think of this in terms of just indexing into the dimension we’re stepping through. When we do this we cycle through a series of rank tensors.
If we want to then consider cycling through dimensions, we can simply imagine getting back a series of tensors.
- The scalar in a tensor has one fundamental intepretation
- The dims represent discrete (categorical or ordinal) variables
- The dim variables must be orthogonal - i.e. there must be a valid interpretation of every walue wrt. all other dim values
- To extract a scalar you simply have to provide values for all these variables.
Example 1: Schools
Our go-to example will be information about schools. First, consider tensor which has the following variables:
And the scalar’s interpretation is the district’s spend on each subject per yeargroup.
Tensor has the fields:
- Student rank
Where the scalar’s interpretation is the student’s end-of year grade.
(Just to make this a bit more interpretable, in both these cases we normalise the values by their L2 norm. This means that inner products give us cosine similarity.)
Various tensor operations usign these two operands will now tell us about the relationship between school spend and student grades.
Example 2: Weight multiplication (XW view)
We take the first operand to be an activation matrix, and the second to be a weight matrix.
Mechanism: cycle through indices of each, multiplying every combination and writing to output tensor.
Interpretation: just gives us all combinations of scalar products between the tensors.
School example: n/a
XW view: we get a copy of the X tensor multiplied (as a whole) by each scalar in the W tensor independently.
Mechanism: as before, but we now count the shared dim-pair in lockstep.
Interpretation: as before, but we no longer consider all combinations of the shared dims; only between their pairs
School example: n/a
XW view: for each index in the
cdimensions, we get the remaining X tensor (in this case a vector in the
adimension) multiplied by each scalar in the remaining W tensor.
Mechanism: as before, but we now sum over the shared dim-pair’s products.
Interpretation: for every value of the non-shared dims (
d) we compute the scaled similarity (i.e. inner product) between the vectors (or tensors flattened to vectors) remaining in the shared dims.
School example: for each combination of district and student rank, computes the spend-grade similarity across the subject-yeargroup combo tensor. Each instance of this combo tensor gives a set of values for a given district student-rank pair. We turn this into a score, reflecting how similar the spend is to the grade.
XW view: for each index in the
adimension, take the remaining X tensor and do a similarity computation with a set of weights in the
cdimensions. This gives a scalar for the
adimension. We’re going to repeat this for the
d’output channels’. In all, we can see this as ‘replacing’ the dims X shares with W with those dims in W it doesn’t.
Simplification: we can just flatten the shared dims into one shared dim, and if we have multiple non-shared dims for one operand they can be flattened too. This turns everything into a matmul!