e3x.nn.modules.Tensor
- class e3x.nn.modules.Tensor(max_degree=None, include_pseudotensors=True, cartesian_order=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function tensor_variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]
Bases:
ModuleTensor product of two equivariant feature representations.
Computes linear combinations (with learnable coefficients) of the direct sum representation of all possible tensor products of irreps in the input features. If the inputs are \(\mathbf{x} \in \mathbb{R}^{P_1\times (L_1+1)^2 \times F}\) and \(\mathbf{y} \in \mathbb{R}^{P_2\times (L_2+1)^2 \times F}\), the output is \(\mathbf{z} \in \mathbb{R}^{P_3\times (L_3+1)^2 \times F}\). Here, \(P_1\), \(P_2\), and \(P_3\) are either \(1\) or \(2\) (depending on whether the inputs/output contain pseudotensors or not) and \(L_1\), \(L_2\), and \(L_3\) nonnegative integers (\(L_3\) =
max_degree). The entries of \(\mathbf{z}\) are given by\[\mathbf{z}^{(c_\gamma)} = \sum_{(a_\alpha,b_\beta)\in V} \mathbf{w}_{(a_\alpha,b_\beta,c_\gamma)} \circ \left( \mathbf{x}^{(a_\alpha)} \otimes^{(c_\gamma)}\mathbf{y}^{(b_\beta)} \right)\,,\]where the sum runs over all \((a_\alpha,b_\beta)\) in the set of valid combinations \(V\) and \(\mathbf{w}_{(a_\alpha,b_\beta,c_\gamma)} \in \mathbb{R}^{1\times 1\times F}\) are learnable (feature-wise) weight parameters. Each combination \((a_\alpha,b_\beta,c_\gamma)\) has separate parameters and the element-wise product ‘\(\circ\)’ implies broadcasting over dimensions. The set \(V\) contains all \((a_\alpha,b_\beta)\) for which the condition
\[\lvert a - b \rvert \leq c \leq a + b \enspace \land \enspace \left( \left( \gamma = +1 \enspace \land \enspace \alpha = \beta \right) \enspace \lor \enspace \left( \gamma = -1 \enspace \land \enspace \alpha \neq \beta \right) \right)\]is true. If
include_pseudotensors = False, coupling paths that lead to pseudotensors are not computed. This means that all entries that do not satisfy\[\left(c \in \{2n+1 : n\in \mathbb{N}_0\} \enspace \land \enspace \gamma = -1 \right) \enspace \lor \enspace \left(c \in \{2n : n\in \mathbb{N}_0\} \enspace \land \enspace \gamma = +1 \right)\](either \(c\) is odd and the parity is odd, or \(c\) is even and the parity is even) are omitted. See also here for more details on the notation used here and the coupling of irreps in general. The following diagram shows a visualization of the computation for the example \(P_1=P_2=P_3=2\), \(L_1=L_2=L_3=1\). For better clarity, weights are only labelled for 2 out of the 20 possible coupling paths.
- max_degree
Maximum degree of the output. If not given,
max_degreeis chosen as the maximum of the maximum degrees of inputs1 and inputs2.
- include_pseudotensors
If
False, all coupling paths that produce pseudotensors are omitted.
- cartesian_order
If
True, Cartesian order is assumed.
- dtype
The dtype of the computation.
- param_dtype
The dtype passed to parameter initializers.
- precision
Numerical precision of the computation, see jax.lax.Precision for details.
- kernel_init
Initializer function for the weight matrix.
- __call__(inputs1, inputs2)[source]
Computes the tensor product of inputs1 and inputs2.
- Parameters:
inputs1 (
Union[Float[Array, '... 1 (max_degree1+1)**2 num_features'], Float[Array, '... 2 (max_degree1+1)**2 num_features']]) – The first factor of the tensor product.inputs2 (
Union[Float[Array, '... 1 (max_degree2+1)**2 num_features'], Float[Array, '... 2 (max_degree2+1)**2 num_features']]) – The second factor of the tensor product.
- Return type:
- Returns:
The tensor product of inputs1 and inputs2, where each output irrep is a weighted linear combination with learnable weights of all valid coupling paths.