e3x.nn.initializers.fused_tensor_normal
- e3x.nn.initializers.fused_tensor_normal(scale=1.0, mask=True, dtype=<class 'jax.numpy.float64'>)[source]
Initializer for fused tensor product kernels with normal distribution.
- Parameters:
scale (
Union[float, Float[Array, '*Shape']], default:1.0) – Scaling factor for the variance.mask (
Union[bool,Array], default:True) – Mask for zeroing unused parameters.dtype (
Any, default:<class 'jax.numpy.float64'>) – The desired dtype of the parameters.
- Return type:
- Returns:
An initializer function.