e3x.nn.modules.Dense

class e3x.nn.modules.Dense(features, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

A linear transformation applied over the last dimension of the input.

The transformation can be written as

\[\begin{split}\mathbf{y}^{(\ell_p)} = \begin{cases} \mathbf{x}^{(\ell_p)}\mathbf{W}_{(\ell_p)} + \mathbf{b} & \ell_p = 0_+ \\ \mathbf{x}^{(\ell_p)}\mathbf{W}_{(\ell_p)} & \ell_p \neq 0_+ \end{cases}\end{split}\]

where \(\mathbf{x} \in \mathbb{R}^{P\times (L+1)^2 \times F_{\mathrm{in}}}\) and \(\mathbf{y} \in \mathbb{R}^{P\times (L+1)^2 \times F_{\mathrm{out}}}\) are the inputs and outputs, respectively. Here, \(P\) is either \(1\) or \(2\) (depending on whether the inputs contain pseudotensors or not), \(L\) is the maximum degree of the input features, and \(F_{\mathrm{in}}\) and \(F_{\mathrm{out}}\) = features are the number of input and output features. Every combination of degree \(\ell\) and parity \(p\) has separate weight matrices \(\mathbf{W}_{(\ell_p)}\). Note that a bias term \(\mathbf{b} \in \mathbb{R}^{1\times 1 \times F_{\mathrm{out}}}\) is only applied to the scalar channel (\(\ell_p= 0_+\)) when use_bias=True.

features

The number of output features \(F_{\mathrm{out}}\).

use_bias

Whether to add a bias to the scalar channel of the output.

dtype

The dtype of the computation.

param_dtype

The dtype passed to parameter initializers.

precision

Numerical precision of the computation, see jax.lax.Precision for details.

kernel_init

Initializer function for the weight matrix.

bias_init

Initializer function for the bias.

__call__(inputs)[source]

Applies a linear transformation to the inputs along the last dimension.

Parameters:

inputs (Union[Float[Array, '... 1 (max_degree+1)**2 in_features'], Float[Array, '... 2 (max_degree+1)**2 in_features']]) – The nd-array to be transformed.

Return type:

Union[Float[Array, '... 1 (max_degree+1)**2 out_features'], Float[Array, '... 2 (max_degree+1)**2 out_features']]

Returns:

The transformed input.