e3x.nn.initializers.fused_tensor_uniform

e3x.nn.initializers.fused_tensor_uniform(scale=1.0, mask=True, dtype=<class 'jax.numpy.float64'>)[source]

Initializer for fused tensor product kernels with uniform distribution.

Parameters:
  • scale (Union[float, Float[Array, '*Shape']], default: 1.0) – Scaling factor for the variance.

  • mask (Union[bool, Array], default: True) – Mask for zeroing unused parameters.

  • dtype (Any, default: <class 'jax.numpy.float64'>) – The desired dtype of the parameters.

Return type:

InitializerFn

Returns:

An initializer function.