e3x.nn.modules.Embed
- class e3x.nn.modules.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]
Bases:
ModuleEmbedding module.
A parameterized function from integers \([0, n)\) to \(d\)-dimensional scalar features.
- num_embeddings
Number of embeddings \(n\).
- features
Dimension \(d\) of the feature space.
- param_dtype
The dtype passed to parameter initializers.
- embedding_init
Embedding initializer.
- __call__(inputs)[source]
Embeds the inputs along the last dimension.
Scalar features are returned with a shape consistent with the conventions used in other equivariant operations.
- Parameters:
inputs (
<class 'Integer[Array, '...']'>) – Input data, all dimensions are considered batch dimensions.- Return type:
- Returns:
Output which is embedded input data. The output shape follows the input, with additional
1,1,featuresdimensions appended.