Source code for e3x.nn.functions.window

# Copyright 2024 The e3x Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Window functions."""

from typing import Union
from e3x.nn.functions import smooth_switch
import jax
import jax.numpy as jnp
import jaxtyping

Array = jaxtyping.Array
Float = jaxtyping.Float


[docs] def rectangular_window( x: Float[Array, '...'], num: int, limit: Union[Float[Array, ''], float] = 1.0, ) -> Float[Array, '... num']: r"""Rectangular window basis functions. Computes the basis functions .. math:: \mathrm{rectangular\_window}_k(x) = \begin{cases} 0, & x < \frac{k l}{K}\\ 1, & \frac{k l}{K+1} \le x \le \frac{(k+1) l}{K} \\ 0, & x > \frac{(k+1) l}{K} \end{cases} where :math:`k=0 \dots K-1` with :math:`K` = ``num`` and :math:`l` = ``limit``. Plot for :math:`K = 5` and :math:`l = 1`: .. jupyter-execute:: :hide-code: import numpy as np, matplotlib.pyplot as plt import matplotlib_inline.backend_inline as inl from e3x.nn import rectangular_window inl.set_matplotlib_formats('pdf', 'svg') plt.subplots_adjust(left=0, right=1, bottom=0, top=1) x = np.linspace(0, 1.0, num=1001); K = 5; l = 1.0 y = rectangular_window(x, num=K, limit=l) plt.xlabel(r'$x$'); plt.ylabel(r'$\mathrm{rectangular\_window}_k(x)$') for k in range(K): plt.plot(x, y[:,k], lw=3, label=r'$k$'+f'={k}') plt.legend(); plt.grid() Args: x: Input array. num: Number of basis functions :math:`K`. limit: Basis functions are distributed between 0 and ``limit``. Returns: Value of all basis functions for all values in ``x``. The output shape follows the input, with an additional dimension of size ``num`` appended. """ with jax.ensure_compile_time_eval(): index = jnp.arange(num) bound = jnp.linspace(0.0, limit, num=num + 1) lower = bound[index] upper = bound[index + 1] x_1 = jnp.expand_dims(x, -1) return jnp.where( x_1 < lower, jnp.zeros_like(x_1), jnp.where(x_1 > upper, jnp.zeros_like(x_1), jnp.ones_like(x_1)), )
[docs] def triangular_window( x: Float[Array, '...'], num: int, limit: Union[Float[Array, ''], float] = 1.0, ) -> Float[Array, '... num']: r"""Triangular window basis functions. Computes the basis functions .. math:: \mathrm{triangular\_window}_k(x) = \max\left( \min\left(\frac{K}{l}x - k - 1, \frac{K}{l}x + k + 1\right),0\right) where :math:`k=0 \dots K-1` with :math:`K` = ``num`` and :math:`l` = ``limit``. Plot for :math:`K = 5` and :math:`l = 1`: .. jupyter-execute:: :hide-code: import numpy as np, matplotlib.pyplot as plt import matplotlib_inline.backend_inline as inl from e3x.nn import triangular_window inl.set_matplotlib_formats('pdf', 'svg') plt.subplots_adjust(left=0, right=1, bottom=0, top=1) x = np.linspace(0, 1.0, num=1001); K = 5; l = 1.0 y = triangular_window(x, num=K, limit=l) plt.xlabel(r'$x$'); plt.ylabel(r'$\mathrm{triangular\_window}_k(x)$') for k in range(K): plt.plot(x, y[:,k], lw=3, label=r'$k$'+f'={k}') plt.legend(); plt.grid() Args: x: Input array. num: Number of basis functions :math:`K`. limit: Basis functions are distributed between 0 and ``limit``. Returns: Value of all basis functions for all values in ``x``. The output shape follows the input, with an additional dimension of size ``num`` appended. """ with jax.ensure_compile_time_eval(): width = limit / num center = jnp.linspace(0.0, limit, num=num + 1)[:-1] lower = center - width upper = center + width x_1 = jnp.expand_dims(x, -1) return jnp.maximum( jnp.minimum((x_1 - lower) / width, -(x_1 - upper) / width), 0 )
[docs] def smooth_window( x: Float[Array, '...'], num: int, limit: Union[Float[Array, ''], float] = 1.0, ) -> Float[Array, '... num']: r"""Smooth window basis functions. Computes the basis functions (see :func:`smooth_switch <e3x.nn.functions.smooth_switch>`) .. math:: \mathrm{smooth\_window}_k(x) = \mathrm{smooth\_switch}\left(\frac{K}{l}x - \frac{kl}{K}\right) - \mathrm{smooth\_switch}\left(1 -\frac{K}{l}x + \frac{kl}{K}\right) where :math:`k=0 \dots K-1` with :math:`K` = ``num`` and :math:`l` = ``limit``. Plot for :math:`K = 5` and :math:`l = 1`: .. jupyter-execute:: :hide-code: import numpy as np, matplotlib.pyplot as plt import matplotlib_inline.backend_inline as inl from e3x.nn import smooth_window inl.set_matplotlib_formats('pdf', 'svg') plt.subplots_adjust(left=0, right=1, bottom=0, top=1) x = np.linspace(0, 1.0, num=1001); K = 5; l = 1.0 y = smooth_window(x, num=K, limit=l) plt.xlabel(r'$x$'); plt.ylabel(r'$\mathrm{smooth\_window}_k(x)$') for k in range(K): plt.plot(x, y[:,k], lw=3, label=r'$k$'+f'={k}') plt.legend(); plt.grid() Args: x: Input array. num: Number of basis functions :math:`K`. limit: Basis functions are distributed between 0 and ``limit``. Returns: Value of all basis functions for all values in ``x``. The output shape follows the input, with an additional dimension of size ``num`` appended. """ with jax.ensure_compile_time_eval(): center = jnp.linspace(0.0, limit, num=num + 1)[:-1] x_1 = num / limit * (jnp.expand_dims(x, -1) - center) return smooth_switch(x_1 + 1) - smooth_switch(x_1)