Overview
When should I use E3x?
Whenever you are working with data that “lives” in three-dimensional space (for example molecules, point clouds, or polygon meshes), using an \(\mathrm{E}(3)\)-equivariant model is probably beneficial. E3x contains \(\mathrm{E}(3)\)-equivariant implementations of many typical neural network building blocks, saving you the time of implementing them yourself.
Mathematical background
What is \(\mathrm{E}(3)\)?
\(\mathrm{E}(3)\) is the name of a group. A group is a set equipped with an operation that combines any two elements of the set to produce a third element of the set, in such a way that the operation is associative, an identity element exists, and every element has an inverse. In three-dimensional space, the following groups are especially relevant:
- \(\mathrm{E}(3)\)
Euclidean group in three dimensions, comprising all translations, rotations, and reflections of the three-dimensional Euclidean space \(\mathbb{E}^3\) and arbitrary finite combinations of them.
- \(\mathrm{SE}(3)\)
Special Euclidean group in three dimensions (translations and rotations).
- \(\mathrm{O}(3)\)
Orthogonal group in three dimensions (rotations and reflections).
- \(\mathrm{SO}(3)\)
Special orthogonal group in three dimensions (rotations).
What does equivariant mean?
Equivariance is a property of functions from \(n\) real numbers to \(m\) real numbers. To define what it means exactly, we first need to introduce the concept of a representation. A (linear) representation \(\rho\) of a group \(G\) is a function from \(G\) to square matrices such that for all \(g,h \in G\)
A function \(f: \mathcal{X} \mapsto \mathcal{Y}\), where \(\mathcal{X}=\mathbb{R}^n\) and \(\mathcal{Y}=\mathbb{R}^m\) are vector spaces, is called equivariant with respect to a group \(G\) and representations \(\rho^\mathcal{X}\) and \(\rho^\mathcal{Y}\) (meaning representations in the vector spaces \(\mathcal{X}\) and \(\mathcal{Y}\)) if for all \(g \in G\) and \(\mathbf{x} \in \mathcal{X}\):
When \(\rho ^\mathcal{Y}(g)\) is the identity function, \(f\) is called invariant with respect to \(G\).
To put it simply: If a function \(f\) is equivariant, it does not matter whether a transformation is applied to its input or output, the result is the same:
(Irreducible) representations of \(\mathrm{O}(3)\)
In a geometrical context, we are typically interested in representations of the group \(\mathrm{O}(3)\). This is because the numerical values of physical properties in three–dimensional space, e.g. (angular) velocities/momenta, electric/magnetic dipoles, etc. depend on the chosen coordinate system. By applying an orthogonal transformation to the underlying basis vectors, a different (but equally valid) coordinate system is obtained. So whenever we describe a physical property by \(m\) numbers, there is a rule that determines how these \(m\) numbers transform when going to a different coordinate system. Usually, this transformation is linear, and thus describes a representation of \(\mathrm{O}(3)\) on the vector space \(\mathbb{R}^m\).
Given two representations of \(\mathrm{O}(3)\) on \(\mathbb{R}^m\) and \(\mathbb{R}^n\), we also get a representation on the concatenated tuples of numbers on \(\mathbb{R}^{m+n}\) called a “direct sum representation”. An \(\mathrm{O}(3)\)–representation is called irreducible if it cannot be written as a direct sum of smaller representations (by introducing an appropriate coordinate system).
Irreducible representations (or irreps) of \(\mathrm{O}(3)\) on \(\mathbb{R}^m\) exist only for odd dimensions \(m=2\ell+1\), where \(\ell\) is called the degree of the representation. Furthermore, the element \(-Id\in \mathrm{O}(3)\) (a reflection, \(Id\) is the identity element) can only operate as the multiplication by a number \(p \in \{+1, -1\}\) in an irreducible representation on \(\mathbb{R}^{2\ell+1}\). The number \(p\) is called the parity of the representation. Representations with \(p=+1\) are also called even, and with \(p=-1\) odd. One can show that for every \(\ell=0,1,2,\dots\) and \(p=\pm1\), there is exactly one irrep of degree \(\ell\) and parity \(p\) (up to change of coordinates in \(\mathbb{R}^{2\ell+1}\)). As a short-hand notation, we write particular irreps as double-struck digits with a subscript ‘\(+\)’ or ‘\(-\)‘, e.g. \(\mathbb{0}_+\) stands for an irrep of degree \(\ell = 0\) and parity \(p = +1\).
For example, the trivial one–dimensional representation \(\mathrm{O}(3)\rightarrow\{+1\}\) is \(\mathbb{0}_+\), and the canonical three–dimensional representation given by the identity map \(\mathrm{O}(3)\rightarrow\mathrm{O}(3)\) is \(\mathbb{1}_-\). These irreps correspond to scalars and vectors, whereas the irreps \(\mathbb{0}_-\) and \(\mathbb{1}_+\) correspond to “pseudoscalars” and “pseudovectors”, respectively. An example of a pseudoscalar would be the (signed) volume (see triple product), an example of a pseudovector would be the angular velocity. In general, irreps with even/odd degree and even/odd parity correspond to (proper) tensors, whereas irreps with even/odd degree and odd/even parity are referred to as “pseudotensors”.
As mentioned above, the group \(\mathrm{O}(3)\) consists of rotations \(g\in \mathrm{SO}(3)\subset \mathrm{O}(3)\) and reflections that can be written as \(-Id \cdot g\) for a \(g\in \mathrm{SO}(3)\). A representation of \(\mathrm{O}(3)\) also gives a representation of the rotations \(\mathrm{SO}(3)\) just by restricting it to the subset \(\mathrm{SO}(3)\subset \mathrm{O}(3)\). The irreps of \(\mathrm{SO}(3)\) can all be obtained from irreps of \(\mathrm{O}(3)\) in this way and are characterized only by a degree \(\ell\), the parity \(p\) gives the rule how to extend an irrep from \(\mathrm{SO}(3)\) to \(\mathrm{O}(3)\).
For computations in E3x, we fix one particular basis in a particular \((2\ell+1)\)–dimensional vector space, which can be described by (real-valued) spherical harmonics of degree \(\ell\).
How does E3x work?
Irrep features
Ordinary neural networks typically operate on features
\(\mathbf{x} \in \mathbb{R}^F\), where \(F\) is the dimension of the
feature space. Typically, we think of \(\mathbf{x}\) as an
\(F\)–dimensional feature
vector, but
to understand how E3x works, it is helpful to instead think of
\(\mathbf{x}\) as a collection of \(F\) scalar features. In
E3x, a single feature is not a single scalar anymore, but instead consists of
irreps of \(\mathrm{O}(3)\) with even and odd parities for
all degrees \(\ell = 0,\dots,L\), where \(L\) is the maximum degree.
Features \(\mathbf{x} \in \mathbb{R}^{2\times (L+1)^2\times F}\) are stored
in an ndarray
like this (for this example,
\(L=2\) and \(F=8\)):
Here, the column \(\mathbf{x}_i\) is the \(i\)-th feature and we use the short-hand notation \(\mathbf{x}^{(\ell_\pm)}\) to refer to all irreps of degree \(\ell\) with parity \(p = \pm 1\), and \(\mathbf{x}^{(\pm)}\) to refer to all irreps of all degrees with parity \(p = \pm 1\).
In code, the different feature subsets can be conveniently accessed with array slicing:
notation |
code |
---|---|
\(\mathbf{x}_i \in \mathbb{R}^{2\times(L+1)^2\times 1}\) |
|
\(\mathbf{x}^{(+)} \in \mathbb{R}^{1\times(L+1)^2\times F}\) |
|
\(\mathbf{x}^{(-)} \in \mathbb{R}^{1\times(L+1)^2\times F}\) |
|
\(\mathbf{x}_i^{(+)} \in \mathbb{R}^{1\times(L+1)^2\times 1}\) |
|
\(\mathbf{x}_i^{(-)} \in \mathbb{R}^{1\times(L+1)^2\times 1}\) |
|
\(\mathbf{x}^{(\ell_+)} \in \mathbb{R}^{1\times(2\ell+1)\times F}\) |
|
\(\mathbf{x}^{(\ell_-)} \in \mathbb{R}^{1\times(2\ell+1)\times F}\) |
|
\(\mathbf{x}_i^{(\ell_+)} \in \mathbb{R}^{1\times(2\ell+1)\times 1}\) |
|
\(\mathbf{x}_i^{(\ell_-)} \in \mathbb{R}^{1\times(2\ell+1)\times 1}\) |
|
Sometimes, it is useful to work with features that omit all pseudotensors and
only contain proper tensor components. For better memory efficiency, such
features \(\mathbf{x} \in \mathbb{R}^{1\times (L+1)^2\times F}\) are stored
in an ndarray
as follows (in this example,
\(L=2\) and \(F=8\)):
They behave equivalently to ordinary features \(\mathbf{x} \in \mathbb{R}^{2\times (L+1)^2\times F}\), where all pseudotensor components are zero (but without the need for explicit zero-padding):
All operations in E3x automatically detect which kind of features is used from their shape and computations are adapted accordingly.
Another way to think about features in E3x is to imagine them as three–dimensional shapes (blue: positive, red: negative, arrows show the x-, y-, and z-axes, click & drag to rotate):
When applying rotations, the shapes of features stay the same, but they might be oriented differently after the transformation. For example, a rotation about the z-axis by \(\frac{\pi}{2}\) transforms the features like this (all numbers that may change when applying rotations are highlighted in red):
When applying a reflection, all \(\mathbf{x}^{(-)}\) change into their opposite, whereas all \(\mathbf{x}^{(+)}\) stay the same:
This makes it possible to predict output quantities that automatically transform in the correct way when a transformation is applied to the inputs of a model (see below).
Modifications to neural network components
Contrary to ordinary neural networks, when working with equivariant features, not every operation is “allowed” – some operations “destroy the equivariance”. For example, nonlinear activation functions may only be applied to scalar features \(\mathbf{x}^{(0+)}\) (technically, odd activation functions could also be applied to pseudoscalar features \(\mathbf{x}^{(0-)}\), but E3x handles all activation functions on equal footing for simplicity). Fortunately, it is possible to express many common activation functions in a way that preserves equivariance and is equivalent to their ordinary formulation when the features contain only scalars (see here for more details). Make sure to use one of the already implemented activation functions, or write your own activation function using the principle described here when building your architecture.
Care must even be taken when applying
linear operations. Namely, all
“feature channels” of the same degree must be treated equally: Even though
e.g. features \(\mathbf{x}^{(1_-)}\) of degree \(\ell=1\) consist of
\(2\ell+1=3\) individual numbers (see above), they
should be regarded as a single entity. For example, when multiplying
\(\mathbf{x}^{(1_-)}\) with a scalar, all three components must be
multiplied by the same factor – just as in ordinary
scalar multiplication
of vectors. While scalar multiplication of irrep features is alright, there is
no valid “scalar addition” (just as adding a scalar to all three components of
a vector is not really meaningful). Concretely, this means that bias terms may
only be applied to scalar features \(\mathbf{x}^{(0_+)}\). E3x contains an
implementation of Dense
layers that handles all
these details for you. Thus, building \(\mathrm{E}(3)\)-equivariant
feedforward neural networks with E3x is as easy as replacing all ordinary dense
layers and activations with their E3x counterparts.
In addition, E3x contains equivariant versions of many other common building
blocks, such as MessagePass
or
SelfAttention
layers (see
here for a complete overview of all implemented modules).
Many of these more advanced layers “couple” irrep features in a way that has no
analogue in ordinary networks, see below.
Coupling irreps
It is possible to “couple” two irreps by computing their tensor product to produce new irreps. This uses the fact that the tensor product of two irreps can be written as a direct sum of (new) irreps. For example
or in general:
To simplify notation, we introduce the short-hand ‘\(\mathbb{a}\otimes^{(\ell_p)}\mathbb{b}\)’ to refer to the irrep of degree \(\ell\) and parity \(p\) in the direct sum representation of the tensor product \(\mathbb{a}\otimes\mathbb{b}\). For example:
Coupling two irreps to produce new irreps may seem to be a foreign concept at first glance. However, it is analogous to familiar operations on vectors in \(\mathbb{R}^3\) (which correspond to irreps \(\mathbb{1_-}\), see above). Consider two vectors \(\mathbf{u},\mathbf{v}\in\mathbb{R}^3\). Their tensor product is
and the irreps \(\mathbb{0_+}\) and \(\mathbb{1_+}\) in the direct sum representation of their tensor product correspond to their scaled dot product
and their scaled cross product
The irrep \(\mathbb{2_+}\) cannot be easily written in terms of familiar operations like dot or cross products, but can be related to the entries of the traceless symmetric matrix
where \(I_3\) is the \(3\times3\) identity matrix and \(\mathrm{Tr}\) is the trace. Then, the irrep \(\mathbb{2_+}\) is given by
or, re-written directly in terms of the components of \(\mathbf{u}\) and \(\mathbf{v}\),
The dot product produces a scalar, corresponding to a one–dimensional irrep \(\mathbb{0_+}\), and the cross product produces a pseudovector, corresponding to a three–dimensional irrep \(\mathbb{1_+}\). It might be less apparent why a traceless symmetric matrix \(\mathbf{S}\) is related to a five–dimensional representation, but any traceless symmetric matrix can be written as
and only five numbers \(a\), \(b\), \(c\), \(d\), and \(e\) are necessary to fully specify it.
Importantly, all irreps in the direct sum representation of \(\mathbf{u} \otimes \mathbf{v}\) can be written as linear combinations of entries in the tensor product matrix
with appropriate coefficients (and vice versa). This is also true for other
tensor products of irreps and the corresponding coefficients are called
Clebsch-Gordan coefficients.
In E3x, all possible irreps in the direct sum representation of a tensor
product can be computed in parallel by multiplication with the
clebsch_gordan
array (containing the
required coefficients) and subsequent summation. The
Tensor
and
FusedTensor
layers are the basic
operations in E3x to couple two feature representations; they have no analogue
in ordinary neural networks. The
FusedTensor
layers have a lower
computational complexity compared to
Tensor
layers (\(O(L^4)\) vs.
\(O(L^6)\)), with the downside of being less expressive (they contain fewer
learnable parameters). When working with small maximum degrees \(L \leq 4\),
the ordinary Tensor
layers still tend to be
faster on accelerators (GPUs/TPUs) due to their simpler implementation. When
working with large maximum degrees \(L > 4\), (or on CPUs),
FusedTensor
layers can be significantly
faster.
Predicting equivariant quantities
The prediction of equivariant quantities with E3x is as straightforward as the
prediction of scalar quantities with ordinary neural networks. For example, a
single vector can be predicted by slicing the output features y
of the last
layer appropriately, i.e. y[..., 1, 1:4, 0]
to produce an array of
shape (..., 3)
that transforms as a vector under rotations/reflections of
the input to your neural network (see the
section on slicing irrep features for reference).
Further, E3x contains convenience functions for converting from irreps to
traceless symmetric tensors and vice versa (see
irreps_to_tensor
and
tensor_to_irreps
), which is a more
typical format for many quantities of interest in physics.