e3x.nn.initializers.tensor_variance_scaling

e3x.nn.initializers.tensor_variance_scaling(scale, mode, distribution, dtype=<class 'jax.numpy.float64'>)[source]

Variance scaling initializer for tensor product kernels.

Equivalent to variance_scaling, but for tensor product kernels.

Parameters:
  • scale (float) – Scaling factor for the variance.

  • mode ({'fan_in', 'fan_out', 'fan_avg'}) – How the variance is computed, supported values are ‘fan_in’, ‘fan_out’, and ‘fan_avg’.

  • distribution ({'truncated_normal', 'normal', 'uniform'}) – Random distribution to draw from, supported values are ‘truncated_normal’, ‘normal’, and ‘uniform’.

  • dtype (Any, default: <class 'jax.numpy.float64'>) – The desired dtype of the parameters.

Return type:

InitializerFn

Returns:

An initializer function.

Raises:

ValueError – If mode or distribution is invalid.