e3x.nn.modules.TensorDense

class e3x.nn.modules.TensorDense(features=None, max_degree=None, use_bias=True, include_pseudotensors=True, cartesian_order=True, use_fused_tensor=False, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, dense_kernel_init=<function variance_scaling.<locals>.init>, dense_bias_init=<function zeros>, tensor_kernel_init=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

Linear projection followed by a tensor product.

This module first applies a Dense layer to linearly combine the input features to two different projections, which are then coupled across the degree dimension with a tensor product. The transformation can be written as

\[\begin{split}\mathbf{a} &= \mathrm{dense}_1(\mathbf{x}) \\ \mathbf{b} &= \mathrm{dense}_2(\mathbf{x}) \\ \mathbf{y} &= \mathrm{tensor}(\mathbf{a}, \mathbf{b}) \\\end{split}\]

where \(\mathbf{x} \in \mathbb{R}^{P_{\mathrm{in}}\times (L_{\mathrm{in}}+1)^2 \times F_{\mathrm{in}}}\) is the input and \(\mathbf{y} \in \mathbb{R}^{P_{\mathrm{out}}\times (L_{\mathrm{out}}+1)^2 \times F_{\mathrm{out}}}\) is the output. The \(\mathrm{tensor}\) transformation corresponds to either a Tensor (use_fused_tensor=False) or a FusedTensor (use_fused_tensor=True) layer. \(P_{\mathrm{out}}\) is either \(1\) (include_pseudotensors=False) or \(2\) (include_pseudotensors=True).

features

The number of output features \(F_{\mathrm{out}}\). If not given, keeps the same number features as the input \(F_{\mathrm{in}}\).

max_degree

Maximum degree \(L_{\mathrm{out}}\) of the output. If not given, keeps the same max_degree \(L_{\mathrm{in}}\) as the input.

use_bias

Whether to use a bias for the Dense layer.

include_pseudotensors

If False, all coupling paths that produce pseudotensors are omitted.

cartesian_order

If True, Cartesian order is assumed.

use_fused_tensor

If True, FusedTensor is used instead of Tensor for computing the tensor product.

dtype

The dtype of the computation.

param_dtype

The dtype passed to parameter initializers.

precision

Numerical precision of the computation, see jax.lax.Precision for details.

dense_kernel_init

Initializer function for the weight matrix of the Dense layer.

dense_bias_init

Initializer function for the bias of the Dense layer.

tensor_kernel_init

Initializer function for the weight matrix of the Tensor layer.

__call__(inputs)[source]

Computes the tensor product of two linear projections of inputs.

Parameters:

inputs (Union[Float[Array, '... 1 (in_max_degree+1)**2 in_features'], Float[Array, '... 2 (in_max_degree+1)**2 in_features']]) – The input features to be transformed.

Return type:

Union[Float[Array, '... 1 (out_max_degree+1)**2 out_features'], Float[Array, '... 2 (out_max_degree+1)**2 out_features']]

Returns:

The tensor product of two different linear projections of the input features.