e3x.nn.modules.FusedTensor
- class e3x.nn.modules.FusedTensor(max_degree=None, include_pseudotensors=True, cartesian_order=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function fused_tensor_normal>, parent=<flax.linen.module._Sentinel object>, name=None)[source]
Bases:
ModuleFused tensor product of two equivariant feature representations.
This module performs a similar function as
Tensor, but has a lower computational complexity and fewer learnable parameters. Given two inputs \(\mathbf{x} \in \mathbb{R}^{P_1\times (L_1+1)^2 \times F}\) and \(\mathbf{y} \in \mathbb{R}^{P_2\times (L_2+1)^2 \times F}\) the output is \(\mathbf{z} \in \mathbb{R}^{P_3\times (L_3+1)^2 \times F}\). Here, \(P_1\), \(P_2\), and \(P_3\) are either \(1\) or \(2\) (depending on whether the inputs/output contain pseudotensors or not) and \(L_1\), \(L_2\), and \(L_3\) are positive integers or zero (\(L_3\) =max_degree). The computation consists of the following steps:1a. The constituent irreps \(\mathbf{x}^{(a_\alpha)} \in \mathbb{R}^{1\times (2a+1) \times F}\) and \(\mathbf{y}^{(b_\beta)} \in \mathbb{R}^{1\times (2b+1) \times F}\) of the features \(\mathbf{x}\) and \(\mathbf{y}\) are transformed via a change of basis (“vectors” to “matrices”) to \(\mathbf{\tilde{x}}^{(a_\alpha)}, \mathbf{\tilde{y}}^{(b_\beta)} \in \mathbb{R}^{(2\tilde{l}+1) \times (2\tilde{l}+1) \times F}\) with \(\tilde{l} = \left\lceil\frac{\mathrm{max}(L_1, L_2, L_3)}{2}\right\rceil\).
1b. The individual “matrix irreps” with equal parities are multiplied with (separate) learnable weights \(\mathbf{w} \in \mathbb{R}^{1 \times 1 \times F}\) and added to form the matrices \(\mathbf{X}^{(+)},\mathbf{X}^{(-)},\mathbf{Y}^{(+)},\mathbf{Y}^{(-)} \in \mathbb{R}^{(2\tilde{l}+1) \times (2\tilde{l}+1) \times F}\) (the element-wise product ‘\(\circ\)’ implies broadcasting over dimensions):
\[ \begin{align}\begin{aligned}\mathbf{X}^{(+)} &= \sum_{a=0}^{L_1} \mathbf{w}_{\mathbf{x}^{(a_+)}} \circ \mathbf{\tilde{x}}^{(a_+)}\\\mathbf{X}^{(-)} &= \sum_{a=0}^{L_1} \mathbf{w}_{\mathbf{x}^{(a_-)}} \circ \mathbf{\tilde{x}}^{(a_-)}\\\mathbf{Y}^{(+)} &= \sum_{b=0}^{L_2} \mathbf{w}_{\mathbf{y}^{(b_+)}} \circ \mathbf{\tilde{y}}^{(b_+)}\\\mathbf{Y}^{(-)} &= \sum_{b=0}^{L_2} \mathbf{w}_{\mathbf{y}^{(b_-)}} \circ \mathbf{\tilde{y}}^{(b_-)}\end{aligned}\end{align} \]Any potentially “missing” irreps, e.g. \(\mathbf{\tilde{x}}^{(1_+)}\) if \(\mathbf{x}\) does not contain pseudotensors, are assumed to be zero.
2. The so-formed matrices are coupled by matrix multiplication to produce new matrices as follows (“batch matrix multiplication” over the last dimension with size \(F\) is implied):
\[ \begin{align}\begin{aligned}\mathbf{Z}^{(+,+)} &= \mathbf{X}^{(+)}\mathbf{Y}^{(+)}\\\mathbf{Z}^{(-,-)} &= \mathbf{X}^{(-)}\mathbf{Y}^{(-)}\\\mathbf{Z}^{(+,-)} &= \mathbf{X}^{(+)}\mathbf{Y}^{(-)}\\\mathbf{Z}^{(-,+)} &= \mathbf{X}^{(-)}\mathbf{Y}^{(+)}\end{aligned}\end{align} \]Note: \(\mathbf{Z}^{(+,+)}\) and \(\mathbf{Z}^{(-,-)}\) have even parity and \(\mathbf{Z}^{(+,-)}\) and \(\mathbf{Z}^{(-,+)}\) have odd parity.
3a. The matrices \(\mathbf{Z}^{(+,+)}, \mathbf{Z}^{(-,-)}, \mathbf{Z}^{(+,-)}, \mathbf{Z}^{(-,+)}\) are “decomposed” into their constituent “matrix irreps” \(\mathbf{\tilde{z}}^{(+,+,c)}, \mathbf{\tilde{z}}^{(-,-,c)},\mathbf{\tilde{z}}^{(+,-,c)}, \mathbf{\tilde{z}}^{(-,+,c)} \in \mathbb{R}^{(2\tilde{l}+1) \times (2\tilde{l}+1)\times F}\) with \(c = 0,\dots,L_3\). During this decomposition, the individual matrix irreps are multiplied with (separate) learnable weights \(\mathbf{w} \in \mathbb{R}^{\times 1 \times 1\times F}\). This step can be thought of as performing the inverse of the operation in step 1b.
3b. The so-obtained “matrix irreps” are transformed via a change of basis (“matrices” to “vectors”) to obtain the irreps \(\mathbf{z}^{(+,+,c)},\mathbf{z}^{(-,-,c)},\mathbf{z}^{(+,-,c)}, \mathbf{z}^{(-,+,c)} \in \mathbb{R}^{1 \times (2c+1) \times F}\) for each value of \(c\). This step can be thought of as performing the inverse of the operation in step 1a.
4. Finally, the output irreps of degree \(c\) are obtained by summing matching parities:
\[ \begin{align}\begin{aligned}\mathbf{z}^{(c_+)} &= \mathbf{z}^{(+,+,c)} + \mathbf{z}^{(-,-,c)}\\\mathbf{z}^{(c_-)} &= \mathbf{z}^{(+,-,c)} + \mathbf{z}^{(-,+,c)}\end{aligned}\end{align} \]If
include_pseudotensors = False, all irreps that correspond to pseudotensors are discarded (set to zero).The following diagram shows a visualization of the computation for the example \(P_1=P_2=P_3=2\), \(L_1=L_2=L_3=2\).
While it is helpful to think of the change of basis and weighting of individual components performed in steps 1a,b and 3a,b separately, these steps are really performed concurrently using Clebsch-Gordan coefficients \(C^{l_3,m_3}_{l_1,m_1,l_2,m_2}\). The actual computation performed for the “vectors to matrices” change of basis plus multiplication by weights is:
\[X_{\tilde{m},\tilde{m}'} = \sum_{l=0}^{L} w_{l}\sum_{m=-l}^{l} C^{l,m}_{\tilde{l},\tilde{m},\tilde{l},\tilde{m}'} x_{l}^{m}\]whereas the computation performed for the “matrices to vectors” change of basis plus multiplication by weights is:
\[x_{l}^{m} = \tilde{w}_{l}\sum_{\tilde{m}=-\tilde{l}}^{\tilde{l}} \sum_{\tilde{m}'=-\tilde{l}}^{\tilde{l}} C^{l,m}_{\tilde{l},\tilde{m},\tilde{l},\tilde{m}'} X_{\tilde{m},\tilde{m}'}\]Here, \(X_{\tilde{m},\tilde{m}'}\) are the elements of matrix \(\mathbf{X} \in \mathbb{R}^{(2\tilde{l}+1)\times(2\tilde{l}+1)}\)
\[\begin{split}\mathbf{X} = \begin{bmatrix} X_{-\tilde{l},-\tilde{l}} & X_{-\tilde{l},-\tilde{l}+1} & \cdots & X_{-\tilde{l},\tilde{l}} \\ X_{-\tilde{l}+1,-\tilde{l}} & X_{-\tilde{l}+1,-\tilde{l}+1} & \cdots & X_{-\tilde{l}+1,\tilde{l}} \\ \vdots & \vdots & \ddots & \vdots \\ X_{\tilde{l},-\tilde{l}} & X_{\tilde{l},-\tilde{l}+1} & \cdots & X_{\tilde{l},\tilde{l}} \\ \end{bmatrix}\end{split}\]and \(x_{l}^{m}\) are the individual entries of irreps \(\mathbf{x} \in \mathbb{R}^{1\times(L+1)^2}\)
\[\mathbf{x} = [\underbrace{x_{0}^{0}}_{\mathbf{x}^{(0)}} \quad \underbrace{x_{1}^{-1} \;\; x_{1}^{0} \;\; x_{1}^{1}}_{\mathbf{x}^{(1)}} \quad \cdots \quad \underbrace{\quad x_{L}^{-L} \;\; \cdots \;\; x_{L}^{L}}_{\mathbf{x}^{(L)}} ]\](the feature dimension with size \(F\) and parity indicators \(+/-\) are omitted for clarity).
- max_degree
Maximum degree of the output. If not given,
max_degreeis chosen as the maximum of the maximum degrees of inputs1 and inputs2.
- include_pseudotensors
If
False, pseudotensors are omitted in the output.
- cartesian_order
If
True, Cartesian order is assumed.
- dtype
The dtype of the computation.
- param_dtype
The dtype passed to parameter initializers.
- precision
Numerical precision of the computation, see jax.lax.Precision for details.
- kernel_init
Initializer function for the weight kernel.
- __call__(inputs1, inputs2)[source]
Computes the fused tensor product of inputs1 and inputs2.
- Parameters:
inputs1 (
Union[Float[Array, '... 1 (max_degree1+1)**2 num_features'], Float[Array, '... 2 (max_degree1+1)**2 num_features']]) – The first factor of the tensor product.inputs2 (
Union[Float[Array, '... 1 (max_degree2+1)**2 num_features'], Float[Array, '... 2 (max_degree2+1)**2 num_features']]) – The second factor of the tensor product.
- Return type:
- Returns:
The tensor product of inputs1 and inputs2, where each output irrep is a weighted linear combination with (partially shared) learnable weights of all valid coupling paths.