e3x.ops.indexed.indexed_softmax
- e3x.ops.indexed.indexed_softmax(inputs, multiplicative_mask=None, *, adj_idx=None, where=None, dst_idx=None, num_segments=None, indices_are_sorted=False, **_)[source]
Determines the softmax of inputs according to sparse or dense index lists.
- Parameters:
inputs (
Union[Float[Array, '... N M'], Float[Array, '... P']]) – Inputs for which to compute the softmax.multiplicative_mask (
Union[Float[Array, '... N M'], Float[Array, '... P'], NoneType], default:None) – Optional mask to multiply with the raw exponentials (before normalization). This can be used for example for smooth cutoffs.adj_idx (
Optional[Integer[Array, '... N M']], default:None) – Adjacency indices.where (
Optional[Bool[Array, '... N M']], default:None) – Mask to specify which values to take the maximum from, required for dense index lists.dst_idx (
Optional[Integer[Array, '... P']], default:None) – Destination indices.num_segments (
Optional[int], default:None) – Number of segments after taking the maximum, required for sparse index lists.indices_are_sorted (
bool, default:False) – IfTrue,dst_idxis assumed to be sorted, which may increase performance (only used for sparse index lists).
- Return type:
- Returns:
An array with the softmax values.
- Raises:
RuntimeError – If neither dense nor sparse index lists are provided, or if both are provided.
ValueError – If the shape of multiplicative mask does not match the shape of inputs, or inputs are not floating point dtype.