e3x.nn.wrappers.basis

e3x.nn.wrappers.basis(r, *, max_degree, num, radial_fn, angular_fn=functools.partial(<function spherical_harmonics>, r_is_normalized=True, normalization='racah'), cutoff_fn=None, return_cutoff=False, return_norm=False, damping_fn=None, cartesian_order=True)[source]

Convenience wrapper for computing radial-angular basis functions.

This function can be used to compute radial-angular basis functions of the form

\[\mathrm{B}_{n\ell m}(\vec{r}) = R_{n\ell}\left(\lVert\vec{r}\rVert\right) A_\ell^m\left(\frac{\vec{r}}{\lVert\vec{r}\rVert}\right)\]

for input vectors \(\vec{r}=[x\ y\ z]^\intercal \in \mathbb{R}^3\). Here, \(R_{n\ell}\) is the radial component and \(A_\ell^m\) are angular components (given by angular_fn). In the most simple case, the radial component is independent of \(\ell\) and given by

\[R_{n\ell}(r) = g_n(r)\,,\]

where \(g_n(r)\) is one of the outputs of radial_fn. However, since angular functions such as the spherical harmonics \(Y_\ell^m\) for \(\ell > 0\) are undefined when \(\vec{r}=[0\ 0\ 0]^\intercal\), depending on whether zero vectors are expected as inputs, it can be desirable to damp the radial component for small input vectors. When damping_fn is not None, the radial component is instead given by

\[\begin{split}R_{n\ell}(r) = \begin{cases} g_n(r) & l = 0 \\ g_n(r) \cdot d(r) & l > 0 \\ \end{cases}\,,\end{split}\]

where \(d(r)\) is the output of damping_fn. Similarly, it is often useful to combine the radial function with a cutoff function, such that the radial component is zero beyond a certain cutoff radius. When cutoff_fn is not None, the radial component is given by

\[R_{n\ell}(r) = g_n(r) \cdot c(r)\,,\]

where \(c(r)\) is the output of cutoff_fn. It is also possible to combine damping_fn and cutoff_fn, in which case the radial component is given by

\[\begin{split}R_{n\ell}(r) = \begin{cases} g_n(r) \cdot c(r) & l = 0 \\ g_n(r) \cdot c(r) \cdot d(r) & l > 0\,. \\ \end{cases}\end{split}\]

Example

>>> import jax.numpy as jnp
>>> import e3x
>>> r = jnp.asarray([[0.5, 1.2, -0.1], [-0.4, 0.5, 1.2]])
>>> basis = e3x.nn.basis(
...   r=r,
...   max_degree=1,
...   num=8,
...   radial_fn=e3x.nn.sinc,
...   cutoff_fn=e3x.nn.smooth_cutoff,
...   damping_fn=e3x.nn.smooth_damping,
... )
>>> basis.shape
(2, 1, 4, 8)
Parameters:
  • r (<class 'Float[Array, '... 3']'>) – Input array of shape (..., 3) containing Cartesian vectors.

  • max_degree (int) – Maximum degree of the spherical harmonics.

  • num (int) – Number of radial basis functions.

  • radial_fn (Callable[[Float[Array, '...'], int], Float[Array, '... num']]) – Callable for computing radial basis functions. This function should take an array of shape (...) (the norm of r) and an integer (the number of radial basis functions num) as input and return an array of shape (..., num) (the values of num radial basis functions \(g_n\)).

  • angular_fn (AngularFn, default: functools.partial(<function spherical_harmonics at 0x7f452d80c940>, r_is_normalized=True, normalization='racah')) – Callable for computing angular basis functions. This function should take an array of shape (..., 3) (Cartesian vectors r), an integer (the maximum degree max_degree), and a boolean (which ordering convention to use, see cartesian_order) as input and return an array of shape (..., (max_degree+1)**2) (the values of the angular basis functions \(A_\ell^m\)). By default, spherical harmonics with Racah’s normalization are used.

  • cutoff_fn (Optional[Callable[[Float[Array, '...']], Float[Array, '...']]], default: None) – Optional Callable for computing cutoff values. This function should take an array of shape (...) (the norm of r) as input and return an array of shape (...) (the values of the cutoff function).

  • return_cutoff (bool, default: False) – If True, also return the values of the cutoff function.

  • return_norm (bool, default: False) – If True, also return the norm of the input vectors r.

  • damping_fn (Optional[Callable[[Float[Array, '...']], Float[Array, '...']]], default: None) – Optional Callable for computing damping values. This function should take an array of shape (...) (the norm of r) as input and return an array of shape (...) (the values of the damping function). If present, basis functions with spherical harmonics degree \(\ell > 0\) are multiplied with the values of the damping function.

  • cartesian_order (bool, default: True) – If True, spherical harmonics are in Cartesian order.

Return type:

Union[Float[Array, '... 1 (max_degree+1)**2 num'], Tuple[Float[Array, '... 1 (max_degree+1)**2 num'], Float[Array, '...']], Tuple[Float[Array, '... 1 (max_degree+1)**2 num'], Float[Array, '...'], Float[Array, '...']]]

Returns:

Value of all basis functions for all values in r. If the input has shape (..., 3), the output has shape (..., 1, (max_degree+1)**2, num) (the 1 in the shape is a parity axis added for compatibility with other methods). If return_cutoff_value=True, or return_norm=True, also returns the values of the cutoff function/vector norms with shape (...) (a tuple of basis function values and cutoff values and/or norms is returned).