e3x.nn.modules.MessagePass

class e3x.nn.modules.MessagePass(max_degree=None, use_basis_bias=False, include_pseudotensors=True, cartesian_order=True, use_fused_tensor=False, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, dense_kernel_init=<function variance_scaling.<locals>.init>, dense_bias_init=<function zeros>, tensor_kernel_init=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: _Conv

Equivariant message-passing step.

max_degree

Maximum degree of the output. If not given, the max_degree is chosen as the maximum of the max_degree of inputs and basis.

use_basis_bias

Whether to add a bias to the linear combination of basis functions.

include_pseudotensors

If False, all coupling paths that produce pseudotensors are omitted.

cartesian_order

If True, Cartesian order is assumed.

use_fused_tensor

If True, FusedTensor is used instead of Tensor for computing the tensor product.

dtype

The dtype of the computation.

param_dtype

The dtype passed to parameter initializers.

precision

Numerical precision of the computation, see jax.lax.Precision for details.

dense_kernel_init

Initializer function for the weight matrix of the Dense layer.

dense_bias_init

Initializer function for the bias of the Dense layer.

tensor_kernel_init

Initializer function for the weight matrix of the Tensor layer.

__call__(inputs, basis, weights=None, *, adj_idx=None, where=None, dst_idx=None, src_idx=None, num_segments=None, indices_are_sorted=False)[source]

Applies a single message-passing step.

This layer computes “messages” \(\mathbf{m}\) as

\[\begin{split}\mathbf{f} &= \mathrm{dense}(\mathbf{b})\\ \mathbf{m}[i] &= \sum_{j \in \mathcal{N}[i]} \mathrm{tensor}(\mathbf{x}[j],\mathbf{f}[ij])\end{split}\]

where \(\mathbf{x}[1],\dots,\mathbf{x}[N]\) are the \(N\) input features, and \(\mathbf{b}\) are basis functions for all relevant interactions between pairs \(i\) and \(j\) from the \(N\) inputs. The relevant interactions for index \(i\) are given by the set of “neighborhood indices” \(\mathcal{N}[i]\) specified by either a dense (adj_idx) or sparse (dst_idx and src_idx) index list. The \(\mathrm{tensor}\) transformation corresponds to either a Tensor (use_fused_tensor=False) or a FusedTensor (use_fused_tensor=True) layer. If the weights argument is not None, \(\mathbf{m}\) is computed as

\[\mathbf{m}[i] = \sum_{j \in \mathcal{N}_i} w[ij] \cdot \mathrm{tensor}(\mathbf{x}[j],\mathbf{f}[ij])\]

instead, where \(w[ij]\) is the entry of weights corresponding to the interaction between pairs \(i\) and :math:`j.

Parameters:
Return type:

Union[Float[Array, '... N 1 (out_max_degree+1)**2 num_features'], Float[Array, '... N 2 (out_max_degree+1)**2 num_features']]

Returns:

The result of the message passing step.

Raises:

ValueError – If weights are not None and cannot be broadcasted to the gathered inputs.