e3x.nn.wrappers.ExponentialBasis

class e3x.nn.wrappers.ExponentialBasis(initial_gamma=1.0, param_dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

Exponential basis module.

This module wraps basis and injects a learnable gamma parameter into the provided radial_fn. Only works with radial functions that accept a gamma keyword (all exponentially mapped radial functions).

__call__(*args, **kwargs)[source]