25  Geometric Deep Learning

Every neural network you have encountered in this book operates on a very specific type of data: sequences (text), or dense grids (images). But the world is full of data that does not fit neatly into sequences or grids: molecules, social networks, protein structures, road maps, 3D point clouds, and the mesh surfaces of physical objects. How do you build neural networks for data with arbitrary, irregular structure?

This is the domain of geometric deep learning (Bronstein et al. 2021), a unifying framework that extends the core operations of neural networks (convolution, attention, pooling) from regular grids to non-Euclidean domains: graphs, manifolds, and other geometric objects.

ImportantThe Geometric Deep Learning Blueprint

Michael Bronstein, Joan Bruna, Taco Cohen, and Petar Veličković's monograph “Geometric Deep Learning: Grids, Groups, Graphs, Geodesics, and Gauges” (2021) is the foundational reference. It unifies most of deep learning under a single mathematical framework based on symmetries and invariances. If you want to understand why convolutions work for images, why transformers work for sequences, and how to design architectures for entirely new data types, this is the text to read. A free version is available at geometricdeeplearning.com.

25.1 Why Geometry Matters

The central insight: effective architectures exploit the symmetries of the data. A CNN works for images because images have translation symmetry (a cat is a cat regardless of where it appears in the frame). A transformer works for sequences because it attends to all positions (and position is handled explicitly through encodings). But what if your data has different symmetries, like the rotational symmetry of a molecule, or the permutation symmetry of a graph?

Real-world problems with geometric structure:

  • Molecular graphs: Atoms are nodes, chemical bonds are edges. Drug discovery requires predicting properties from molecular structure: will this molecule bind to this protein? Is it toxic? Is it synthesizable?
  • Social networks: Users are nodes, relationships are edges. Link prediction, community detection, and influence modeling all require reasoning about graph structure.
  • Protein structures: Amino acids connected in 3D space. AlphaFold (Jumper et al. 2021) revolutionized structural biology by predicting protein 3D structures with near-experimental accuracy.
  • 3D point clouds: LiDAR sensors in autonomous vehicles produce point clouds: unordered sets of 3D coordinates. Processing these requires permutation-invariant architectures.
  • Physical simulations: Particles, fluid elements, and mesh nodes interact through spatial relationships. Learning physics directly from data requires respecting the geometric structure of the simulation domain.

25.2 Symmetry, Equivariance, and Invariance

The mathematical foundation of geometric deep learning rests on group theory and the concept of equivariance.

A function \(f\) is equivariant to a group action \(g\) if: \[f(g \cdot x) = g \cdot f(x)\] The output transforms consistently with the input. If you rotate the input, the output rotates correspondingly.

A function is invariant if: \[f(g \cdot x) = f(x)\] The output does not change under the group action. A molecule's energy is the same regardless of how you orient it in space.

NoteEquivariance in Practice

Why equivariance matters practically: it means the network does not waste capacity learning the same function for every orientation, position, or permutation. A rotationally equivariant network for molecular property prediction needs to learn “what makes a molecule active” only once, not separately for every possible 3D orientation. This dramatically reduces the data and compute needed for learning.

Examples:

  • CNNs: Equivariant to translation (shifting the input shifts feature maps correspondingly).
  • GNNs: Equivariant to node permutation (reordering nodes reorders outputs, preserving the graph structure).
  • SE(3)-equivariant models: Equivariant to 3D rotations and translations, essential for molecular and physics applications.

25.3 The 5G Framework

Bronstein et al. organize geometric deep learning around five domains (the “5Gs”), each defined by its symmetry group:

  1. Grids (\(\mathbb{Z}^n\)): Regular lattices with translation symmetry. Standard CNNs for images and audio.
  2. Groups: Extend convolutions to richer symmetries: rotation equivariant CNNs, scale equivariant networks. Cohen and Welling's group-equivariant CNNs (Cohen and Welling 2016) (2016) showed how to build convolutions that respect any specified symmetry group.
  3. Graphs: Arbitrary connectivity with permutation symmetry. Graph Neural Networks (GNNs).
  4. Geodesics: Intrinsic operations on manifold surfaces using geodesic distances. Mesh CNNs for 3D shape analysis.
  5. Gauges: The most general setting: data on fiber bundles with gauge symmetry. Relevant for physics simulations, climate modeling, and data on arbitrary curved surfaces.

25.4 Graph Neural Networks

GNNs are the most practically important class of geometric deep learning models, and the most accessible starting point.

25.4.1 The Message Passing Framework

Most GNNs follow the message-passing paradigm: each node updates its representation by aggregating information (“messages”) from its neighbors: \[h_v^{(l+1)} = \text{UPDATE}\!\left(h_v^{(l)},\; \text{AGG}\!\left(\left\{m_{u \to v}^{(l)} : u \in \mathcal{N}(v)\right\}\right)\right)\] where \(m_{u \to v}^{(l)} = \text{MSG}(h_u^{(l)}, h_v^{(l)}, e_{uv})\) is the message from neighbor \(u\) to node \(v\), and \(\mathcal{N}(v)\) is the set of neighbors of \(v\).

TipAnalogy: Gossip in a Social Network

Each person (node) learns about the wider network by talking to their friends (neighbors). After one round, everyone knows about their immediate friends. After two rounds, everyone knows about friends of friends. After \(k\) rounds, information has spread \(k\) hops outward. The “receptive field” of a GNN grows with depth in exactly the same way a CNN's receptive field grows with layers.

25.4.2 Key GNN Architectures

  • GCN (Graph Convolutional Network) (Kipf and Welling 2017): The foundational architecture. Uses a spectral approach with a localized first-order Chebyshev approximation. Simple, efficient, and surprisingly effective. The update rule averages neighbor features with degree normalization.
  • GAT (Graph Attention Network) (Veličković et al. 2018): Adds attention coefficients to weight messages from different neighbors. Some neighbors are more important than others, and GAT learns which ones to attend to.
  • GraphSAGE (Hamilton et al. 2017): Samples a fixed-size neighborhood and aggregates via mean, LSTM, or max pooling. Crucially enables inductive learning: the model can process nodes that were not present during training.
  • GIN (Graph Isomorphism Network): Provably maximally expressive among message-passing GNNs. As powerful as the Weisfeiler-Lehman (WL) graph isomorphism test.

25.4.3 Limitations of Message Passing

Standard message-passing GNNs have known limitations:

Expressiveness: They cannot distinguish certain non-isomorphic graphs (bounded by the WL test). Higher-order GNNs and subgraph methods address this but at higher computational cost.

Over-smoothing: As you stack more GNN layers, node representations converge, eventually becoming indistinguishable. Deep GNNs (beyond 5 to 10 layers) often perform worse than shallow ones.

Over-squashing: Information from distant nodes must pass through many intermediate nodes, creating a bottleneck. This is especially problematic in graphs with high diameter.

Graph Transformers: Apply self-attention over all nodes (treating the graph as a fully connected network with structural positional encodings) to overcome these limitations. Models like Graphormer and GPS combine message passing with global attention.

25.5 Applications

AlphaFold (Jumper et al. 2021): Perhaps the most impactful application of geometric deep learning. AlphaFold 2 predicts protein 3D structures from amino acid sequences with near-experimental accuracy, a problem that had been open for 50 years. Its Evoformer module combines attention over amino acid pairs with equivariant geometric processing to iteratively refine 3D coordinates.

Drug discovery: GNNs predict molecular properties (binding affinity, toxicity, solubility) from molecular graphs. Virtual screening with GNNs can evaluate millions of candidate molecules in hours, a process that would take years in a wet lab.

Recommender systems: User-item interactions form a bipartite graph. GNN-based recommenders (PinSage at Pinterest, LightGCN) propagate information through the interaction graph to learn embeddings for collaborative filtering.

Physics simulation: Graph Network-based Simulations (GNS, DeepMind) represent particles or mesh nodes as graph nodes and learn physical dynamics. They can simulate complex fluid, rigid-body, and deformable-body systems after training on example simulations.

TipAlphaFold's Impact

AlphaFold 2 predicted the 3D structures of essentially all known proteins, over 200 million structures. This has been called the most significant contribution of AI to science to date. It has enabled new research in drug design, enzyme engineering, understanding disease mechanisms, and even archaeology (analyzing ancient proteins). The AlphaFold Protein Structure Database is freely available to all researchers.

25.6 Exercises

  1. Install PyTorch Geometric and implement a GCN (Kipf and Welling 2017) for node classification on the Cora citation dataset. Report accuracy and compare with a simple MLP baseline that ignores graph structure. How much does graph structure help?
  2. Replace GCN layers with GAT layers (Veličković et al. 2018) and compare performance. Visualize the attention weights: which neighbors receive the highest attention?
  3. Train a GNN for molecular property prediction on the ESOL or Tox21 dataset (available in PyTorch Geometric). Compare GCN, GAT, and GIN architectures.
  4. Visualize learned node embeddings (via t-SNE) before and after GNN training. How does the graph structure influence the embedding space?
  5. Read the AlphaFold 2 paper (Jumper et al. 2021). Identify which components use equivariant operations and explain why equivariance matters for predicting 3D protein structure from a 1D amino acid sequence.

References

Bronstein, Michael M, Joan Bruna, Taco Cohen, and Petar Veličković. 2021. “Geometric Deep Learning: Grids, Groups, Graphs, Geodesics, and Gauges.” arXiv Preprint arXiv:2104.13478.
Cohen, Taco, and Max Welling. 2016. “Group Equivariant Convolutional Networks.” International Conference on Machine Learning.
Hamilton, William L., Rex Ying, and Jure Leskovec. 2017. “Inductive Representation Learning on Large Graphs.” Advances in Neural Information Processing Systems.
Jumper, John, Richard Evans, Alexander Pritzel, et al. 2021. “Highly Accurate Protein Structure Prediction with AlphaFold.” Nature 596: 583-89.
Kipf, Thomas N, and Max Welling. 2017. “Semi-Supervised Classification with Graph Convolutional Networks.” arXiv Preprint arXiv:1609.02907.
Veličković, Petar, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, and Yoshua Bengio. 2018. “Graph Attention Networks.” arXiv Preprint arXiv:1710.10903.