e3x.nn.initializers.tensor_variance_scaling
- e3x.nn.initializers.tensor_variance_scaling(scale, mode, distribution, dtype=<class 'jax.numpy.float64'>)[source]
Variance scaling initializer for tensor product kernels.
Equivalent to
variance_scaling, but for tensor product kernels.- Parameters:
scale (
float) – Scaling factor for the variance.mode ({
'fan_in','fan_out','fan_avg'}) – How the variance is computed, supported values are ‘fan_in’, ‘fan_out’, and ‘fan_avg’.distribution ({
'truncated_normal','normal','uniform'}) – Random distribution to draw from, supported values are ‘truncated_normal’, ‘normal’, and ‘uniform’.dtype (
Any, default:<class 'jax.numpy.float64'>) – The desired dtype of the parameters.
- Return type:
- Returns:
An initializer function.
- Raises:
ValueError – If mode or distribution is invalid.