e3x.nn.modules.TensorDense
- class e3x.nn.modules.TensorDense(features=None, max_degree=None, use_bias=True, include_pseudotensors=True, cartesian_order=True, use_fused_tensor=False, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, dense_kernel_init=<function variance_scaling.<locals>.init>, dense_bias_init=<function zeros>, tensor_kernel_init=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]
Bases:
ModuleLinear projection followed by a tensor product.
This module first applies a
Denselayer to linearly combine the input features to two different projections, which are then coupled across the degree dimension with a tensor product. The transformation can be written as\[\begin{split}\mathbf{a} &= \mathrm{dense}_1(\mathbf{x}) \\ \mathbf{b} &= \mathrm{dense}_2(\mathbf{x}) \\ \mathbf{y} &= \mathrm{tensor}(\mathbf{a}, \mathbf{b}) \\\end{split}\]where \(\mathbf{x} \in \mathbb{R}^{P_{\mathrm{in}}\times (L_{\mathrm{in}}+1)^2 \times F_{\mathrm{in}}}\) is the input and \(\mathbf{y} \in \mathbb{R}^{P_{\mathrm{out}}\times (L_{\mathrm{out}}+1)^2 \times F_{\mathrm{out}}}\) is the output. The \(\mathrm{tensor}\) transformation corresponds to either a
Tensor(use_fused_tensor=False) or aFusedTensor(use_fused_tensor=True) layer. \(P_{\mathrm{out}}\) is either \(1\) (include_pseudotensors=False) or \(2\) (include_pseudotensors=True).- features
The number of output features \(F_{\mathrm{out}}\). If not given, keeps the same number features as the input \(F_{\mathrm{in}}\).
- max_degree
Maximum degree \(L_{\mathrm{out}}\) of the output. If not given, keeps the same max_degree \(L_{\mathrm{in}}\) as the input.
- include_pseudotensors
If
False, all coupling paths that produce pseudotensors are omitted.
- cartesian_order
If
True, Cartesian order is assumed.
- use_fused_tensor
If
True,FusedTensoris used instead ofTensorfor computing the tensor product.
- dtype
The dtype of the computation.
- param_dtype
The dtype passed to parameter initializers.
- precision
Numerical precision of the computation, see jax.lax.Precision for details.
- dense_kernel_init
Initializer function for the weight matrix of the Dense layer.
- dense_bias_init
Initializer function for the bias of the Dense layer.
- tensor_kernel_init
Initializer function for the weight matrix of the Tensor layer.
- __call__(inputs)[source]
Computes the tensor product of two linear projections of inputs.
- Parameters:
inputs (
Union[Float[Array, '... 1 (in_max_degree+1)**2 in_features'], Float[Array, '... 2 (in_max_degree+1)**2 in_features']]) – The input features to be transformed.- Return type:
- Returns:
The tensor product of two different linear projections of the input features.