Overview

When should I use E3x?

Whenever you are working with data that “lives” in three-dimensional space (for example molecules, point clouds, or polygon meshes), using an \(\mathrm{E}(3)\)-equivariant model is probably beneficial. E3x contains \(\mathrm{E}(3)\)-equivariant implementations of many typical neural network building blocks, saving you the time of implementing them yourself.

Mathematical background

What is \(\mathrm{E}(3)\)?

\(\mathrm{E}(3)\) is the name of a group. A group is a set equipped with an operation that combines any two elements of the set to produce a third element of the set, in such a way that the operation is associative, an identity element exists, and every element has an inverse. In three-dimensional space, the following groups are especially relevant:

\(\mathrm{E}(3)\)

Euclidean group in three dimensions, comprising all translations, rotations, and reflections of the three-dimensional Euclidean space \(\mathbb{E}^3\) and arbitrary finite combinations of them.

\(\mathrm{SE}(3)\)

Special Euclidean group in three dimensions (translations and rotations).

\(\mathrm{O}(3)\)

Orthogonal group in three dimensions (rotations and reflections).

\(\mathrm{SO}(3)\)

Special orthogonal group in three dimensions (rotations).

What does equivariant mean?

Equivariance is a property of functions from \(n\) real numbers to \(m\) real numbers. To define what it means exactly, we first need to introduce the concept of a representation. A (linear) representation \(\rho\) of a group \(G\) is a function from \(G\) to square matrices such that for all \(g,h \in G\)

\[\rho(g)\rho(h) = \rho(gh)\,.\]

A function \(f: \mathcal{X} \mapsto \mathcal{Y}\), where \(\mathcal{X}=\mathbb{R}^n\) and \(\mathcal{Y}=\mathbb{R}^m\) are vector spaces, is called equivariant with respect to a group \(G\) and representations \(\rho^\mathcal{X}\) and \(\rho^\mathcal{Y}\) (meaning representations in the vector spaces \(\mathcal{X}\) and \(\mathcal{Y}\)) if for all \(g \in G\) and \(\mathbf{x} \in \mathcal{X}\):

\[f\left(\rho^\mathcal{X}(g)\mathbf{x}\right) = \rho^\mathcal{Y}(g)f(\mathbf{x})\,.\]

When \(\rho ^\mathcal{Y}(g)\) is the identity function, \(f\) is called invariant with respect to \(G\).

To put it simply: If a function \(f\) is equivariant, it does not matter whether a transformation is applied to its input or output, the result is the same:

equivariance

(Irreducible) representations of \(\mathrm{O}(3)\)

In a geometrical context, we are typically interested in representations of the group \(\mathrm{O}(3)\). This is because the numerical values of physical properties in three–dimensional space, e.g. (angular) velocities/momenta, electric/magnetic dipoles, etc. depend on the chosen coordinate system. By applying an orthogonal transformation to the underlying basis vectors, a different (but equally valid) coordinate system is obtained. So whenever we describe a physical property by \(m\) numbers, there is a rule that determines how these \(m\) numbers transform when going to a different coordinate system. Usually, this transformation is linear, and thus describes a representation of \(\mathrm{O}(3)\) on the vector space \(\mathbb{R}^m\).

Given two representations of \(\mathrm{O}(3)\) on \(\mathbb{R}^m\) and \(\mathbb{R}^n\), we also get a representation on the concatenated tuples of numbers on \(\mathbb{R}^{m+n}\) called a “direct sum representation”. An \(\mathrm{O}(3)\)–representation is called irreducible if it cannot be written as a direct sum of smaller representations (by introducing an appropriate coordinate system).

Irreducible representations (or irreps) of \(\mathrm{O}(3)\) on \(\mathbb{R}^m\) exist only for odd dimensions \(m=2\ell+1\), where \(\ell\) is called the degree of the representation. Furthermore, the element \(-Id\in \mathrm{O}(3)\) (a reflection, \(Id\) is the identity element) can only operate as the multiplication by a number \(p \in \{+1, -1\}\) in an irreducible representation on \(\mathbb{R}^{2\ell+1}\). The number \(p\) is called the parity of the representation. Representations with \(p=+1\) are also called even, and with \(p=-1\) odd. One can show that for every \(\ell=0,1,2,\dots\) and \(p=\pm1\), there is exactly one irrep of degree \(\ell\) and parity \(p\) (up to change of coordinates in \(\mathbb{R}^{2\ell+1}\)). As a short-hand notation, we write particular irreps as double-struck digits with a subscript ‘\(+\)’ or ‘\(-\)‘, e.g. \(\mathbb{0}_+\) stands for an irrep of degree \(\ell = 0\) and parity \(p = +1\).

For example, the trivial one–dimensional representation \(\mathrm{O}(3)\rightarrow\{+1\}\) is \(\mathbb{0}_+\), and the canonical three–dimensional representation given by the identity map \(\mathrm{O}(3)\rightarrow\mathrm{O}(3)\) is \(\mathbb{1}_-\). These irreps correspond to scalars and vectors, whereas the irreps \(\mathbb{0}_-\) and \(\mathbb{1}_+\) correspond to “pseudoscalars” and “pseudovectors”, respectively. An example of a pseudoscalar would be the (signed) volume (see triple product), an example of a pseudovector would be the angular velocity. In general, irreps with even/odd degree and even/odd parity correspond to (proper) tensors, whereas irreps with even/odd degree and odd/even parity are referred to as “pseudotensors”.

As mentioned above, the group \(\mathrm{O}(3)\) consists of rotations \(g\in \mathrm{SO}(3)\subset \mathrm{O}(3)\) and reflections that can be written as \(-Id \cdot g\) for a \(g\in \mathrm{SO}(3)\). A representation of \(\mathrm{O}(3)\) also gives a representation of the rotations \(\mathrm{SO}(3)\) just by restricting it to the subset \(\mathrm{SO}(3)\subset \mathrm{O}(3)\). The irreps of \(\mathrm{SO}(3)\) can all be obtained from irreps of \(\mathrm{O}(3)\) in this way and are characterized only by a degree \(\ell\), the parity \(p\) gives the rule how to extend an irrep from \(\mathrm{SO}(3)\) to \(\mathrm{O}(3)\).

For computations in E3x, we fix one particular basis in a particular \((2\ell+1)\)–dimensional vector space, which can be described by (real-valued) spherical harmonics of degree \(\ell\).

How does E3x work?

Irrep features

Ordinary neural networks typically operate on features \(\mathbf{x} \in \mathbb{R}^F\), where \(F\) is the dimension of the feature space. Typically, we think of \(\mathbf{x}\) as an \(F\)–dimensional feature vector, but to understand how E3x works, it is helpful to instead think of \(\mathbf{x}\) as a collection of \(F\) scalar features. In E3x, a single feature is not a single scalar anymore, but instead consists of irreps of \(\mathrm{O}(3)\) with even and odd parities for all degrees \(\ell = 0,\dots,L\), where \(L\) is the maximum degree. Features \(\mathbf{x} \in \mathbb{R}^{2\times (L+1)^2\times F}\) are stored in an ndarray like this (for this example, \(L=2\) and \(F=8\)):

feature shape diagram

Here, the column \(\mathbf{x}_i\) is the \(i\)-th feature and we use the short-hand notation \(\mathbf{x}^{(\ell_\pm)}\) to refer to all irreps of degree \(\ell\) with parity \(p = \pm 1\), and \(\mathbf{x}^{(\pm)}\) to refer to all irreps of all degrees with parity \(p = \pm 1\).

In code, the different feature subsets can be conveniently accessed with array slicing:

notation

code

\(\mathbf{x}_i \in \mathbb{R}^{2\times(L+1)^2\times 1}\)

x[:, :, i:i+1]

\(\mathbf{x}^{(+)} \in \mathbb{R}^{1\times(L+1)^2\times F}\)

x[0:1, :, :]

\(\mathbf{x}^{(-)} \in \mathbb{R}^{1\times(L+1)^2\times F}\)

x[1:2, :, :]

\(\mathbf{x}_i^{(+)} \in \mathbb{R}^{1\times(L+1)^2\times 1}\)

x[0:1, :, i:i+1]

\(\mathbf{x}_i^{(-)} \in \mathbb{R}^{1\times(L+1)^2\times 1}\)

x[1:2, :, i:i+1]

\(\mathbf{x}^{(\ell_+)} \in \mathbb{R}^{1\times(2\ell+1)\times F}\)

x[0:1, l**2:(l+1)**2, :]

\(\mathbf{x}^{(\ell_-)} \in \mathbb{R}^{1\times(2\ell+1)\times F}\)

x[1:2, l**2:(l+1)**2, :]

\(\mathbf{x}_i^{(\ell_+)} \in \mathbb{R}^{1\times(2\ell+1)\times 1}\)

x[0:1, l**2:(l+1)**2, i:i+1]

\(\mathbf{x}_i^{(\ell_-)} \in \mathbb{R}^{1\times(2\ell+1)\times 1}\)

x[1:2, l**2:(l+1)**2, i:i+1]

Sometimes, it is useful to work with features that omit all pseudotensors and only contain proper tensor components. For better memory efficiency, such features \(\mathbf{x} \in \mathbb{R}^{1\times (L+1)^2\times F}\) are stored in an ndarray as follows (in this example, \(L=2\) and \(F=8\)):

proper tensor O(3) features

They behave equivalently to ordinary features \(\mathbf{x} \in \mathbb{R}^{2\times (L+1)^2\times F}\), where all pseudotensor components are zero (but without the need for explicit zero-padding):

padded O(3) features

All operations in E3x automatically detect which kind of features is used from their shape and computations are adapted accordingly.

Another way to think about features in E3x is to imagine them as three–dimensional shapes (blue: positive, red: negative, arrows show the x-, y-, and z-axes, click & drag to rotate):

When applying rotations, the shapes of features stay the same, but they might be oriented differently after the transformation. For example, a rotation about the z-axis by \(\frac{\pi}{2}\) transforms the features like this (all numbers that may change when applying rotations are highlighted in red):

Rotated features
(click here for a visualization of the rotated features).


When applying a reflection, all \(\mathbf{x}^{(-)}\) change into their opposite, whereas all \(\mathbf{x}^{(+)}\) stay the same:

Reflected features
(click here for a visualization of the reflected features).


This makes it possible to predict output quantities that automatically transform in the correct way when a transformation is applied to the inputs of a model (see below).

Modifications to neural network components

Contrary to ordinary neural networks, when working with equivariant features, not every operation is “allowed” – some operations “destroy the equivariance”. For example, nonlinear activation functions may only be applied to scalar features \(\mathbf{x}^{(0+)}\) (technically, odd activation functions could also be applied to pseudoscalar features \(\mathbf{x}^{(0-)}\), but E3x handles all activation functions on equal footing for simplicity). Fortunately, it is possible to express many common activation functions in a way that preserves equivariance and is equivalent to their ordinary formulation when the features contain only scalars (see here for more details). Make sure to use one of the already implemented activation functions, or write your own activation function using the principle described here when building your architecture.

Care must even be taken when applying linear operations. Namely, all “feature channels” of the same degree must be treated equally: Even though e.g. features \(\mathbf{x}^{(1_-)}\) of degree \(\ell=1\) consist of \(2\ell+1=3\) individual numbers (see above), they should be regarded as a single entity. For example, when multiplying \(\mathbf{x}^{(1_-)}\) with a scalar, all three components must be multiplied by the same factor – just as in ordinary scalar multiplication of vectors. While scalar multiplication of irrep features is alright, there is no valid “scalar addition” (just as adding a scalar to all three components of a vector is not really meaningful). Concretely, this means that bias terms may only be applied to scalar features \(\mathbf{x}^{(0_+)}\). E3x contains an implementation of Dense layers that handles all these details for you. Thus, building \(\mathrm{E}(3)\)-equivariant feedforward neural networks with E3x is as easy as replacing all ordinary dense layers and activations with their E3x counterparts.

In addition, E3x contains equivariant versions of many other common building blocks, such as MessagePass or SelfAttention layers (see here for a complete overview of all implemented modules). Many of these more advanced layers “couple” irrep features in a way that has no analogue in ordinary networks, see below.

Coupling irreps

It is possible to “couple” two irreps by computing their tensor product to produce new irreps. This uses the fact that the tensor product of two irreps can be written as a direct sum of (new) irreps. For example

\[\begin{split}\begin{align*} \mathbb{0_+} \otimes \mathbb{1_-} &= \mathbb{1_-} \,, \\ \mathbb{1_-} \otimes \mathbb{1_-} &= \mathbb{0_+} \oplus \mathbb{1_+} \oplus \mathbb{2_+} \,, \\ \mathbb{2_-} \otimes \mathbb{3_+} &= \mathbb{1_-} \oplus \mathbb{2_-} \oplus \mathbb{3_-} \oplus \mathbb{4_-} \oplus \mathbb{5_-} \,, \end{align*}\end{split}\]

or in general:

\[\begin{split}\mathbb{a}_{\alpha} \otimes \mathbb{b}_{\beta} = \begin{cases} \left( \mathbb{\lvert a-b \rvert} \right)\mathbb{_+} \oplus \left( \mathbb{\lvert a-b \rvert + 1} \right)\mathbb{_+} \oplus \dots \oplus \left( \mathbb{a+b} \right)\mathbb{_+} & \alpha = \beta \\ \left( \mathbb{\lvert a-b \rvert} \right)\mathbb{_-} \oplus \left( \mathbb{\lvert a-b \rvert + 1} \right)\mathbb{_-} \oplus \dots \oplus \left( \mathbb{a+b} \right)\mathbb{_-} & \alpha \neq \beta\,. \\ \end{cases}\end{split}\]

To simplify notation, we introduce the short-hand ‘\(\mathbb{a}\otimes^{(\ell_p)}\mathbb{b}\)’ to refer to the irrep of degree \(\ell\) and parity \(p\) in the direct sum representation of the tensor product \(\mathbb{a}\otimes\mathbb{b}\). For example:

\[\mathbb{1_-} \otimes \mathbb{1_-} = \overbrace{\mathbb{0_+}}^{\mathbb{1_-}\otimes^{(0_+)}\mathbb{1_-}} \oplus \overbrace{\mathbb{1_+}}^{\mathbb{1_-}\otimes^{(1_+)}\mathbb{1_-}} \oplus \overbrace{\mathbb{2_+}}^{\mathbb{1_-}\otimes^{(2_+)}\mathbb{1_-}} \,.\]

Coupling two irreps to produce new irreps may seem to be a foreign concept at first glance. However, it is analogous to familiar operations on vectors in \(\mathbb{R}^3\) (which correspond to irreps \(\mathbb{1_-}\), see above). Consider two vectors \(\mathbf{u},\mathbf{v}\in\mathbb{R}^3\). Their tensor product is

\[\begin{split}\mathbf{u} \otimes \mathbf{v} = \begin{bmatrix} u_{x} \\ u_{y} \\ u_{z} \end{bmatrix} \otimes \begin{bmatrix} v_{x} \\ v_{y} \\ v_{z} \end{bmatrix} = \begin{bmatrix} u_{x}v_{x} & u_{x}v_{y} & u_{x}v_{z} \\ u_{y}v_{x} & u_{y}v_{y} & u_{y}v_{z} \\ u_{z}v_{x} & u_{z}v_{y} & u_{z}v_{z} \end{bmatrix}\end{split}\]

and the irreps \(\mathbb{0_+}\) and \(\mathbb{1_+}\) in the direct sum representation of their tensor product correspond to their scaled dot product

\[\mathbf{u} \otimes^{(0_+)} \mathbf{v} = \frac{1}{\sqrt{3}} \langle\mathbf{u},\mathbf{v}\rangle = \frac{1}{\sqrt{3}} \left(u_x v_x + u_y v_y + u_z v_z \right)\,,\]

and their scaled cross product

\[\begin{split}\mathbf{u} \otimes^{(1_+)} \mathbf{v} = \frac{1}{\sqrt{2}} \left(\mathbf{u}\times \mathbf{v}\right) = \frac{1}{\sqrt{2}} \begin{bmatrix} u_y v_z - u_z v_y\\ u_z v_x - u_x v_z\\ u_x v_y - u_y v_x \end{bmatrix}\,.\end{split}\]

The irrep \(\mathbb{2_+}\) cannot be easily written in terms of familiar operations like dot or cross products, but can be related to the entries of the traceless symmetric matrix

\[\begin{split}\mathbf{S} = \begin{bmatrix} S_{xx} & S_{xy} & S_{xz} \\ S_{xy} & S_{yy} & S_{yz} \\ S_{xz} & S_{yz} & S_{zz} \end{bmatrix} := \frac{1}{2} \left(\mathbf{u}\mathbf{v}^\intercal + \mathbf{v}\mathbf{u}^\intercal\right) - \frac{1}{6} \mathrm{Tr} \left(\mathbf{u}\mathbf{v}^\intercal + \mathbf{v}\mathbf{u}^\intercal\right) \cdot I_3\,,\end{split}\]

where \(I_3\) is the \(3\times3\) identity matrix and \(\mathrm{Tr}\) is the trace. Then, the irrep \(\mathbb{2_+}\) is given by

\[\begin{split}\mathbf{u} \otimes^{(2_+)} \mathbf{v} = \frac{1}{\sqrt{2}} \begin{bmatrix} S_{xx} - S_{yy}\\ 2 S_{xy}\\ 2 S_{xz}\\ 2 S_{yz}\\ \sqrt{3} S_{zz} \end{bmatrix}\,,\end{split}\]

or, re-written directly in terms of the components of \(\mathbf{u}\) and \(\mathbf{v}\),

\[\begin{split}\mathbf{u} \otimes^{(2_+)} \mathbf{v} = \frac{1}{\sqrt{2}} \begin{bmatrix} u_x v_x - u_y v_y\\ u_x v_y + u_y v_x\\ u_x v_z + u_z v_x\\ u_y v_z + u_z v_y\\ \frac{1}{\sqrt{3}} \left(2 u_z v_z - u_x v_x - u_y v_y\right) \end{bmatrix}\,.\end{split}\]

The dot product produces a scalar, corresponding to a one–dimensional irrep \(\mathbb{0_+}\), and the cross product produces a pseudovector, corresponding to a three–dimensional irrep \(\mathbb{1_+}\). It might be less apparent why a traceless symmetric matrix \(\mathbf{S}\) is related to a five–dimensional representation, but any traceless symmetric matrix can be written as

\[\begin{split}\begin{bmatrix} \phantom{1-}a\phantom{-d} & b & c \\ b & \phantom{1-}d\phantom{-d} & e \\ c & e & -a-d \\ \end{bmatrix}\end{split}\]

and only five numbers \(a\), \(b\), \(c\), \(d\), and \(e\) are necessary to fully specify it.

Importantly, all irreps in the direct sum representation of \(\mathbf{u} \otimes \mathbf{v}\) can be written as linear combinations of entries in the tensor product matrix

\[\begin{split}\mathbf{u} \otimes \mathbf{v} = \begin{bmatrix} u_{x}v_{x} & u_{x}v_{y} & u_{x}v_{z} \\ u_{y}v_{x} & u_{y}v_{y} & u_{y}v_{z} \\ u_{z}v_{x} & u_{z}v_{y} & u_{z}v_{z} \end{bmatrix}\end{split}\]

with appropriate coefficients (and vice versa). This is also true for other tensor products of irreps and the corresponding coefficients are called Clebsch-Gordan coefficients. In E3x, all possible irreps in the direct sum representation of a tensor product can be computed in parallel by multiplication with the clebsch_gordan array (containing the required coefficients) and subsequent summation. The Tensor and FusedTensor layers are the basic operations in E3x to couple two feature representations; they have no analogue in ordinary neural networks. The FusedTensor layers have a lower computational complexity compared to Tensor layers (\(O(L^4)\) vs. \(O(L^6)\)), with the downside of being less expressive (they contain fewer learnable parameters). When working with small maximum degrees \(L \leq 4\), the ordinary Tensor layers still tend to be faster on accelerators (GPUs/TPUs) due to their simpler implementation. When working with large maximum degrees \(L > 4\), (or on CPUs), FusedTensor layers can be significantly faster.

Predicting equivariant quantities

The prediction of equivariant quantities with E3x is as straightforward as the prediction of scalar quantities with ordinary neural networks. For example, a single vector can be predicted by slicing the output features y of the last layer appropriately, i.e. y[..., 1, 1:4, 0] to produce an array of shape (..., 3) that transforms as a vector under rotations/reflections of the input to your neural network (see the section on slicing irrep features for reference). Further, E3x contains convenience functions for converting from irreps to traceless symmetric tensors and vice versa (see irreps_to_tensor and tensor_to_irreps), which is a more typical format for many quantities of interest in physics.