e3x.nn.initializers.tensor_glorot_normal

e3x.nn.initializers.tensor_glorot_normal(dtype=<class 'jax.numpy.float64'>)[source]

Glorot normal initializer for tensor product kernels.

Equivalent to glorot_normal, but for tensor product kernels.

Parameters:

dtype (Any, default: <class 'jax.numpy.float64'>) – The desired dtype of the parameters.

Return type:

InitializerFn

Returns:

An initializer function.