e3x.nn.modules.Embed

class e3x.nn.modules.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

Embedding module.

A parameterized function from integers \([0, n)\) to \(d\)-dimensional scalar features.

num_embeddings

Number of embeddings \(n\).

features

Dimension \(d\) of the feature space.

dtype

The dtype of the embedding vectors.

param_dtype

The dtype passed to parameter initializers.

embedding_init

Embedding initializer.

__call__(inputs)[source]

Embeds the inputs along the last dimension.

Scalar features are returned with a shape consistent with the conventions used in other equivariant operations.

Parameters:

inputs (<class 'Integer[Array, '...']'>) – Input data, all dimensions are considered batch dimensions.

Return type:

<class 'Float[Array, '... 1 1 F']'>

Returns:

Output which is embedded input data. The output shape follows the input, with additional 1,1,features dimensions appended.