e3x.nn.modules.Tensor

class e3x.nn.modules.Tensor(max_degree=None, include_pseudotensors=True, cartesian_order=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function tensor_variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

Tensor product of two equivariant feature representations.

Computes linear combinations (with learnable coefficients) of the direct sum representation of all possible tensor products of irreps in the input features. If the inputs are \(\mathbf{x} \in \mathbb{R}^{P_1\times (L_1+1)^2 \times F}\) and \(\mathbf{y} \in \mathbb{R}^{P_2\times (L_2+1)^2 \times F}\), the output is \(\mathbf{z} \in \mathbb{R}^{P_3\times (L_3+1)^2 \times F}\). Here, \(P_1\), \(P_2\), and \(P_3\) are either \(1\) or \(2\) (depending on whether the inputs/output contain pseudotensors or not) and \(L_1\), \(L_2\), and \(L_3\) nonnegative integers (\(L_3\) = max_degree). The entries of \(\mathbf{z}\) are given by

\[\mathbf{z}^{(c_\gamma)} = \sum_{(a_\alpha,b_\beta)\in V} \mathbf{w}_{(a_\alpha,b_\beta,c_\gamma)} \circ \left( \mathbf{x}^{(a_\alpha)} \otimes^{(c_\gamma)}\mathbf{y}^{(b_\beta)} \right)\,,\]

where the sum runs over all \((a_\alpha,b_\beta)\) in the set of valid combinations \(V\) and \(\mathbf{w}_{(a_\alpha,b_\beta,c_\gamma)} \in \mathbb{R}^{1\times 1\times F}\) are learnable (feature-wise) weight parameters. Each combination \((a_\alpha,b_\beta,c_\gamma)\) has separate parameters and the element-wise product ‘\(\circ\)’ implies broadcasting over dimensions. The set \(V\) contains all \((a_\alpha,b_\beta)\) for which the condition

\[\lvert a - b \rvert \leq c \leq a + b \enspace \land \enspace \left( \left( \gamma = +1 \enspace \land \enspace \alpha = \beta \right) \enspace \lor \enspace \left( \gamma = -1 \enspace \land \enspace \alpha \neq \beta \right) \right)\]

is true. If include_pseudotensors = False, coupling paths that lead to pseudotensors are not computed. This means that all entries that do not satisfy

\[\left(c \in \{2n+1 : n\in \mathbb{N}_0\} \enspace \land \enspace \gamma = -1 \right) \enspace \lor \enspace \left(c \in \{2n : n\in \mathbb{N}_0\} \enspace \land \enspace \gamma = +1 \right)\]

(either \(c\) is odd and the parity is odd, or \(c\) is even and the parity is even) are omitted. See also here for more details on the notation used here and the coupling of irreps in general. The following diagram shows a visualization of the computation for the example \(P_1=P_2=P_3=2\), \(L_1=L_2=L_3=1\). For better clarity, weights are only labelled for 2 out of the 20 possible coupling paths.

visualization

max_degree

Maximum degree of the output. If not given, max_degree is chosen as the maximum of the maximum degrees of inputs1 and inputs2.

include_pseudotensors

If False, all coupling paths that produce pseudotensors are omitted.

cartesian_order

If True, Cartesian order is assumed.

dtype

The dtype of the computation.

param_dtype

The dtype passed to parameter initializers.

precision

Numerical precision of the computation, see jax.lax.Precision for details.

kernel_init

Initializer function for the weight matrix.

__call__(inputs1, inputs2)[source]

Computes the tensor product of inputs1 and inputs2.

Parameters:
Return type:

Union[Float[Array, '... 1 (max_degree3+1)**2 num_features'], Float[Array, '... 2 (max_degree3+1)**2 num_features']]

Returns:

The tensor product of inputs1 and inputs2, where each output irrep is a weighted linear combination with learnable weights of all valid coupling paths.