High Performance GNNs in JAX

Speaker: Jonathan Godwin (DeepMind)

šŸ¦’ Jraph

A high-performance GNN library.
Built in JAX, powered by XLA.
Easily forkable.
(More) easily sharded for ā¬†ļø models (i.e. large graphs).

JAX Overview

Key features:
  • Normal numpy + autodiff. E.g.. twice differentiate ... just by calling grad() twice!
  • Built-in XLA (recall: XLA (Accelerated Linear Algebra) = domain-specific compiler for linear algebra) - one-liner for compilation
  • Easy SPMD parallelism - see pmap function, which gives automatic cross-device parallelism

What do we optimise for?

Interested in wide-range of applications - focus is on the most general Graph Nets framework (i.e. the Battaglia paper)
Note: a bit different for a GCN where compute and storage scales just in the number of nodes - no explicit edge features. Although Jraph does have an efficient GNC-style implementation.
Note: a bit different for a GCN where compute and storage scales just in the number of nodes - no explicit edge features. Although Jraph does have an efficient GNC-style implementation.
Examples given of fluid/materials & mesh physics prediction papers. I recall that the former used the (old?) deepmind graph nets library. Assume the latter paper uses Jraph?

Jraph Design

šŸ§  What are the key design criteria?
Functional: don't want to have explicit state - we want to outline a compute pattern āž”ļø very much in favour of this approach; abstracting state from flow not only computational gain, but also much more intuitive conceptually
Forkable: copy-pastable code; limit the number of abstractions āž”ļø I really like this criterion, often overlooked
Padding & Masking: XLA has a static shape requirement for tensors, which filters down into JAX and Jraph - so we have to pad graphs if they're dynamically shaped and we need masking to make sure we don't accidentally pass messages "from padding"
Easy to use Jax parallelism for new models
Example code. Uses haiku (Sonnet for JAX) to convert graph into kind of functional form required for JAX.

make_graph() constructs input data, forward() defines GN with single linear layer for all fns.
Example code. Uses haiku (Sonnet for JAX) to convert graph into kind of functional form required for JAX. make_graph() constructs input data, forward() defines GN with single linear layer for all fns.

Functional GNNs

notion image
Key idea is that GNNs as functions that define a computation pattern (if you want to use state ā†’ Haiku or Flax).
Update functions are how learned params are injected. Inserted via a closure during network construction.

Forkable

Simple enough that you could just copy-paste a model and then modify to run your own version, without having to wory about various abstractions.

Static Shape Utilities

Provide tools to deal with necessary padding and masking (XLA req.).

Support For Nests

Can easily make recurrence models work - e.g. LSTM edge-to-node aggregation (pictured), which easily allows LSTM state and node embedding to be handled:
Compare this to code in TF graphnets library, which is much more verbose
Compare this to code in TF graphnets library, which is much more verbose

Advanced Features Coming Soon: Massive GNNs

How might you implement a distributed GNN?
E.g. graph has many millions of edges/nodes and can't fit on a single device.
But we know that edge and node updates are highly data parallel!
  1. Split up edge tensors onto different devices
  1. Update edges on separate devices
  1. For node updates, naively: we might send a single edge row between devices (expensive! avoid!)
  1. Instead, replicate all nodes on all devices and do a partial reduce of edge updates available within a device
  1. Becore doing a second cross-device sum reduce
    1. This is very simple and very neat - takes us from to and spreads the workload out nicely.
      notion image
Also, note that these node summations simply send the entire node tensor (no transpose, permutations, etc) improving utilisation of network bandwidth.
This is all based on the assumption that we can fit nodes in mem but not edges.
The pmap and psum commands in Jax will take care of edge and node updates respectively šŸ™‚

Results

(on TPUv1)
notion image
Very large systems can now be trained! On the path to GPT-3 for graph nets (maybe šŸ˜)

Q & A

What are the considerations for data (batch) vs model vs pipeline parallelism?
Depends on the problem at hand. In graph-land(!) we have lots of large graphs, e.g. using large embeddings as well as many nodes/edges, so data-parallelism will be required.
Models themselves tend not to be very deep/have many params, so perhaps model parallelism not useful.
Pipeline model parallelism may be more useful. Benefit: you can send the entire graph across each stage in your pipeline, which will improve bandwidth utilisation.
TODO: skim this and this for more info
Padding & masking may introduce overheads for computation, how do we optimise that compared to batching everything in a single graph (what PyTorch geometric would do)?
Can't really get away from if you want to run on TPUs or use XLA.
We would still take the approach of padding everything into a single batch - e.g. multiple graphs in a batch might have different sizes/shapes.
But can utilise a dynamic number of batches.