# Copyright 2024 The e3x Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Equivariant neural network modules.
.. _Modules:
"""
import dataclasses
import functools
import itertools
import math
from typing import Any, List, Optional, Sequence, Tuple, Union
from e3x import ops
from e3x import so3
from e3x import util
from e3x.config import Config
from flax import linen as nn
from flax.linen.dtypes import promote_dtype
import jax
import jax.numpy as jnp
import jaxtyping
from . import initializers
from .features import _extract_max_degree_and_check_shape
from .features import change_max_degree_or_type
FusedTensorInitializerFn = initializers.FusedTensorInitializerFn
InitializerFn = initializers.InitializerFn
Array = jaxtyping.Array
Bool = jaxtyping.Bool
Float = jaxtyping.Float
Integer = jaxtyping.Integer
UInt32 = jaxtyping.UInt32
Shape = Sequence[Union[int, Any]]
Dtype = Any # This could be a real type if support for that is added.
PRNGKey = UInt32[Array, '2']
PrecisionLike = jax.lax.PrecisionLike
default_embed_init = jax.nn.initializers.variance_scaling(
1.0, 'fan_in', 'normal', out_axis=0
)
[docs]
class Embed(nn.Module):
"""Embedding module.
A parameterized function from integers :math:`[0, n)` to :math:`d`-dimensional
scalar features.
Attributes:
num_embeddings: Number of embeddings :math:`n`.
features: Dimension :math:`d` of the feature space.
dtype: The :class:`dtype <jax.numpy.dtype>` of the embedding vectors.
param_dtype: The dtype passed to parameter initializers.
embedding_init: Embedding initializer.
"""
num_embeddings: int
features: int
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
embedding_init: InitializerFn = default_embed_init
embedding: Float[Array, 'num_embeddings 1 1 features'] = dataclasses.field(
init=False
)
def setup(self):
self.embedding = self.param(
'embedding',
self.embedding_init,
(self.num_embeddings, 1, 1, self.features),
self.param_dtype,
)
[docs]
def __call__(
self, inputs: Integer[Array, '...']
) -> Float[Array, '... 1 1 F']:
"""Embeds the inputs along the last dimension.
Scalar features are returned with a shape consistent with the conventions
used in other equivariant operations.
Args:
inputs: Input data, all dimensions are considered batch dimensions.
Returns:
Output which is embedded input data. The output shape follows the input,
with additional ``1,1,features`` dimensions appended.
"""
if not jnp.issubdtype(inputs.dtype, jnp.integer):
raise ValueError('input type must be an integer or unsigned integer')
(embedding,) = promote_dtype(
self.embedding, dtype=self.dtype, inexact=False
)
return jnp.take(embedding, inputs, axis=0)
default_kernel_init = jax.nn.initializers.lecun_normal()
[docs]
class Dense(nn.Module):
r"""A linear transformation applied over the last dimension of the input.
The transformation can be written as
.. math::
\mathbf{y}^{(\ell_p)} = \begin{cases}
\mathbf{x}^{(\ell_p)}\mathbf{W}_{(\ell_p)} + \mathbf{b} & \ell_p = 0_+ \\
\mathbf{x}^{(\ell_p)}\mathbf{W}_{(\ell_p)} & \ell_p \neq 0_+
\end{cases}
where
:math:`\mathbf{x} \in \mathbb{R}^{P\times (L+1)^2 \times F_{\mathrm{in}}}` and
:math:`\mathbf{y} \in \mathbb{R}^{P\times (L+1)^2 \times F_{\mathrm{out}}}`
are the inputs and outputs, respectively. Here, :math:`P` is either :math:`1`
or :math:`2` (depending on whether the inputs contain pseudotensors or not),
:math:`L` is the maximum degree of the input features, and
:math:`F_{\mathrm{in}}` and :math:`F_{\mathrm{out}}` = ``features`` are the
number of input and output features. Every combination of degree :math:`\ell`
and parity :math:`p` has separate weight matrices
:math:`\mathbf{W}_{(\ell_p)}`. Note that a bias term
:math:`\mathbf{b} \in \mathbb{R}^{1\times 1 \times F_{\mathrm{out}}}` is only
applied to the scalar channel (:math:`\ell_p= 0_+`) when ``use_bias=True``.
Attributes:
features: The number of output features :math:`F_{\mathrm{out}}`.
use_bias: Whether to add a bias to the scalar channel of the output.
dtype: The dtype of the computation.
param_dtype: The dtype passed to parameter initializers.
precision: Numerical precision of the computation, see `jax.lax.Precision`
for details.
kernel_init: Initializer function for the weight matrix.
bias_init: Initializer function for the bias.
"""
features: int
use_bias: bool = True
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
kernel_init: InitializerFn = default_kernel_init
bias_init: InitializerFn = jax.nn.initializers.zeros
@nn.nowrap
def _make_dense_for_each_degree(
self, max_degree: int, use_bias: bool, name_suffix: Optional[str] = None
) -> List[nn.Dense]:
"""Helper function for generating Modules."""
if name_suffix is None:
parity = ['+', '-']
name = [f'{l}{parity[l%2]}' for l in range(max_degree + 1)]
else:
name = [f'{l}{name_suffix}' for l in range(max_degree + 1)]
dense = []
for l in range(max_degree + 1):
dense.append(
nn.Dense(
features=self.features,
use_bias=use_bias and l == 0, # Apply bias only for scalars!
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
name=name[l],
)
)
return dense
[docs]
@nn.compact
def __call__(
self,
inputs: Union[
Float[Array, '... 1 (max_degree+1)**2 in_features'],
Float[Array, '... 2 (max_degree+1)**2 in_features'],
],
) -> Union[
Float[Array, '... 1 (max_degree+1)**2 out_features'],
Float[Array, '... 2 (max_degree+1)**2 out_features'],
]:
"""Applies a linear transformation to the inputs along the last dimension.
Args:
inputs: The nd-array to be transformed.
Returns:
The transformed input.
"""
max_degree = _extract_max_degree_and_check_shape(inputs.shape)
if inputs.shape[-3] == 2: # Has pseudotensors.
dense_e = self._make_dense_for_each_degree(max_degree, self.use_bias, '+')
dense_o = self._make_dense_for_each_degree(max_degree, False, '-')
return jnp.stack(
[
# Even parity (tensors).
jnp.concatenate(
[
dense_e[l](inputs[..., 0, l**2 : (l + 1) ** 2, :])
for l in range(max_degree + 1)
],
axis=-2,
),
# Odd parity (pseudotensors).
jnp.concatenate(
[
dense_o[l](inputs[..., 1, l**2 : (l + 1) ** 2, :])
for l in range(max_degree + 1)
],
axis=-2,
),
],
axis=-3,
)
elif inputs.shape[-3] == 1: # Has no pseudotensors.
dense = self._make_dense_for_each_degree(max_degree, self.use_bias)
return jnp.concatenate(
[
dense[l](inputs[..., l**2 : (l + 1) ** 2, :])
for l in range(max_degree + 1)
],
axis=-2,
)
else:
assert False, 'Shape has passed checks even though it is invalid!'
def _duplication_indices_for_max_degree(
max_degree: int,
) -> Integer[Array, '(max_degree+1)**2']:
"""Returns indices for use in jnp.take to expand degree-wise arrays.
This functionality is often needed to duplicate the values of an array that
stores degree-wise values for each order of the degree. The value for degree
l needs to be repeated 2*l+1 times. For example, for max_degree=2, the
duplication indices are [0, 1, 1, 1, 2, 2, 2, 2, 2].
Args:
max_degree: The maximum degree for which to construct indices.
Returns:
The corresponding indices for use in jnp.take.
"""
l = jnp.arange(max_degree + 1) # [0, 1, 2, ..., max_degree]
return jnp.repeat(l, 2 * l + 1, total_repeat_length=(max_degree + 1) ** 2)
def _make_tensor_product_mask(
shape: Tuple[int, int, int, int, int, int],
dtype: Dtype = jnp.float32,
) -> Float[Array, 'S1 L1 S2 L2 S3 L3 1']:
"""Helper function for generating the tensor product mask.
Can be multiplied with a parameter matrix to zero out all forbidden (parity
violating) coupling paths. A coupling path is forbidden whenever the sum of
the degrees of the coupled irreps and the degree of the output irrep are not
both even or both odd.
The input shape tuple must have the form (S1, L1, S2, L2, S3, L3), where each
S? is either 1 or 2 and each L? is short for (max_degree?+1)**2.
Args:
shape: The input shape tuple (see above).
dtype: The dtype of the returned mask.
Returns:
A mask array containing only 0s and 1s with the same shape as specified by
the input shape tuple with an appended size 1 dimension at position -1.
"""
def _make_index_combinations(
parity: int, max_degree: int
) -> List[Tuple[int, int, int]]:
"""Helper function for generating index combinations (useful for loops).
The parity input is either 1 (a single dimension that stores irreps with
mixed parity, where even degrees have even parity and odd degrees have odd
parity) or 2 (two dimensions, with all irreps with even parity in position
0 and all irreps with odd parity in position 1).
Args:
parity: Determines the desired parity convention (either 1 or 2, see
above).
max_degree: The maximum degree to consider.
Returns:
A list of tuples (p, l, d), where p is the parity index, l is the degree
index, and d is 0 or 1, indicating even or odd parity, respectively.
"""
assert parity in (1, 2)
if parity == 2: # All entries for p=0 are even and for p=1 odd.
idx = [(0, l, 0) for l in range(max_degree + 1)]
idx += [(1, l, 1) for l in range(max_degree + 1)]
else: # Entries are even if l is even and odd if l is odd.
idx = [(0, l, l % 2) for l in range(max_degree + 1)]
return idx
# Initialize mask to ones (forbidden paths are set to zero below).
mask = jnp.ones((*shape, 1), dtype=dtype)
# Generate lists of index combinations for the input shape.
idx1 = _make_index_combinations(shape[0], shape[1])
idx2 = _make_index_combinations(shape[2], shape[3])
idx3 = _make_index_combinations(shape[4], shape[5])
# Loop over all possible index combinations.
for pld1, pld2, pld3 in itertools.product(idx1, idx2, idx3):
p1, l1, d1 = pld1
p2, l2, d2 = pld2
p3, l3, d3 = pld3
if (d1 + d2) % 2 != d3: # Parity violation!
mask = mask.at[p1, l1, p2, l2, p3, l3, :].set(0)
return mask
default_tensor_kernel_init = initializers.tensor_lecun_normal()
[docs]
class Tensor(nn.Module):
r"""Tensor product of two equivariant feature representations.
Computes linear combinations (with learnable coefficients) of the direct sum
representation of all possible tensor products of irreps in the input
features. If the inputs are
:math:`\mathbf{x} \in \mathbb{R}^{P_1\times (L_1+1)^2 \times F}` and
:math:`\mathbf{y} \in \mathbb{R}^{P_2\times (L_2+1)^2 \times F}`, the output
is :math:`\mathbf{z} \in \mathbb{R}^{P_3\times (L_3+1)^2 \times F}`. Here,
:math:`P_1`, :math:`P_2`, and :math:`P_3` are either :math:`1` or
:math:`2` (depending on whether the inputs/output contain pseudotensors or
not) and :math:`L_1`, :math:`L_2`, and :math:`L_3` nonnegative integers
(:math:`L_3` = ``max_degree``). The entries of :math:`\mathbf{z}` are
given by
.. math::
\mathbf{z}^{(c_\gamma)} = \sum_{(a_\alpha,b_\beta)\in V}
\mathbf{w}_{(a_\alpha,b_\beta,c_\gamma)} \circ \left(
\mathbf{x}^{(a_\alpha)} \otimes^{(c_\gamma)}\mathbf{y}^{(b_\beta)}
\right)\,,
where the sum runs over all :math:`(a_\alpha,b_\beta)` in the set of valid
combinations :math:`V` and :math:`\mathbf{w}_{(a_\alpha,b_\beta,c_\gamma)} \in
\mathbb{R}^{1\times 1\times F}` are learnable (feature-wise) weight
parameters. Each combination :math:`(a_\alpha,b_\beta,c_\gamma)` has separate
parameters and the element-wise product ':math:`\circ`' implies broadcasting
over dimensions. The set :math:`V` contains all :math:`(a_\alpha,b_\beta)` for
which the condition
.. math::
\lvert a - b \rvert \leq c \leq a + b \enspace \land \enspace
\left(
\left( \gamma = +1 \enspace \land \enspace \alpha = \beta \right)
\enspace \lor \enspace
\left( \gamma = -1 \enspace \land \enspace \alpha \neq \beta \right)
\right)
is true. If ``include_pseudotensors = False``, coupling paths that lead to
pseudotensors are not computed. This means that all entries that do not
satisfy
.. math::
\left(c \in \{2n+1 : n\in \mathbb{N}_0\} \enspace \land \enspace
\gamma = -1 \right) \enspace \lor \enspace
\left(c \in \{2n : n\in \mathbb{N}_0\} \enspace \land \enspace
\gamma = +1 \right)
(either :math:`c` is odd *and* the parity is odd, *or* :math:`c` is
even *and* the parity is even) are omitted. See also
:ref:`here <CouplingIrreps>` for more details on the notation used here and
the coupling of irreps in general. The following diagram shows a visualization
of the computation for the example :math:`P_1=P_2=P_3=2`,
:math:`L_1=L_2=L_3=1`. For better clarity, weights are only labelled for 2 out
of the 20 possible coupling paths.
.. image:: ../_static/tensor_product_visualization.svg
:scale: 100 %
:align: center
:alt: visualization
|
Attributes:
max_degree: Maximum degree of the output. If not given, ``max_degree`` is
chosen as the maximum of the maximum degrees of inputs1 and inputs2.
include_pseudotensors: If ``False``, all coupling paths that produce
pseudotensors are omitted.
cartesian_order: If ``True``, Cartesian order is assumed.
dtype: The dtype of the computation.
param_dtype: The dtype passed to parameter initializers.
precision: Numerical precision of the computation, see `jax.lax.Precision`
for details.
kernel_init: Initializer function for the weight matrix.
"""
max_degree: Optional[int] = None
include_pseudotensors: bool = True
cartesian_order: bool = Config.cartesian_order
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
kernel_init: InitializerFn = default_tensor_kernel_init
[docs]
@nn.compact
def __call__(
self,
inputs1: Union[
Float[Array, '... 1 (max_degree1+1)**2 num_features'],
Float[Array, '... 2 (max_degree1+1)**2 num_features'],
],
inputs2: Union[
Float[Array, '... 1 (max_degree2+1)**2 num_features'],
Float[Array, '... 2 (max_degree2+1)**2 num_features'],
],
) -> Union[
Float[Array, '... 1 (max_degree3+1)**2 num_features'],
Float[Array, '... 2 (max_degree3+1)**2 num_features'],
]:
"""Computes the tensor product of inputs1 and inputs2.
Args:
inputs1: The first factor of the tensor product.
inputs2: The second factor of the tensor product.
Returns:
The tensor product of inputs1 and inputs2, where each output irrep is a
weighted linear combination with learnable weights of all valid coupling
paths.
"""
# Determine max_degree of inputs and output.
max_degree1 = _extract_max_degree_and_check_shape(inputs1.shape)
max_degree2 = _extract_max_degree_and_check_shape(inputs2.shape)
max_degree3 = (
max(max_degree1, max_degree2)
if self.max_degree is None
else self.max_degree
)
# Check that max_degree3 is not larger than is sensible.
if max_degree3 > max_degree1 + max_degree2:
raise ValueError(
'max_degree for the tensor product of inputs with max_degree'
f' {max_degree1} and {max_degree2} can be at most'
f' {max_degree1 + max_degree2}, received max_degree={max_degree3}'
)
# Check that axis -1 (number of features) of both inputs matches in size.
if inputs1.shape[-1] != inputs2.shape[-1]:
raise ValueError(
'axis -1 of inputs1 and input2 must have the same size, '
f'received shapes {inputs1.shape} and {inputs2.shape}'
)
# Extract number of features from size of axis -1.
features = inputs1.shape[-1]
# If both inputs contain no pseudotensors and at least one input or the
# output has max_degree == 0, the tensor product will not produce
# pseudotensors, in this case, the output will be returned with no
# pseudotensor channel, regardless of whether self.include_pseudotensors is
# True or False.
if (inputs1.shape[-3] == inputs2.shape[-3] == 1) and (
max_degree1 == 0 or max_degree2 == 0 or max_degree3 == 0
):
include_pseudotensors = False
else:
include_pseudotensors = self.include_pseudotensors
# Determine number of parity channels.
num_parity1 = inputs1.shape[-3]
num_parity2 = inputs2.shape[-3]
num_parity3 = 2 if include_pseudotensors else 1
# Initialize parameters.
kernel_shape = (
num_parity1,
max_degree1 + 1,
num_parity2,
max_degree2 + 1,
num_parity3,
max_degree3 + 1,
features,
)
kernel = self.param(
'kernel', self.kernel_init, kernel_shape, self.param_dtype
)
(kernel,) = promote_dtype(kernel, dtype=self.dtype)
# If any of the two inputs or the output do not contain pseudotensors, the
# forbidded coupling paths correspond to "mixed entries within array
# slices". However, if all inputs and the output contain pseudotensors, the
# forbidden coupling paths all correspond to "whole slices" of the arrays.
# Instead of masking specific entries, it is then more efficient to slice
# the arrays and compute the allowed paths separately, effectively cutting
# the number of necessary computations in half.
mixed_coupling_paths = not num_parity1 == num_parity2 == num_parity3 == 2
# Initialize constants.
with jax.ensure_compile_time_eval():
# Clebsch-Gordan tensor.
cg = so3.clebsch_gordan(
max_degree1,
max_degree2,
max_degree3,
cartesian_order=self.cartesian_order,
)
# Mask for zeroing out forbidden (parity violating) coupling paths.
if mixed_coupling_paths:
mask = _make_tensor_product_mask(kernel_shape[:-1])
else:
mask = 1
# Indices for expanding shape of kernel.
idx1 = _duplication_indices_for_max_degree(max_degree1)
idx2 = _duplication_indices_for_max_degree(max_degree2)
idx3 = _duplication_indices_for_max_degree(max_degree3)
# Mask kernel (only necessary for mixed coupling paths)
if mixed_coupling_paths:
kernel *= mask
# Expand shape (necessary for correct broadcasting).
kernel = jnp.take(kernel, idx1, axis=1, indices_are_sorted=True)
kernel = jnp.take(kernel, idx2, axis=3, indices_are_sorted=True)
kernel = jnp.take(kernel, idx3, axis=5, indices_are_sorted=True)
if mixed_coupling_paths:
return jnp.einsum(
'...plf,...qmf,plqmrnf,lmn->...rnf',
inputs1,
inputs2,
kernel,
cg,
precision=self.precision,
optimize='optimal',
)
else:
# Compute all allowed even/odd + even/odd -> even/odd coupling paths.
def _couple_slices(
i: int, j: int, k: int
) -> Float[Array, '... (max_degree3+1)**2 num_features']:
"""Helper function for coupling slice (i, j, k)."""
return jnp.einsum(
'...lf,...mf,lmnf,lmn->...nf',
inputs1[..., i, :, :],
inputs2[..., j, :, :],
kernel[i, :, j, :, k, :, :],
cg,
precision=self.precision,
optimize='optimal',
)
eee = _couple_slices(0, 0, 0) # even + even -> even
ooe = _couple_slices(1, 1, 0) # odd + odd -> even
eoo = _couple_slices(0, 1, 1) # even + odd -> odd
oeo = _couple_slices(1, 0, 1) # odd + even -> odd
# Combine same parities and return stacked features.
return jnp.stack((eee + ooe, eoo + oeo), axis=-3)
default_fused_tensor_kernel_init = initializers.fused_tensor_normal
[docs]
class FusedTensor(nn.Module):
r"""Fused tensor product of two equivariant feature representations.
This module performs a similar function as
:class:`Tensor <e3x.nn.modules.Tensor>`, but has a lower computational
complexity and fewer learnable parameters. Given two inputs
:math:`\mathbf{x} \in \mathbb{R}^{P_1\times (L_1+1)^2 \times F}` and
:math:`\mathbf{y} \in \mathbb{R}^{P_2\times (L_2+1)^2 \times F}` the output
is :math:`\mathbf{z} \in \mathbb{R}^{P_3\times (L_3+1)^2 \times F}`. Here,
:math:`P_1`, :math:`P_2`, and :math:`P_3` are either :math:`1` or
:math:`2` (depending on whether the inputs/output contain pseudotensors or
not) and :math:`L_1`, :math:`L_2`, and :math:`L_3` are positive integers or
zero (:math:`L_3` = ``max_degree``). The computation consists of the following
steps:
1a. The constituent irreps
:math:`\mathbf{x}^{(a_\alpha)} \in \mathbb{R}^{1\times (2a+1) \times F}` and
:math:`\mathbf{y}^{(b_\beta)} \in \mathbb{R}^{1\times (2b+1) \times F}` of the
features :math:`\mathbf{x}` and :math:`\mathbf{y}` are transformed via a
change of basis ("vectors" to "matrices") to
:math:`\mathbf{\tilde{x}}^{(a_\alpha)}, \mathbf{\tilde{y}}^{(b_\beta)} \in
\mathbb{R}^{(2\tilde{l}+1) \times (2\tilde{l}+1) \times F}` with
:math:`\tilde{l} =
\left\lceil\frac{\mathrm{max}(L_1, L_2, L_3)}{2}\right\rceil`.
1b. The individual "matrix irreps" with equal parities are multiplied with
(separate) learnable weights :math:`\mathbf{w} \in
\mathbb{R}^{1 \times 1 \times F}` and added to form the matrices
:math:`\mathbf{X}^{(+)},\mathbf{X}^{(-)},\mathbf{Y}^{(+)},\mathbf{Y}^{(-)}
\in \mathbb{R}^{(2\tilde{l}+1) \times (2\tilde{l}+1) \times F}` (the
element-wise product ':math:`\circ`' implies broadcasting over
dimensions):
.. math::
\mathbf{X}^{(+)} &= \sum_{a=0}^{L_1} \mathbf{w}_{\mathbf{x}^{(a_+)}}
\circ \mathbf{\tilde{x}}^{(a_+)}
\mathbf{X}^{(-)} &= \sum_{a=0}^{L_1} \mathbf{w}_{\mathbf{x}^{(a_-)}}
\circ \mathbf{\tilde{x}}^{(a_-)}
\mathbf{Y}^{(+)} &= \sum_{b=0}^{L_2} \mathbf{w}_{\mathbf{y}^{(b_+)}}
\circ \mathbf{\tilde{y}}^{(b_+)}
\mathbf{Y}^{(-)} &= \sum_{b=0}^{L_2} \mathbf{w}_{\mathbf{y}^{(b_-)}}
\circ \mathbf{\tilde{y}}^{(b_-)}
Any potentially "missing" irreps, e.g. :math:`\mathbf{\tilde{x}}^{(1_+)}` if
:math:`\mathbf{x}` does not contain pseudotensors, are assumed to be zero.
2. The so-formed matrices are coupled by matrix multiplication to produce new
matrices as follows ("batch matrix multiplication" over the last dimension
with size :math:`F` is implied):
.. math::
\mathbf{Z}^{(+,+)} &= \mathbf{X}^{(+)}\mathbf{Y}^{(+)}
\mathbf{Z}^{(-,-)} &= \mathbf{X}^{(-)}\mathbf{Y}^{(-)}
\mathbf{Z}^{(+,-)} &= \mathbf{X}^{(+)}\mathbf{Y}^{(-)}
\mathbf{Z}^{(-,+)} &= \mathbf{X}^{(-)}\mathbf{Y}^{(+)}
Note: :math:`\mathbf{Z}^{(+,+)}` and :math:`\mathbf{Z}^{(-,-)}` have even
parity and :math:`\mathbf{Z}^{(+,-)}` and :math:`\mathbf{Z}^{(-,+)}` have odd
parity.
3a. The matrices :math:`\mathbf{Z}^{(+,+)}, \mathbf{Z}^{(-,-)},
\mathbf{Z}^{(+,-)}, \mathbf{Z}^{(-,+)}` are "decomposed" into their
constituent "matrix irreps" :math:`\mathbf{\tilde{z}}^{(+,+,c)},
\mathbf{\tilde{z}}^{(-,-,c)},\mathbf{\tilde{z}}^{(+,-,c)},
\mathbf{\tilde{z}}^{(-,+,c)}
\in \mathbb{R}^{(2\tilde{l}+1) \times (2\tilde{l}+1)\times F}` with
:math:`c = 0,\dots,L_3`. During this decomposition, the individual matrix
irreps are multiplied with (separate) learnable weights
:math:`\mathbf{w} \in \mathbb{R}^{\times 1 \times 1\times F}`. This step can
be thought of as performing the inverse of the operation in step 1b.
3b. The so-obtained "matrix irreps" are transformed via a change of basis
("matrices" to "vectors") to obtain the irreps
:math:`\mathbf{z}^{(+,+,c)},\mathbf{z}^{(-,-,c)},\mathbf{z}^{(+,-,c)},
\mathbf{z}^{(-,+,c)} \in \mathbb{R}^{1 \times (2c+1) \times F}` for each value
of :math:`c`. This step can be thought of as performing the inverse of the
operation in step 1a.
4. Finally, the output irreps of degree :math:`c` are obtained by summing
matching parities:
.. math::
\mathbf{z}^{(c_+)} &= \mathbf{z}^{(+,+,c)} + \mathbf{z}^{(-,-,c)}
\mathbf{z}^{(c_-)} &= \mathbf{z}^{(+,-,c)} + \mathbf{z}^{(-,+,c)}
If ``include_pseudotensors = False``, all irreps that correspond to
pseudotensors are discarded (set to zero).
The following diagram shows a visualization of the computation for the example
:math:`P_1=P_2=P_3=2`, :math:`L_1=L_2=L_3=2`.
.. image:: ../_static/fused_tensor_product_visualization.svg
:scale: 100 %
:align: center
:alt: visualization
|
While it is helpful to think of the change of basis and weighting of
individual components performed in steps 1a,b and 3a,b separately, these steps
are really performed concurrently using Clebsch-Gordan coefficients
:math:`C^{l_3,m_3}_{l_1,m_1,l_2,m_2}`. The actual computation performed for
the "vectors to matrices" change of basis plus multiplication by weights is:
.. math::
X_{\tilde{m},\tilde{m}'} = \sum_{l=0}^{L} w_{l}\sum_{m=-l}^{l}
C^{l,m}_{\tilde{l},\tilde{m},\tilde{l},\tilde{m}'} x_{l}^{m}
whereas the computation performed for the "matrices to vectors" change of
basis plus multiplication by weights is:
.. math::
x_{l}^{m} = \tilde{w}_{l}\sum_{\tilde{m}=-\tilde{l}}^{\tilde{l}}
\sum_{\tilde{m}'=-\tilde{l}}^{\tilde{l}}
C^{l,m}_{\tilde{l},\tilde{m},\tilde{l},\tilde{m}'} X_{\tilde{m},\tilde{m}'}
Here, :math:`X_{\tilde{m},\tilde{m}'}` are the elements of matrix
:math:`\mathbf{X} \in \mathbb{R}^{(2\tilde{l}+1)\times(2\tilde{l}+1)}`
.. math::
\mathbf{X} = \begin{bmatrix}
X_{-\tilde{l},-\tilde{l}} & X_{-\tilde{l},-\tilde{l}+1} &
\cdots & X_{-\tilde{l},\tilde{l}} \\
X_{-\tilde{l}+1,-\tilde{l}} & X_{-\tilde{l}+1,-\tilde{l}+1} &
\cdots & X_{-\tilde{l}+1,\tilde{l}} \\
\vdots & \vdots & \ddots & \vdots \\
X_{\tilde{l},-\tilde{l}} & X_{\tilde{l},-\tilde{l}+1} &
\cdots & X_{\tilde{l},\tilde{l}} \\ \end{bmatrix}
and :math:`x_{l}^{m}` are the individual entries of irreps :math:`\mathbf{x}
\in \mathbb{R}^{1\times(L+1)^2}`
.. math::
\mathbf{x} = [\underbrace{x_{0}^{0}}_{\mathbf{x}^{(0)}} \quad
\underbrace{x_{1}^{-1} \;\; x_{1}^{0} \;\; x_{1}^{1}}_{\mathbf{x}^{(1)}}
\quad \cdots \quad \underbrace{\quad x_{L}^{-L} \;\; \cdots \;\;
x_{L}^{L}}_{\mathbf{x}^{(L)}} ]
(the feature dimension with size :math:`F` and parity indicators :math:`+/-`
are omitted for clarity).
Attributes:
max_degree: Maximum degree of the output. If not given, ``max_degree`` is
chosen as the maximum of the maximum degrees of inputs1 and inputs2.
include_pseudotensors: If ``False``, pseudotensors are omitted in the
output.
cartesian_order: If ``True``, Cartesian order is assumed.
dtype: The dtype of the computation.
param_dtype: The dtype passed to parameter initializers.
precision: Numerical precision of the computation, see `jax.lax.Precision`
for details.
kernel_init: Initializer function for the weight kernel.
"""
max_degree: Optional[int] = None
include_pseudotensors: bool = True
cartesian_order: bool = Config.cartesian_order
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
kernel_init: FusedTensorInitializerFn = default_fused_tensor_kernel_init
[docs]
@nn.compact
def __call__(
self,
inputs1: Union[
Float[Array, '... 1 (max_degree1+1)**2 num_features'],
Float[Array, '... 2 (max_degree1+1)**2 num_features'],
],
inputs2: Union[
Float[Array, '... 1 (max_degree2+1)**2 num_features'],
Float[Array, '... 2 (max_degree2+1)**2 num_features'],
],
) -> Union[
Float[Array, '... 1 (max_degree3+1)**2 num_features'],
Float[Array, '... 2 (max_degree3+1)**2 num_features'],
]:
"""Computes the fused tensor product of inputs1 and inputs2.
Args:
inputs1: The first factor of the tensor product.
inputs2: The second factor of the tensor product.
Returns:
The tensor product of inputs1 and inputs2, where each output irrep is a
weighted linear combination with (partially shared) learnable weights of
all valid coupling paths.
"""
# Determine max_degree of inputs and output.
max_degree1 = _extract_max_degree_and_check_shape(inputs1.shape)
max_degree2 = _extract_max_degree_and_check_shape(inputs2.shape)
max_degree3 = (
max(max_degree1, max_degree2)
if self.max_degree is None
else self.max_degree
)
# Check that max_degree3 is not larger than is sensible.
if max_degree3 > max_degree1 + max_degree2:
raise ValueError(
'max_degree for the tensor product of inputs with max_degree'
f' {max_degree1} and {max_degree2} can be at most'
f' {max_degree1 + max_degree2}, received max_degree={max_degree3}'
)
# Check that axis -1 (number of features) of both inputs matches in size.
if inputs1.shape[-1] != inputs2.shape[-1]:
raise ValueError(
'axis -1 of inputs1 and input2 must have the same size, '
f'received shapes {inputs1.shape} and {inputs2.shape}'
)
# Extract number of features from size of axis -1.
features = inputs1.shape[-1]
# Determine number of parity channels for inputs.
num_parity1 = inputs1.shape[-3]
num_parity2 = inputs2.shape[-3]
# If both inputs contain no pseudotensors and at least one input or the
# output has max_degree == 0, the tensor product will not produce
# pseudotensors, in this case, the output will be returned with no
# pseudotensor channel, regardless of whether self.include_pseudotensors is
# True or False.
if (num_parity1 == num_parity2 == 1) and (
max_degree1 == 0 or max_degree2 == 0 or max_degree3 == 0
):
include_pseudotensors = False
else:
include_pseudotensors = self.include_pseudotensors
# Initialize constants.
with jax.ensure_compile_time_eval():
max_max_degree = max(max_degree1, max_degree2, max_degree3)
# Create Clebsch-Gordan tensor and extract relevant slice.
matrix_degree = math.ceil(max_max_degree / 2)
cg = so3.clebsch_gordan(
matrix_degree,
matrix_degree,
max_max_degree,
cartesian_order=self.cartesian_order,
)
i = matrix_degree**2
j = (matrix_degree + 1) ** 2
cg = cg[i:j, i:j, :]
# Create masks for even/odd degrees.
degrees = jnp.arange(max_max_degree + 1)
repeats = 2 * degrees + 1
even = (degrees + 1) % 2
odd = degrees % 2
max_length = (max_max_degree + 1) ** 2
even_mask = jnp.repeat(even, repeats, total_repeat_length=max_length)
even_mask = jnp.expand_dims(even_mask, axis=-1)
odd_mask = jnp.repeat(odd, repeats, total_repeat_length=max_length)
odd_mask = jnp.expand_dims(odd_mask, axis=-1)
# Masks for initialization of parameters (so unused parameters are zeros).
mask_e1 = True if num_parity1 == 2 else even[: max_degree1 + 1, None]
mask_o1 = True if num_parity1 == 2 else odd[: max_degree1 + 1, None]
mask_e2 = True if num_parity2 == 2 else even[: max_degree2 + 1, None]
mask_o2 = True if num_parity2 == 2 else odd[: max_degree2 + 1, None]
if num_parity1 == num_parity2 == 1: # Output has no pseudoscalars.
mask_o3 = jnp.ones((max_degree3 + 1, 1)).at[0].set(0)
else:
mask_o3 = True
# Variance scaling factor for inputs.
num_mat = 2 * matrix_degree + 1
var_in = 1.0 / math.sqrt(num_mat) # Normalization from matrix mult.
var_in *= num_mat / min(max_degree1 + 1, max_degree2 + 1)
# Variance scaling factor for outputs.
if num_parity1 == num_parity2 == 2:
var_out = 1.0 / 2.0
elif num_parity1 == num_parity2 == 1:
if max_degree1 == 0 or max_degree2 == 0:
var_out = 1.0
else:
var_out = (
jnp.full((max_degree3 + 1, 1), fill_value=2.0).at[0].set(1.0)
)
else:
var_out = 1.0
# Initialize parameters.
shape1 = (max_degree1 + 1, features)
kernel_e1 = self.param(
'kernel_e1',
self.kernel_init(var_in, mask_e1),
shape1,
self.param_dtype,
)
kernel_o1 = self.param(
'kernel_o1',
self.kernel_init(var_in, mask_o1),
shape1,
self.param_dtype,
)
shape2 = (max_degree2 + 1, features)
kernel_e2 = self.param(
'kernel_e2',
self.kernel_init(var_in, mask_e2),
shape2,
self.param_dtype,
)
kernel_o2 = self.param(
'kernel_o2',
self.kernel_init(var_in, mask_o2),
shape2,
self.param_dtype,
)
shape3 = (max_degree3 + 1, features)
kernel_eee = self.param(
'kernel_eee', self.kernel_init(var_out), shape3, self.param_dtype
)
kernel_ooe = self.param(
'kernel_ooe', self.kernel_init(var_out), shape3, self.param_dtype
)
kernel_eoo = self.param(
'kernel_eoo',
self.kernel_init(var_out, mask_o3),
shape3,
self.param_dtype,
)
kernel_oeo = self.param(
'kernel_oeo',
self.kernel_init(var_out, mask_o3),
shape3,
self.param_dtype,
)
# Promote parameters to desired dtype.
(
kernel_e1,
kernel_o1,
kernel_e2,
kernel_o2,
kernel_eee,
kernel_ooe,
kernel_eoo,
kernel_oeo,
) = promote_dtype(
kernel_e1,
kernel_o1,
kernel_e2,
kernel_o2,
kernel_eee,
kernel_ooe,
kernel_eoo,
kernel_oeo,
dtype=self.dtype,
)
# Compute "stop indices" for slicing CG tensor and other arrays.
l1 = (max_degree1 + 1) ** 2
l2 = (max_degree2 + 1) ** 2
l3 = (max_degree3 + 1) ** 2
# Expand shape of parameters (repeat degree channels).
repeats1 = repeats[: max_degree1 + 1]
kernel_e1 = jnp.repeat(kernel_e1, repeats1, axis=0, total_repeat_length=l1)
kernel_o1 = jnp.repeat(kernel_o1, repeats1, axis=0, total_repeat_length=l1)
repeats2 = repeats[: max_degree2 + 1]
kernel_e2 = jnp.repeat(kernel_e2, repeats2, axis=0, total_repeat_length=l2)
kernel_o2 = jnp.repeat(kernel_o2, repeats2, axis=0, total_repeat_length=l2)
repeats3 = repeats[: max_degree3 + 1]
kernel_eee = jnp.repeat(
kernel_eee, repeats3, axis=0, total_repeat_length=l3
)
kernel_ooe = jnp.repeat(
kernel_ooe, repeats3, axis=0, total_repeat_length=l3
)
kernel_eoo = jnp.repeat(
kernel_eoo, repeats3, axis=0, total_repeat_length=l3
)
kernel_oeo = jnp.repeat(
kernel_oeo, repeats3, axis=0, total_repeat_length=l3
)
def _split_into_even_and_odd_components(
x: Union[
Float[Array, '... 1 (max_degree+1)**2 num_features'],
Float[Array, '... 2 (max_degree+1)**2 num_features'],
],
l: int, # l = (desired_max_degree+1)**2
) -> Tuple[
Float[Array, '... l num_features'],
Float[Array, '... l num_features'],
]:
if x.shape[-3] == 2: # Different parities are already nicely separated.
return x[..., 0, :, :], x[..., 1, :, :]
else: # Extract even and odd components with masking.
x = jnp.squeeze(x, axis=-3) # Squeeze parity channel.
return x * even_mask[:l, :], x * odd_mask[:l, :]
# Split inputs into even and odd components.
e1, o1 = _split_into_even_and_odd_components(inputs1, l1)
e2, o2 = _split_into_even_and_odd_components(inputs2, l2)
# Convert inputs into "matrix basis".
def _to_matrix(
x: Float[Array, '... (max_degree+1)**2 num_features'],
kernel: Float[Array, '(max_degree+1)**2 num_features'],
ls: int,
) -> Float[Array, '... 2*matrix_degree+1 2*matrix_degree+1 num_features']:
"""Helper function for converting to matrix basis."""
return jnp.einsum(
'...nf,nf,lmn->...lmf',
x,
kernel,
cg[..., :ls],
precision=self.precision,
optimize='optimal',
)
e1 = _to_matrix(e1, kernel_e1, l1)
o1 = _to_matrix(o1, kernel_o1, l1)
e2 = _to_matrix(e2, kernel_e2, l2)
o2 = _to_matrix(o2, kernel_o2, l2)
# Compute the different coupling paths (matrix multiplication).
def _couple(
x1: Float[
Array, '... 2*matrix_degree+1 2*matrix_degree+1 num_features'
],
x2: Float[
Array, '... 2*matrix_degree+1 2*matrix_degree+1 num_features'
],
) -> Float[Array, '... 2*matrix_degree+1 2*matrix_degree+1 num_features']:
"""Helper function for computing coupling paths."""
return jnp.einsum(
'...lmf,...mnf->...lnf',
x1,
x2,
precision=self.precision,
optimize='optimal',
)
eee = _couple(e1, e2)
ooe = _couple(o1, o2)
eoo = _couple(e1, o2)
oeo = _couple(o1, e2)
# Convert results back into "vector basis".
def _to_vector(
x: Float[Array, '... 2*matrix_degree+1 2*matrix_degree+1 num_features'],
kernel: Float[Array, '(max_degree+1)**2 num_features'],
) -> Float[Array, '... (max_degree+1)**2 num_features']:
"""Helper function for converting to vector basis."""
return jnp.einsum(
'...lmf,nf,lmn->...nf',
x,
kernel,
cg[..., :l3],
precision=self.precision,
optimize='optimal',
)
eee = _to_vector(eee, kernel_eee)
ooe = _to_vector(ooe, kernel_ooe)
eoo = _to_vector(eoo, kernel_eoo)
oeo = _to_vector(oeo, kernel_oeo)
# Combine same parities (even/odd).
e3 = eee + ooe
o3 = eoo + oeo
# Combine even and odd output features (usual feature shape conventions).
if include_pseudotensors:
return jnp.stack((e3, o3), axis=-3)
else:
return jnp.expand_dims(
e3 * even_mask[:l3, :] + o3 * odd_mask[:l3, :], axis=-3
)
def _create_tensor(
use_fused_tensor: bool,
tensor_kernel_init: Optional[
Union[InitializerFn, FusedTensorInitializerFn]
] = None,
) -> Any:
"""Helper function for creating either FusedTensor or Tensor modules."""
if use_fused_tensor:
return functools.partial(
FusedTensor,
name='fused_tensor',
kernel_init=(
default_fused_tensor_kernel_init
if tensor_kernel_init is None
else tensor_kernel_init
),
)
else:
return functools.partial(
Tensor,
name='tensor',
kernel_init=(
default_tensor_kernel_init
if tensor_kernel_init is None
else tensor_kernel_init
),
)
[docs]
class TensorDense(nn.Module):
r"""Linear projection followed by a tensor product.
This module first applies a :class:`Dense` layer to linearly combine the input
features to two different projections, which are then coupled across the
degree dimension with a tensor product. The transformation can be written as
.. math::
\mathbf{a} &= \mathrm{dense}_1(\mathbf{x}) \\
\mathbf{b} &= \mathrm{dense}_2(\mathbf{x}) \\
\mathbf{y} &= \mathrm{tensor}(\mathbf{a}, \mathbf{b}) \\
where
:math:`\mathbf{x} \in \mathbb{R}^{P_{\mathrm{in}}\times (L_{\mathrm{in}}+1)^2 \times F_{\mathrm{in}}}`
is the input and
:math:`\mathbf{y} \in \mathbb{R}^{P_{\mathrm{out}}\times (L_{\mathrm{out}}+1)^2 \times F_{\mathrm{out}}}`
is the output. The :math:`\mathrm{tensor}` transformation corresponds to
either a :class:`Tensor` (``use_fused_tensor=False``) or a
:class:`FusedTensor` (``use_fused_tensor=True``) layer.
:math:`P_{\mathrm{out}}` is either :math:`1`
(``include_pseudotensors=False``) or :math:`2`
(``include_pseudotensors=True``).
Attributes:
features: The number of output features :math:`F_{\mathrm{out}}`. If not
given, keeps the same number features as the input
:math:`F_{\mathrm{in}}`.
max_degree: Maximum degree :math:`L_{\mathrm{out}}` of the output. If not
given, keeps the same max_degree :math:`L_{\mathrm{in}}` as the input.
use_bias: Whether to use a bias for the :class:`Dense` layer.
include_pseudotensors: If ``False``, all coupling paths that produce
pseudotensors are omitted.
cartesian_order: If ``True``, Cartesian order is assumed.
use_fused_tensor: If ``True``, :class:`FusedTensor` is used instead of
:class:`Tensor` for computing the tensor product.
dtype: The dtype of the computation.
param_dtype: The dtype passed to parameter initializers.
precision: Numerical precision of the computation, see `jax.lax.Precision`
for details.
dense_kernel_init: Initializer function for the weight matrix of the Dense
layer.
dense_bias_init: Initializer function for the bias of the Dense layer.
tensor_kernel_init: Initializer function for the weight matrix of the Tensor
layer.
"""
features: Optional[int] = None
max_degree: Optional[int] = None
use_bias: bool = True
include_pseudotensors: bool = True
cartesian_order: bool = Config.cartesian_order
use_fused_tensor: bool = Config.use_fused_tensor
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
dense_kernel_init: InitializerFn = default_kernel_init
dense_bias_init: InitializerFn = jax.nn.initializers.zeros
tensor_kernel_init: Optional[
Union[InitializerFn, FusedTensorInitializerFn]
] = None
[docs]
@nn.compact
def __call__(
self,
inputs: Union[
Float[Array, '... 1 (in_max_degree+1)**2 in_features'],
Float[Array, '... 2 (in_max_degree+1)**2 in_features'],
],
) -> Union[
Float[Array, '... 1 (out_max_degree+1)**2 out_features'],
Float[Array, '... 2 (out_max_degree+1)**2 out_features'],
]:
"""Computes the tensor product of two linear projections of inputs.
Args:
inputs: The input features to be transformed.
Returns:
The tensor product of two different linear projections of the input
features.
"""
# Extract features from size of axis -1 if it is not given.
features = inputs.shape[-1] if self.features is None else self.features
# Compute two separate linear projections.
x1, x2 = jnp.split(
Dense(
features=2 * features,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
kernel_init=self.dense_kernel_init,
bias_init=self.dense_bias_init,
name='dense',
)(inputs),
indices_or_sections=2,
axis=-1,
)
# Return tensor product of both projections.
return _create_tensor(
use_fused_tensor=self.use_fused_tensor,
tensor_kernel_init=self.tensor_kernel_init,
)(
max_degree=self.max_degree,
include_pseudotensors=self.include_pseudotensors,
cartesian_order=self.cartesian_order,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(
x1, x2
)
class _Conv(nn.Module):
r"""Basic "continuous convolution" layer.
This layer is not meant to be used directly, but provides common functionality
to other higher-level modules.
Attributes:
max_degree: Maximum degree of the output. If not given, the max_degree is
chosen as the maximum of the max_degree of inputs and basis.
use_basis_bias: Whether to add a bias to the linear combination of basis
functions.
include_pseudotensors: If False, all coupling paths that produce
pseudotensors are omitted.
cartesian_order: If True, Cartesian order is assumed.
use_fused_tensor: If True, :class:`FusedTensor` is used instead of
:class:`Tensor` for computing the tensor product.
dtype: The dtype of the computation.
param_dtype: The dtype passed to parameter initializers.
precision: Numerical precision of the computation, see `jax.lax.Precision`
for details.
dense_kernel_init: Initializer function for the weight matrix of the Dense
layer.
dense_bias_init: Initializer function for the bias of the Dense layer.
tensor_kernel_init: Initializer function for the weight matrix of the Tensor
layer.
"""
max_degree: Optional[int] = None
use_basis_bias: bool = False
include_pseudotensors: bool = True
cartesian_order: bool = Config.cartesian_order
use_fused_tensor: bool = Config.use_fused_tensor
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
dense_kernel_init: InitializerFn = default_kernel_init
dense_bias_init: InitializerFn = jax.nn.initializers.zeros
tensor_kernel_init: Optional[
Union[InitializerFn, FusedTensorInitializerFn]
] = None
@nn.compact
def __call__(
self,
inputs: Union[
Union[
Float[Array, '... N M 1 (in_max_degree+1)**2 num_features'],
Float[Array, '... N M 2 (in_max_degree+1)**2 num_features'],
],
Union[
Float[Array, '... P 1 (in_max_degree+1)**2 num_features'],
Float[Array, '... P 2 (in_max_degree+1)**2 num_features'],
],
],
basis: Optional[
Union[
Float[Array, '... N M 1 (basis_max_degree+1)**2 num_basis'],
Float[Array, '... P 1 (basis_max_degree+1)**2 num_basis'],
]
] = None,
*,
adj_idx: Optional[Integer[Array, '... N M']] = None,
where: Optional[Bool[Array, '... N M']] = None,
dst_idx: Optional[Integer[Array, '... P']] = None,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
) -> Union[
Float[Array, '... N 1 (out_max_degree+1)**2 num_features'],
Float[Array, '... N 2 (out_max_degree+1)**2 num_features'],
]:
"""Applies a basic "continuous convolution".
This computation requires either a dense or sparse index list to determine
which entries of the inputs will be summed to which output entry.
Args:
inputs: "Expanded" inputs to be summed according to provided index list.
basis: Optional basis functions, if provided, they are linearly combined
to a learned "convolutional filter" that is tensor-multiplied with the
inputs before summation.
adj_idx: Adjacency indices (dense index list), or None.
where: Mask to specify which values to sum over, required for dense index
lists.
dst_idx: Destination indices (sparse index list), or None.
num_segments: Number of segments after summation, required for sparse
index lists.
indices_are_sorted: If True, dst_idx is assumed to be sorted, which may
increase performance (only used for sparse index lists).
Returns:
The indexed summation over the inputs (if basis is None), or the
tensorproduct of the inputs with learned filters (if basis is not None).
Raises:
RuntimeError: If neither dense nor sparse index lists are provided, or if
both are provided.
"""
if basis is not None:
# Check that shapes of inputs and basis are consistent.
if inputs.shape[:-3] != basis.shape[:-3]:
raise ValueError('inputs and basis have incompatible shapes')
# Generate convolution filters as linear combination of basis.
filters = Dense(
features=inputs.shape[-1],
use_bias=self.use_basis_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
kernel_init=self.dense_kernel_init,
bias_init=self.dense_bias_init,
name='filter',
)(basis)
# Calculate tensor product of convolution filters and inputs.
products = _create_tensor(
use_fused_tensor=self.use_fused_tensor,
tensor_kernel_init=self.tensor_kernel_init,
)(
max_degree=self.max_degree,
include_pseudotensors=self.include_pseudotensors,
cartesian_order=self.cartesian_order,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(
filters, inputs
)
else:
products = inputs
return ops.indexed_sum(
inputs=products,
adj_idx=adj_idx,
where=where,
dst_idx=dst_idx,
num_segments=num_segments,
indices_are_sorted=indices_are_sorted,
)
[docs]
class MessagePass(_Conv):
r"""Equivariant message-passing step.
Attributes:
max_degree: Maximum degree of the output. If not given, the max_degree is
chosen as the maximum of the max_degree of inputs and basis.
use_basis_bias: Whether to add a bias to the linear combination of basis
functions.
include_pseudotensors: If False, all coupling paths that produce
pseudotensors are omitted.
cartesian_order: If True, Cartesian order is assumed.
use_fused_tensor: If True, :class:`FusedTensor` is used instead of
:class:`Tensor` for computing the tensor product.
dtype: The dtype of the computation.
param_dtype: The dtype passed to parameter initializers.
precision: Numerical precision of the computation, see `jax.lax.Precision`
for details.
dense_kernel_init: Initializer function for the weight matrix of the Dense
layer.
dense_bias_init: Initializer function for the bias of the Dense layer.
tensor_kernel_init: Initializer function for the weight matrix of the Tensor
layer.
"""
[docs]
@nn.compact
def __call__( # pytype: disable=signature-mismatch # overriding-parameter-count-checks
self,
inputs: Union[
Float[Array, '... N 1 (in_max_degree+1)**2 num_features'],
Float[Array, '... N 2 (in_max_degree+1)**2 num_features'],
],
basis: Union[
Float[Array, '... N M 1 (basis_max_degree+1)**2 num_basis'],
Float[Array, '... P 1 (basis_max_degree+1)**2 num_basis'],
],
weights: Optional[
Float[Array, '_*broadcastable_to_gathered_inputs']
] = None,
*,
adj_idx: Optional[Integer[Array, '... N M']] = None,
where: Optional[Bool[Array, '... N M']] = None,
dst_idx: Optional[Integer[Array, '... P']] = None,
src_idx: Optional[Integer[Array, '... P']] = None,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
) -> Union[
Float[Array, '... N 1 (out_max_degree+1)**2 num_features'],
Float[Array, '... N 2 (out_max_degree+1)**2 num_features'],
]:
r"""Applies a single message-passing step.
This layer computes "messages" :math:`\mathbf{m}` as
.. math::
\mathbf{f} &= \mathrm{dense}(\mathbf{b})\\
\mathbf{m}[i] &= \sum_{j \in \mathcal{N}[i]}
\mathrm{tensor}(\mathbf{x}[j],\mathbf{f}[ij])
where :math:`\mathbf{x}[1],\dots,\mathbf{x}[N]` are the :math:`N` input
features, and :math:`\mathbf{b}` are
`basis functions <../basis_functions.html>`_ for all relevant interactions
between pairs :math:`i` and :math:`j` from the :math:`N` inputs. The
relevant interactions for index :math:`i` are given by the set of
"neighborhood indices" :math:`\mathcal{N}[i]` specified by either a dense
(``adj_idx``) or sparse (``dst_idx`` and ``src_idx``)
`index list <../neighbor_lists.html>`_. The :math:`\mathrm{tensor}`
transformation corresponds to either a :class:`Tensor`
(``use_fused_tensor=False``) or a :class:`FusedTensor`
(``use_fused_tensor=True``) layer. If the ``weights`` argument is not
``None``, :math:`\mathbf{m}` is computed as
.. math::
\mathbf{m}[i] = \sum_{j \in \mathcal{N}_i} w[ij] \cdot
\mathrm{tensor}(\mathbf{x}[j],\mathbf{f}[ij])
instead, where :math:`w[ij]` is the entry of ``weights`` corresponding to
the interaction between pairs :math:`i` and :math:`j.
Args:
inputs: A set of :math:`N` feature representations.
basis: Basis functions for all relevant interactions between pairs
:math:`i` and :math:`j` from the :math:`N` inputs (either in dense or
sparse indexed format).
weights: Optional weights for interactions between pairs :math:`i` and
:math:`j`.
adj_idx: Adjacency indices (dense index list), or `None`.
where: Mask to specify which values to sum over (only for dense index
lists). If this is `None`, the `where` mask is auto-determined from
`inputs`.
dst_idx: Destination indices (sparse index list), or `None`.
src_idx: Source indices (sparse index list), or `None`.
num_segments: Number of segments after summation (only for sparse index
lists). If this is `None`, `num_segments` is auto-determined from
`inputs`.
indices_are_sorted: If `True`, `dst_idx` is assumed to be sorted, which
may increase performance (only used for sparse index lists).
Returns:
The result of the message passing step.
Raises:
ValueError: If weights are not `None` and cannot be broadcasted to the
gathered inputs.
"""
gathered_inputs = ops.gather_src(
inputs=inputs, adj_idx=adj_idx, src_idx=src_idx
)
# Optionally multiply gathered_inputs with weights (if given).
if weights is not None:
# Shape check.
if not util.is_broadcastable(gathered_inputs.shape, weights.shape):
raise ValueError(
f'weights with shape {weights.shape} cannot be broadcasted to '
f'gathered_inputs with shape {gathered_inputs.shape}'
)
gathered_inputs *= weights
# Auto-determine num segments and where mask for indexed ops (if not given).
if num_segments is None:
num_segments = inputs.shape[-4]
if where is None and adj_idx is not None:
where = adj_idx < inputs.shape[-4]
return super().__call__(
inputs=gathered_inputs,
basis=basis,
adj_idx=adj_idx,
where=where,
dst_idx=dst_idx,
num_segments=num_segments,
indices_are_sorted=indices_are_sorted,
)
[docs]
class MultiHeadAttention(_Conv):
r"""Equivariant multi-head attention.
Attributes:
max_degree: Maximum degree of the output. If not given, the max_degree is
chosen as the maximum of the max_degree of inputs and basis.
use_basis_bias: Whether to add a bias to the linear combination of basis
functions.
include_pseudotensors: If False, all coupling paths that produce
pseudotensors are omitted.
cartesian_order: If True, Cartesian order is assumed.
use_fused_tensor: If True, :class:`FusedTensor` is used instead of
:class:`Tensor` for computing the tensor product.
dtype: The dtype of the computation.
param_dtype: The dtype passed to parameter initializers.
precision: Numerical precision of the computation, see `jax.lax.Precision`
for details.
dense_kernel_init: Initializer function for the weight matrix of the Dense
layer.
dense_bias_init: Initializer function for the bias of the Dense layer.
tensor_kernel_init: Initializer function for the weight matrix of the Tensor
layer.
num_heads: Number of attention heads.
qkv_features: Number of features used for queries, keys and values. If this
is `None`, the same number of features as in `inputs_q` is used.
out_features: Number of features for the output. If this is `None`, the same
number of features as in `inputs_q` is used.
use_relative_positional_encoding_qk: If this is `True`, relative positional
encodings are used for computing the dot product between queries and keys.
use_relative_positional_encoding_v: If this is `True`, a relative positional
encoding (with respect to the queries) is used for computing the values.
query_kernel_init: Initializer function for the weight matrix of the
:class:`Dense` layer for computing queries.
query_bias_init: Initializer function for the bias terms of the
:class:`Dense` layer for computing queries.
query_use_bias: Whether to use bias terms in the :class:`Dense` layer for
computing queries.
key_kernel_init: Initializer function for the weight matrix of the
:class:`Dense` layer for computing keys.
key_bias_init: Initializer function for the bias terms of the :class:`Dense`
layer for computing keys.
key_use_bias: Whether to use bias terms in the :class:`Dense` layer for
computing keys.
value_kernel_init: Initializer function for the weight matrix of the
:class:`Dense` layer for computing values.
value_bias_init: Initializer function for the bias terms of the
:class:`Dense` layer for computing values.
value_use_bias: Whether to use bias terms in the :class:`Dense` layer for
computing values.
output_kernel_init: Initializer function for the weight matrix of the
:class:`Dense` layer for computing outputs.
output_bias_init: Initializer function for the bias terms of the
:class:`Dense` layer for computing outputs.
output_use_bias: Whether to use bias terms in the :class:`Dense` layer for
computing outputs.
"""
num_heads: Optional[int] = 1
qkv_features: Optional[int] = None
out_features: Optional[int] = None
use_relative_positional_encoding_qk: bool = True
use_relative_positional_encoding_v: bool = True
query_kernel_init: InitializerFn = default_kernel_init
query_bias_init: InitializerFn = jax.nn.initializers.zeros
query_use_bias: bool = False
key_kernel_init: InitializerFn = default_kernel_init
key_bias_init: InitializerFn = jax.nn.initializers.zeros
key_use_bias: bool = False
value_kernel_init: InitializerFn = default_kernel_init
value_bias_init: InitializerFn = jax.nn.initializers.zeros
value_use_bias: bool = False
output_kernel_init: InitializerFn = default_kernel_init
output_bias_init: InitializerFn = jax.nn.initializers.zeros
output_use_bias: bool = True
[docs]
@nn.compact
def __call__( # pytype: disable=signature-mismatch # overriding-parameter-count-checks
self,
inputs_q: Union[
Float[Array, '... N 1 (max_degree+1)**2 q_features'],
Float[Array, '... N 2 (max_degree+1)**2 q_features'],
],
inputs_kv: Union[
Float[Array, '... M 1 (max_degree+1)**2 kv_features'],
Float[Array, '... M 2 (max_degree+1)**2 kv_features'],
],
basis: Optional[
Union[
Float[Array, '... N M 1 (basis_max_degree+1)**2 num_basis'],
Float[Array, '... P 1 (basis_max_degree+1)**2 num_basis'],
]
] = None,
cutoff_value: Optional[
Union[
Float[Array, '... N M 1 (basis_max_degree+1)**2 num_basis'],
Float[Array, '... P 1 (basis_max_degree+1)**2 num_basis'],
]
] = None,
*,
adj_idx: Optional[Integer[Array, '... N M']] = None,
where: Optional[Bool[Array, '... N M']] = None,
dst_idx: Optional[Integer[Array, '... P']] = None,
src_idx: Optional[Integer[Array, '... P']] = None,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
) -> Union[
Float[Array, '... N 1 (max_degree+1)**2 out_features'],
Float[Array, '... N 2 (max_degree+1)**2 out_features'],
]:
"""Applies multi-head attention.
Args:
inputs_q: Input features that are used to compute queries.
inputs_kv: Input features that are used to compute keys and values.
basis: Basis functions for all relevant interactions between queries and
keys (either in dense or sparse indexed format).
cutoff_value: Multiplicative cutoff values that are applied to the "raw"
softmax values (before normalization), can be used for smooth cutoffs.
adj_idx: Adjacency indices (dense index list), or `None`.
where: Mask to specify which values to sum over (only for dense index
lists). If this is `None`, the `where` mask is auto-determined from
`inputs_kv`.
dst_idx: Destination indices (sparse index list), or `None`.
src_idx: Source indices (sparse index list), or `None`.
num_segments: Number of segments after summation (only for sparse index
lists). If this is `None`, `num_segments` is auto-determined from
`inputs_q`.
indices_are_sorted: If `True`, `dst_idx` is assumed to be sorted, which
may increase performance (only used for sparse index lists).
Returns:
The result of the multi-head attention computation.
Raises:
ValueError: If `inputs_q` and `inputs_kv` have incompatible shapes, or if
`qkv_features` is not divisible by `num_heads`.
TypeError: When relative positional encodings are requested, but no input
for `basis` is provided.
"""
# Shape check.
if inputs_q.shape[:-4] != inputs_kv.shape[:-4]:
raise ValueError('inputs_q and inputs_kv have incompatible shapes')
# Check that positional encodings are possible.
if (
self.use_relative_positional_encoding_qk
or self.use_relative_positional_encoding_v
) and basis is None:
raise TypeError(
"when using relative positional encodings, 'basis' is "
'a required argument, received basis=None'
)
# Determine features and check for compatibility with num_heads.
out_features = (
inputs_q.shape[-1] if self.out_features is None else self.out_features
)
qkv_features = (
inputs_q.shape[-1] if self.qkv_features is None else self.qkv_features
)
if qkv_features % self.num_heads != 0:
raise ValueError(
f'qkv_features ({qkv_features}) must be divisible by '
f'num_heads ({self.num_heads})'
)
# For query and key projections (used to calculate the dot product), we have
# to make sure that the final query and key have the same number of
# parity/degree channels, or the dot product would be ill-defined.
max_degree_q = _extract_max_degree_and_check_shape(inputs_q.shape)
max_degree_k = _extract_max_degree_and_check_shape(inputs_kv.shape)
max_degree_qk = min(max_degree_q, max_degree_k)
has_pseudotensors_q = inputs_q.shape[-3] == 2
has_pseudotensors_k = inputs_kv.shape[-3] == 2
has_pseudotensors_qk = has_pseudotensors_q and has_pseudotensors_k
query_inputs = change_max_degree_or_type(
inputs_q,
max_degree=max_degree_qk,
include_pseudotensors=has_pseudotensors_qk,
)
key_inputs = change_max_degree_or_type(
inputs_kv,
max_degree=max_degree_qk,
include_pseudotensors=has_pseudotensors_qk,
)
# Query, key and value projections.
query = Dense(
features=qkv_features,
use_bias=self.query_use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
kernel_init=self.query_kernel_init,
bias_init=self.query_bias_init,
name='query',
)(query_inputs)
key = Dense(
features=qkv_features,
use_bias=self.key_use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
kernel_init=self.key_kernel_init,
bias_init=self.key_bias_init,
name='key',
)(key_inputs)
value = Dense(
features=qkv_features,
use_bias=self.value_use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
kernel_init=self.value_kernel_init,
bias_init=self.value_bias_init,
name='value',
)(inputs_kv)
# Split heads -> shape=(..., parity, degrees, features, heads).
query = jnp.reshape(query, (*query.shape[:-1], -1, self.num_heads))
key = jnp.reshape(key, (*key.shape[:-1], -1, self.num_heads))
# Scale query by 1/sqrt(depth) to normalize the dot product.
depth = math.prod(query.shape[-4:-1]) # parity * degrees * features
query /= jnp.sqrt(depth).astype(query.dtype)
# Gather queries and keys according to index lists.
query = ops.gather_dst(query, adj_idx=adj_idx, dst_idx=dst_idx)
key = ops.gather_src(key, adj_idx=adj_idx, src_idx=src_idx)
# Dot product.
if self.use_relative_positional_encoding_qk:
# Compute the relative positional encoding for queries and keys from the
# p=0, l=0 component of the basis.
num_parity_channels = 2 if has_pseudotensors_qk else 1
rel_pos_encoding = nn.Dense(
features=num_parity_channels * (max_degree_qk + 1) * qkv_features,
use_bias=self.use_basis_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
kernel_init=self.dense_kernel_init,
bias_init=self.dense_bias_init,
name='relative_positional_encoding',
)(basis[..., 0, 0, :])
# Reshape to (..., num_parity_channels, max_degree+1, qkv_features).
rel_pos_encoding = jnp.reshape(
rel_pos_encoding,
(
*rel_pos_encoding.shape[:-1],
num_parity_channels,
max_degree_qk + 1,
qkv_features,
),
)
# Duplicate entries for individual degrees to get the shape:
# (..., num_parity_channels, (max_degree+1)**2, qkv_features).
with jax.ensure_compile_time_eval():
idx = _duplication_indices_for_max_degree(max_degree_qk)
rel_pos_encoding = jnp.take(
rel_pos_encoding, idx, axis=-2, indices_are_sorted=True
)
# Split heads -> shape=(..., parity, degrees, features, heads).
rel_pos_encoding = jnp.reshape(
rel_pos_encoding, (*rel_pos_encoding.shape[:-1], -1, self.num_heads)
)
# Position encoding weighted dot product.
dot = jnp.einsum(
'...plfh,...plfh,...plfh->...h',
query,
key,
rel_pos_encoding,
precision=self.precision,
optimize='optimal',
)
else:
# Normal dot product.
dot = jnp.einsum(
'...plfh,...plfh->...h',
query,
key,
precision=self.precision,
optimize='optimal',
)
# Auto-determine num segments and where mask for indexed ops (if not given).
if num_segments is None:
num_segments = inputs_q.shape[-4]
if where is None and adj_idx is not None:
where = adj_idx < inputs_kv.shape[-4]
# Attention weights.
weight = jax.vmap(
functools.partial(
ops.indexed_softmax,
multiplicative_mask=cutoff_value,
adj_idx=adj_idx,
where=where,
dst_idx=dst_idx,
num_segments=num_segments,
indices_are_sorted=indices_are_sorted,
),
in_axes=-1,
out_axes=-1,
)(dot)
# Duplicate weights for each feature in a head.
weight = jnp.repeat(weight, qkv_features // self.num_heads, axis=-1)
# Expand shape of weight for broadcasting (add parity and degree channel).
weight = jnp.expand_dims(weight, (-2, -3))
# Expand shape of value by gathering, so that it matches shape of weight.
value = ops.gather_src(inputs=value, adj_idx=adj_idx, src_idx=src_idx)
# Attention weighted values (with optional relative positional encoding).
attention = super().__call__(
inputs=weight * value,
basis=basis if self.use_relative_positional_encoding_v else None,
adj_idx=adj_idx,
where=where,
dst_idx=dst_idx,
num_segments=num_segments,
indices_are_sorted=indices_are_sorted,
)
# Linear combination of individual attention heads.
outputs = Dense(
features=out_features,
use_bias=self.output_use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
kernel_init=self.output_kernel_init,
bias_init=self.output_bias_init,
name='out',
)(attention)
return outputs
[docs]
class SelfAttention(MultiHeadAttention):
r"""Equivariant self-attention.
Attributes:
max_degree: Maximum degree of the output. If not given, the max_degree is
chosen as the maximum of the max_degree of inputs and basis.
use_basis_bias: Whether to add a bias to the linear combination of basis
functions.
include_pseudotensors: If False, all coupling paths that produce
pseudotensors are omitted.
cartesian_order: If True, Cartesian order is assumed.
use_fused_tensor: If True, :class:`FusedTensor` is used instead of
:class:`Tensor` for computing the tensor product.
dtype: The dtype of the computation.
param_dtype: The dtype passed to parameter initializers.
precision: Numerical precision of the computation, see `jax.lax.Precision`
for details.
dense_kernel_init: Initializer function for the weight matrix of the Dense
layer.
dense_bias_init: Initializer function for the bias of the Dense layer.
tensor_kernel_init: Initializer function for the weight matrix of the Tensor
layer.
num_heads: Number of attention heads.
qkv_features: Number of features used for queries, keys and values. If this
is `None`, the same number of features as in `inputs_q` is used.
out_features: Number of features for the output. If this is `None`, the same
number of features as in `inputs_q` is used.
use_relative_positional_encoding_qk: If this is `True`, relative positional
encodings are used for computing the dot product between queries and keys.
use_relative_positional_encoding_v: If this is `True`, a relative positional
encoding (with respect to the queries) is used for computing the values.
query_kernel_init: Initializer function for the weight matrix of the
:class:`Dense` layer for computing queries.
query_bias_init: Initializer function for the bias terms of the
:class:`Dense` layer for computing queries.
query_use_bias: Whether to use bias terms in the :class:`Dense` layer for
computing queries.
key_kernel_init: Initializer function for the weight matrix of the
:class:`Dense` layer for computing keys.
key_bias_init: Initializer function for the bias terms of the :class:`Dense`
layer for computing keys.
key_use_bias: Whether to use bias terms in the :class:`Dense` layer for
computing keys.
value_kernel_init: Initializer function for the weight matrix of the
:class:`Dense` layer for computing values.
value_bias_init: Initializer function for the bias terms of the
:class:`Dense` layer for computing values.
value_use_bias: Whether to use bias terms in the :class:`Dense` layer for
computing values.
output_kernel_init: Initializer function for the weight matrix of the
:class:`Dense` layer for computing outputs.
output_bias_init: Initializer function for the bias terms of the
:class:`Dense` layer for computing outputs.
output_use_bias: Whether to use bias terms in the :class:`Dense` layer for
computing outputs.
"""
[docs]
@nn.compact
def __call__(
self,
inputs: Union[
Float[Array, '... N 1 (max_degree+1)**2 num_features'],
Float[Array, '... N 2 (max_degree+1)**2 num_features'],
],
basis: Optional[
Union[
Float[Array, '... N M 1 (basis_max_degree+1)**2 num_basis'],
Float[Array, '... P 1 (basis_max_degree+1)**2 num_basis'],
]
] = None,
cutoff_value: Optional[
Union[
Float[Array, '... N M 1 #(basis_max_degree+1)**2 #num_basis'],
Float[Array, '... P 1 #(basis_max_degree+1)**2 #num_basis'],
]
] = None,
*,
adj_idx: Optional[Integer[Array, '... N M']] = None,
where: Optional[Bool[Array, '... N M']] = None,
dst_idx: Optional[Integer[Array, '... P']] = None,
src_idx: Optional[Integer[Array, '... P']] = None,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
) -> Union[
Float[Array, '... N 1 (max_degree+1)**2 num_features'],
Float[Array, '... N 2 (max_degree+1)**2 num_features'],
]:
"""Applies self-attention.
In principle, self-attention is very similar to message-passing, but
with an
additional weight factor for each summand, with the weights summing up
to 1.
In contrast, in ordinary message-passing, all summands have an implicit
weight of 1.
Args:
inputs: A set of :math:`N` input features.
basis: Basis functions for all relevant interactions between pairs
:math:`i` and :math:`j` from the :math:`N` inputs (either in dense or
sparse indexed format).
cutoff_value: Multiplicative cutoff values that are applied to the "raw"
softmax values (before normalization), can be used for smooth cutoffs.
adj_idx: Adjacency indices (dense index list), or `None`.
where: Mask to specify which values to sum over (only for dense index
lists). If this is `None`, the `where` mask is auto-determined from
`inputs`.
dst_idx: Destination indices (sparse index list), or `None`.
src_idx: Source indices (sparse index list), or `None`.
num_segments: Number of segments after summation (only for sparse index
lists). If this is `None`, `num_segments` is auto-determined from
`inputs`.
indices_are_sorted: If `True`, `dst_idx` is assumed to be sorted, which
may increase performance (only used for sparse index lists).
Returns:
The output of self-attention.
"""
return super().__call__(
inputs_q=inputs,
inputs_kv=inputs,
basis=basis,
cutoff_value=cutoff_value,
adj_idx=adj_idx,
where=where,
dst_idx=dst_idx,
src_idx=src_idx,
num_segments=num_segments,
indices_are_sorted=indices_are_sorted,
)