e3x.nn.modules.MultiHeadAttention
- class e3x.nn.modules.MultiHeadAttention(max_degree=None, use_basis_bias=False, include_pseudotensors=True, cartesian_order=True, use_fused_tensor=False, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, dense_kernel_init=<function variance_scaling.<locals>.init>, dense_bias_init=<function zeros>, tensor_kernel_init=None, num_heads=1, qkv_features=None, out_features=None, use_relative_positional_encoding_qk=True, use_relative_positional_encoding_v=True, query_kernel_init=<function variance_scaling.<locals>.init>, query_bias_init=<function zeros>, query_use_bias=False, key_kernel_init=<function variance_scaling.<locals>.init>, key_bias_init=<function zeros>, key_use_bias=False, value_kernel_init=<function variance_scaling.<locals>.init>, value_bias_init=<function zeros>, value_use_bias=False, output_kernel_init=<function variance_scaling.<locals>.init>, output_bias_init=<function zeros>, output_use_bias=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]
Bases:
_ConvEquivariant multi-head attention.
- max_degree
Maximum degree of the output. If not given, the max_degree is chosen as the maximum of the max_degree of inputs and basis.
- use_basis_bias
Whether to add a bias to the linear combination of basis functions.
- include_pseudotensors
If False, all coupling paths that produce pseudotensors are omitted.
- cartesian_order
If True, Cartesian order is assumed.
- use_fused_tensor
If True,
FusedTensoris used instead ofTensorfor computing the tensor product.
- dtype
The dtype of the computation.
- param_dtype
The dtype passed to parameter initializers.
- precision
Numerical precision of the computation, see jax.lax.Precision for details.
- dense_kernel_init
Initializer function for the weight matrix of the Dense layer.
- dense_bias_init
Initializer function for the bias of the Dense layer.
- tensor_kernel_init
Initializer function for the weight matrix of the Tensor layer.
- num_heads
Number of attention heads.
- qkv_features
Number of features used for queries, keys and values. If this is None, the same number of features as in inputs_q is used.
- out_features
Number of features for the output. If this is None, the same number of features as in inputs_q is used.
- use_relative_positional_encoding_qk
If this is True, relative positional encodings are used for computing the dot product between queries and keys.
- use_relative_positional_encoding_v
If this is True, a relative positional encoding (with respect to the queries) is used for computing the values.
- query_kernel_init
Initializer function for the weight matrix of the
Denselayer for computing queries.
- value_kernel_init
Initializer function for the weight matrix of the
Denselayer for computing values.
- output_kernel_init
Initializer function for the weight matrix of the
Denselayer for computing outputs.
- __call__(inputs_q, inputs_kv, basis=None, cutoff_value=None, *, adj_idx=None, where=None, dst_idx=None, src_idx=None, num_segments=None, indices_are_sorted=False)[source]
Applies multi-head attention.
- Parameters:
inputs_q (
Union[Float[Array, '... N 1 (max_degree+1)**2 q_features'], Float[Array, '... N 2 (max_degree+1)**2 q_features']]) – Input features that are used to compute queries.inputs_kv (
Union[Float[Array, '... M 1 (max_degree+1)**2 kv_features'], Float[Array, '... M 2 (max_degree+1)**2 kv_features']]) – Input features that are used to compute keys and values.basis (
Union[Float[Array, '... N M 1 (basis_max_degree+1)**2 num_basis'], Float[Array, '... P 1 (basis_max_degree+1)**2 num_basis'], NoneType], default:None) – Basis functions for all relevant interactions between queries and keys (either in dense or sparse indexed format).cutoff_value (
Union[Float[Array, '... N M 1 (basis_max_degree+1)**2 num_basis'], Float[Array, '... P 1 (basis_max_degree+1)**2 num_basis'], NoneType], default:None) – Multiplicative cutoff values that are applied to the “raw” softmax values (before normalization), can be used for smooth cutoffs.adj_idx (
Optional[Integer[Array, '... N M']], default:None) – Adjacency indices (dense index list), or None.where (
Optional[Bool[Array, '... N M']], default:None) – Mask to specify which values to sum over (only for dense index lists). If this is None, the where mask is auto-determined from inputs_kv.dst_idx (
Optional[Integer[Array, '... P']], default:None) – Destination indices (sparse index list), or None.src_idx (
Optional[Integer[Array, '... P']], default:None) – Source indices (sparse index list), or None.num_segments (
Optional[int], default:None) – Number of segments after summation (only for sparse index lists). If this is None, num_segments is auto-determined from inputs_q.indices_are_sorted (
bool, default:False) – If True, dst_idx is assumed to be sorted, which may increase performance (only used for sparse index lists).
- Return type:
- Returns:
The result of the multi-head attention computation.
- Raises:
ValueError – If inputs_q and inputs_kv have incompatible shapes, or if qkv_features is not divisible by num_heads.
TypeError – When relative positional encodings are requested, but no input for basis is provided.