e3x.nn.wrappers.basis
- e3x.nn.wrappers.basis(r, *, max_degree, num, radial_fn, angular_fn=functools.partial(<function spherical_harmonics>, r_is_normalized=True, normalization='racah'), cutoff_fn=None, return_cutoff=False, return_norm=False, damping_fn=None, cartesian_order=True)[source]
Convenience wrapper for computing radial-angular basis functions.
This function can be used to compute radial-angular basis functions of the form
\[\mathrm{B}_{n\ell m}(\vec{r}) = R_{n\ell}\left(\lVert\vec{r}\rVert\right) A_\ell^m\left(\frac{\vec{r}}{\lVert\vec{r}\rVert}\right)\]for input vectors \(\vec{r}=[x\ y\ z]^\intercal \in \mathbb{R}^3\). Here, \(R_{n\ell}\) is the radial component and \(A_\ell^m\) are angular components (given by
angular_fn). In the most simple case, the radial component is independent of \(\ell\) and given by\[R_{n\ell}(r) = g_n(r)\,,\]where \(g_n(r)\) is one of the outputs of
radial_fn. However, since angular functions such as the spherical harmonics \(Y_\ell^m\) for \(\ell > 0\) are undefined when \(\vec{r}=[0\ 0\ 0]^\intercal\), depending on whether zero vectors are expected as inputs, it can be desirable to damp the radial component for small input vectors. Whendamping_fnis notNone, the radial component is instead given by\[\begin{split}R_{n\ell}(r) = \begin{cases} g_n(r) & l = 0 \\ g_n(r) \cdot d(r) & l > 0 \\ \end{cases}\,,\end{split}\]where \(d(r)\) is the output of
damping_fn. Similarly, it is often useful to combine the radial function with a cutoff function, such that the radial component is zero beyond a certain cutoff radius. Whencutoff_fnis notNone, the radial component is given by\[R_{n\ell}(r) = g_n(r) \cdot c(r)\,,\]where \(c(r)\) is the output of
cutoff_fn. It is also possible to combinedamping_fnandcutoff_fn, in which case the radial component is given by\[\begin{split}R_{n\ell}(r) = \begin{cases} g_n(r) \cdot c(r) & l = 0 \\ g_n(r) \cdot c(r) \cdot d(r) & l > 0\,. \\ \end{cases}\end{split}\]Example
>>> import jax.numpy as jnp >>> import e3x >>> r = jnp.asarray([[0.5, 1.2, -0.1], [-0.4, 0.5, 1.2]]) >>> basis = e3x.nn.basis( ... r=r, ... max_degree=1, ... num=8, ... radial_fn=e3x.nn.sinc, ... cutoff_fn=e3x.nn.smooth_cutoff, ... damping_fn=e3x.nn.smooth_damping, ... ) >>> basis.shape (2, 1, 4, 8)
- Parameters:
r (
<class 'Float[Array, '... 3']'>) – Input array of shape(..., 3)containing Cartesian vectors.max_degree (
int) – Maximum degree of the spherical harmonics.num (
int) – Number of radial basis functions.radial_fn (
Callable[[Float[Array, '...'], int], Float[Array, '... num']]) – Callable for computing radial basis functions. This function should take an array of shape(...)(the norm ofr) and an integer (the number of radial basis functionsnum) as input and return an array of shape(..., num)(the values ofnumradial basis functions \(g_n\)).angular_fn (
AngularFn, default:functools.partial(<function spherical_harmonics at 0x7f452d80c940>, r_is_normalized=True, normalization='racah')) – Callable for computing angular basis functions. This function should take an array of shape(..., 3)(Cartesian vectorsr), an integer (the maximum degreemax_degree), and a boolean (which ordering convention to use, seecartesian_order) as input and return an array of shape(..., (max_degree+1)**2)(the values of the angular basis functions \(A_\ell^m\)). By default, spherical harmonics with Racah’s normalization are used.cutoff_fn (
Optional[Callable[[Float[Array, '...']], Float[Array, '...']]], default:None) – Optional Callable for computing cutoff values. This function should take an array of shape(...)(the norm ofr) as input and return an array of shape(...)(the values of the cutoff function).return_cutoff (
bool, default:False) – IfTrue, also return the values of the cutoff function.return_norm (
bool, default:False) – IfTrue, also return the norm of the input vectorsr.damping_fn (
Optional[Callable[[Float[Array, '...']], Float[Array, '...']]], default:None) – Optional Callable for computing damping values. This function should take an array of shape(...)(the norm ofr) as input and return an array of shape(...)(the values of the damping function). If present, basis functions with spherical harmonics degree \(\ell > 0\) are multiplied with the values of the damping function.cartesian_order (
bool, default:True) – IfTrue, spherical harmonics are in Cartesian order.
- Return type:
- Returns:
Value of all basis functions for all values in
r. If the input has shape(..., 3), the output has shape(..., 1, (max_degree+1)**2, num)(the1in the shape is a parity axis added for compatibility with other methods). Ifreturn_cutoff_value=True, orreturn_norm=True, also returns the values of the cutoff function/vector norms with shape(...)(a tuple of basis function values and cutoff values and/or norms is returned).