Source code for e3x.nn.functions.damping

# Copyright 2024 The e3x Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Damping functions.

Functions that go to :math:`0` for small :math:`x` and are close to :math:`1`
for large :math:`x` (assuming :math:`x>0`).
"""

from typing import Union
import jax.numpy as jnp
import jaxtyping

Array = jaxtyping.Array
Float = jaxtyping.Float


[docs] def smooth_damping( x: Float[Array, '...'], gamma: Union[Float[Array, ''], float] = 1.0 ) -> Float[Array, '...']: r"""Smooth damping function. Computes the function .. math:: \mathrm{smooth\_damping}(x) = \exp\left(-\frac{1}{\exp(\gamma x) - 1} \right) where :math:`\gamma` = ``gamma``. This function (and all its derivatives) are :math:`0` at :math:`x=0`. Further, the function quickly approaches :math:`1` for :math:`x > 0`. .. jupyter-execute:: :hide-code: import numpy as np, matplotlib.pyplot as plt import matplotlib_inline.backend_inline as inl from e3x.nn import smooth_damping inl.set_matplotlib_formats('pdf', 'svg') plt.subplots_adjust(left=0, right=1, bottom=0, top=1) x = np.linspace(0, 2, num=1001) y1 = smooth_damping(x, gamma=1.0) y2 = smooth_damping(x, gamma=5.0) plt.xlabel(r'$x$'); plt.ylabel(r'$\mathrm{smooth\_zero\_damping}(x)$') plt.plot(x, y1, lw=3, ls='-', label='gamma = 1.0') plt.plot(x, y2, lw=3, ls='--', label='gamma = 5.0') plt.legend(); plt.grid() Args: x: Input array. gamma: Exponential decay constant. Returns: The function value. """ exp = jnp.exp(-gamma * x) div = 1 - exp mask = div * div > 0 return jnp.where(mask, jnp.exp(-exp / jnp.where(mask, div, 1)), 0)