e3x.nn.modules.SelfAttention

class e3x.nn.modules.SelfAttention(max_degree=None, use_basis_bias=False, include_pseudotensors=True, cartesian_order=True, use_fused_tensor=False, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, dense_kernel_init=<function variance_scaling.<locals>.init>, dense_bias_init=<function zeros>, tensor_kernel_init=None, num_heads=1, qkv_features=None, out_features=None, use_relative_positional_encoding_qk=True, use_relative_positional_encoding_v=True, query_kernel_init=<function variance_scaling.<locals>.init>, query_bias_init=<function zeros>, query_use_bias=False, key_kernel_init=<function variance_scaling.<locals>.init>, key_bias_init=<function zeros>, key_use_bias=False, value_kernel_init=<function variance_scaling.<locals>.init>, value_bias_init=<function zeros>, value_use_bias=False, output_kernel_init=<function variance_scaling.<locals>.init>, output_bias_init=<function zeros>, output_use_bias=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: MultiHeadAttention

Equivariant self-attention.

max_degree

Maximum degree of the output. If not given, the max_degree is chosen as the maximum of the max_degree of inputs and basis.

use_basis_bias

Whether to add a bias to the linear combination of basis functions.

include_pseudotensors

If False, all coupling paths that produce pseudotensors are omitted.

cartesian_order

If True, Cartesian order is assumed.

use_fused_tensor

If True, FusedTensor is used instead of Tensor for computing the tensor product.

dtype

The dtype of the computation.

param_dtype

The dtype passed to parameter initializers.

precision

Numerical precision of the computation, see jax.lax.Precision for details.

dense_kernel_init

Initializer function for the weight matrix of the Dense layer.

dense_bias_init

Initializer function for the bias of the Dense layer.

tensor_kernel_init

Initializer function for the weight matrix of the Tensor layer.

num_heads

Number of attention heads.

qkv_features

Number of features used for queries, keys and values. If this is None, the same number of features as in inputs_q is used.

out_features

Number of features for the output. If this is None, the same number of features as in inputs_q is used.

use_relative_positional_encoding_qk

If this is True, relative positional encodings are used for computing the dot product between queries and keys.

use_relative_positional_encoding_v

If this is True, a relative positional encoding (with respect to the queries) is used for computing the values.

query_kernel_init

Initializer function for the weight matrix of the Dense layer for computing queries.

query_bias_init

Initializer function for the bias terms of the Dense layer for computing queries.

query_use_bias

Whether to use bias terms in the Dense layer for computing queries.

key_kernel_init

Initializer function for the weight matrix of the Dense layer for computing keys.

key_bias_init

Initializer function for the bias terms of the Dense layer for computing keys.

key_use_bias

Whether to use bias terms in the Dense layer for computing keys.

value_kernel_init

Initializer function for the weight matrix of the Dense layer for computing values.

value_bias_init

Initializer function for the bias terms of the Dense layer for computing values.

value_use_bias

Whether to use bias terms in the Dense layer for computing values.

output_kernel_init

Initializer function for the weight matrix of the Dense layer for computing outputs.

output_bias_init

Initializer function for the bias terms of the Dense layer for computing outputs.

output_use_bias

Whether to use bias terms in the Dense layer for computing outputs.

__call__(inputs, basis=None, cutoff_value=None, *, adj_idx=None, where=None, dst_idx=None, src_idx=None, num_segments=None, indices_are_sorted=False)[source]

Applies self-attention.

In principle, self-attention is very similar to message-passing, but with an additional weight factor for each summand, with the weights summing up to 1. In contrast, in ordinary message-passing, all summands have an implicit weight of 1.

Parameters:
Return type:

Union[Float[Array, '... N 1 (max_degree+1)**2 num_features'], Float[Array, '... N 2 (max_degree+1)**2 num_features']]

Returns:

The output of self-attention.