e3x.nn.modules.MessagePass
- class e3x.nn.modules.MessagePass(max_degree=None, use_basis_bias=False, include_pseudotensors=True, cartesian_order=True, use_fused_tensor=False, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, dense_kernel_init=<function variance_scaling.<locals>.init>, dense_bias_init=<function zeros>, tensor_kernel_init=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]
Bases:
_ConvEquivariant message-passing step.
- max_degree
Maximum degree of the output. If not given, the max_degree is chosen as the maximum of the max_degree of inputs and basis.
- use_basis_bias
Whether to add a bias to the linear combination of basis functions.
- include_pseudotensors
If False, all coupling paths that produce pseudotensors are omitted.
- cartesian_order
If True, Cartesian order is assumed.
- use_fused_tensor
If True,
FusedTensoris used instead ofTensorfor computing the tensor product.
- dtype
The dtype of the computation.
- param_dtype
The dtype passed to parameter initializers.
- precision
Numerical precision of the computation, see jax.lax.Precision for details.
- dense_kernel_init
Initializer function for the weight matrix of the Dense layer.
- dense_bias_init
Initializer function for the bias of the Dense layer.
- tensor_kernel_init
Initializer function for the weight matrix of the Tensor layer.
- __call__(inputs, basis, weights=None, *, adj_idx=None, where=None, dst_idx=None, src_idx=None, num_segments=None, indices_are_sorted=False)[source]
Applies a single message-passing step.
This layer computes “messages” \(\mathbf{m}\) as
\[\begin{split}\mathbf{f} &= \mathrm{dense}(\mathbf{b})\\ \mathbf{m}[i] &= \sum_{j \in \mathcal{N}[i]} \mathrm{tensor}(\mathbf{x}[j],\mathbf{f}[ij])\end{split}\]where \(\mathbf{x}[1],\dots,\mathbf{x}[N]\) are the \(N\) input features, and \(\mathbf{b}\) are basis functions for all relevant interactions between pairs \(i\) and \(j\) from the \(N\) inputs. The relevant interactions for index \(i\) are given by the set of “neighborhood indices” \(\mathcal{N}[i]\) specified by either a dense (
adj_idx) or sparse (dst_idxandsrc_idx) index list. The \(\mathrm{tensor}\) transformation corresponds to either aTensor(use_fused_tensor=False) or aFusedTensor(use_fused_tensor=True) layer. If theweightsargument is notNone, \(\mathbf{m}\) is computed as\[\mathbf{m}[i] = \sum_{j \in \mathcal{N}_i} w[ij] \cdot \mathrm{tensor}(\mathbf{x}[j],\mathbf{f}[ij])\]instead, where \(w[ij]\) is the entry of
weightscorresponding to the interaction between pairs \(i\) and :math:`j.- Parameters:
inputs (
Union[Float[Array, '... N 1 (in_max_degree+1)**2 num_features'], Float[Array, '... N 2 (in_max_degree+1)**2 num_features']]) – A set of \(N\) feature representations.basis (
Union[Float[Array, '... N M 1 (basis_max_degree+1)**2 num_basis'], Float[Array, '... P 1 (basis_max_degree+1)**2 num_basis']]) – Basis functions for all relevant interactions between pairs \(i\) and \(j\) from the \(N\) inputs (either in dense or sparse indexed format).weights (
Optional[Float[Array, '_*broadcastable_to_gathered_inputs']], default:None) – Optional weights for interactions between pairs \(i\) and \(j\).adj_idx (
Optional[Integer[Array, '... N M']], default:None) – Adjacency indices (dense index list), or None.where (
Optional[Bool[Array, '... N M']], default:None) – Mask to specify which values to sum over (only for dense index lists). If this is None, the where mask is auto-determined from inputs.dst_idx (
Optional[Integer[Array, '... P']], default:None) – Destination indices (sparse index list), or None.src_idx (
Optional[Integer[Array, '... P']], default:None) – Source indices (sparse index list), or None.num_segments (
Optional[int], default:None) – Number of segments after summation (only for sparse index lists). If this is None, num_segments is auto-determined from inputs.indices_are_sorted (
bool, default:False) – If True, dst_idx is assumed to be sorted, which may increase performance (only used for sparse index lists).
- Return type:
- Returns:
The result of the message passing step.
- Raises:
ValueError – If weights are not None and cannot be broadcasted to the gathered inputs.