Source code for e3x.nn.functions.mappings

# Copyright 2024 The e3x Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Functions for mapping the interval :math:`[0,\infty)` to :math:`[0,1)`."""

from typing import Literal, Union
import jax.numpy as jnp
import jaxtyping

Array = jaxtyping.Array
Float = jaxtyping.Float

_valid_reciprocal_mappings = (
    'shifted',
    'damped',
    'cuspless',
)
ReciprocalMapping = Literal[_valid_reciprocal_mappings]


[docs] def reciprocal_mapping( x: Float[Array, '...'], kind: ReciprocalMapping = 'shifted', ) -> Float[Array, '...']: r"""Reciprocal mapping function. Computes the function (when ``kind = 'shifted'``) .. math:: \mathrm{reciprocal\_mapping}(x) = \frac{1}{x+1} which is :math:`1` for :math:`x = 0` and :math:`0` for :math:`x \to \infty`, or (when ``kind = 'damped'``) .. math:: \mathrm{reciprocal\_mapping}(x) = \frac{1-e^{-x}}{x} which is :math:`1` for :math:`x = 0` and :math:`\sim \frac{1}{x}` for :math:`x \gg 1`, or (when ``kind = 'cuspless'``) .. math:: \mathrm{reciprocal\_mapping}(x) = \frac{1}{x+e^{-x}} which is similar to ``kind = 'damped'``, but has no cusp at :math:`x = 0`. .. jupyter-execute:: :hide-code: import numpy as np, matplotlib.pyplot as plt import matplotlib_inline.backend_inline as inl from e3x.nn import reciprocal_mapping inl.set_matplotlib_formats('pdf', 'svg') plt.subplots_adjust(left=0, right=1, bottom=0, top=1) x = np.linspace(0, 5, num=1001) y1 = reciprocal_mapping(x, 'shifted') y2 = reciprocal_mapping(x, 'damped') y3 = reciprocal_mapping(x, 'cuspless') plt.xlabel(r'$x$'); plt.ylabel(r'$\mathrm{reciprocal\_mapping}(x)$') plt.plot(x, y1, lw=3, label='shifted'); plt.plot(x, y2, lw=3, label='damped'); plt.plot(x, y3, lw=3, label='cuspless'); plt.legend() plt.grid() Args: x: Input array. kind: Which kind of mapping is used. Returns: The function value. """ # Check that type is a valid value. if kind not in _valid_reciprocal_mappings: raise ValueError( f"kind must be in {_valid_reciprocal_mappings}, received '{kind}'" ) if kind == 'shifted': return 1 / (x + 1) elif kind == 'damped': small = x < jnp.finfo(x.dtype).eps safe_x = jnp.where(small, 1, x) return jnp.where(small, 1 - x / 2 + x * x / 6, -jnp.expm1(-safe_x) / safe_x) elif kind == 'cuspless': return 1 / (x + jnp.exp(-x)) else: # Protection from potential bugs if other valid values are added. assert False, f"Missing implementation of kind '{kind}'!"
[docs] def exponential_mapping( x: Float[Array, '...'], gamma: Union[Float[Array, ''], float] = 1.0, cuspless: bool = False, ) -> Float[Array, '...']: r"""Exponential mapping function. Computes the function (when ``cuspless = False``) .. math:: \mathrm{exponential\_mapping}(x) = \exp\left(-\gamma x\right)\,, or (when ``cuspless = True``) .. math:: \mathrm{exponential\_mapping}(x) = \exp\left(-\gamma (x+e^{-x}-1)\right)\,, where :math:`\gamma` = ``gamma``. Plots for ``cuspless = False``: .. jupyter-execute:: :hide-code: import numpy as np, matplotlib.pyplot as plt import matplotlib_inline.backend_inline as inl from e3x.nn import exponential_mapping inl.set_matplotlib_formats('pdf', 'svg') plt.subplots_adjust(left=0, right=1, bottom=0, top=1) x = np.linspace(0, 3, num=1001) y1 = exponential_mapping(x, gamma=1.0, cuspless=False) y2 = exponential_mapping(x, gamma=2.0, cuspless=False) y3 = exponential_mapping(x, gamma=0.5, cuspless=False) plt.xlabel(r'$x$'); plt.ylabel(r'$\mathrm{exponential\_mapping}(x)$') plt.plot(x, y1, lw=3, ls='-', label='gamma = 1.0') plt.plot(x, y2, lw=3, ls='--', label='gamma = 2.0') plt.plot(x, y3, lw=3, ls=':', label='gamma = 0.5') plt.legend(); plt.grid() Plots for ``cuspless = True``: .. jupyter-execute:: :hide-code: inl.set_matplotlib_formats('pdf', 'svg') plt.subplots_adjust(left=0, right=1, bottom=0, top=1) x = np.linspace(0, 3, num=1001) y1 = exponential_mapping(x, gamma=1.0, cuspless=True) y2 = exponential_mapping(x, gamma=2.0, cuspless=True) y3 = exponential_mapping(x, gamma=0.5, cuspless=True) plt.xlabel(r'$x$'); plt.ylabel(r'$\mathrm{exponential\_mapping}(x)$') plt.plot(x, y1, lw=3, ls='-', label='gamma = 1.0') plt.plot(x, y2, lw=3, ls='--', label='gamma = 2.0') plt.plot(x, y3, lw=3, ls=':', label='gamma = 0.5') plt.legend(); plt.grid() Args: x: Input array. gamma: Exponential decay constant. cuspless: If this is ``True``, a cuspless exponential mapping is returned. Returns: The function value. """ if cuspless: x = x + jnp.exp(-x) - 1 return jnp.exp(-gamma * x)