Download this example as Jupyter notebook

Moment of inertia (regression of equivariant properties)

This tutorial demonstrates how to use E3x to construct a simple model for the prediction of equivariant properties. In this toy example, we want to predict the moment of inertia tensor of a collection of point masses. The \(3\times 3\) inertia tensor \(\mathbf{I}\) for a collection of \(N\) point masses with masses \(m_i\) and positions \(\vec{r}_i = [x_i\ y_i\ z_i]\) is given by

\[\begin{split}\mathbf{I} = \begin{bmatrix} I_{xx} & I_{xy} & I_{xz} \\ I_{yx} & I_{yy} & I_{yz} \\ I_{zx} & I_{zy} & I_{zz} \\ \end{bmatrix}\end{split}\]

with the components

\[I_{\alpha\beta} = \sum_{i=1}^{N} m_i \left(\lVert \vec{r}_i \rVert^2\delta_{\alpha\beta} - \alpha_i\beta_i \right)\]

where \(\alpha\) and \(\beta\) can be either \(x\), \(y\), or \(z\), and \(\delta_{\alpha\beta}\) is \(1\) if \(\alpha = \beta\) and \(0\) otherwise.

First, all necessary packages are imported.

[1]:
import functools
import e3x
from flax import linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax

# Disable future warnings.
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

Next, we define a function that generates a dataset by randomly drawing positions and masses and calculating the corresponding moment of inertia tensor.

[2]:
def calculate_moment_of_inertia_tensor(masses, positions):
  diag = jnp.sum(positions**2, axis=-1)[..., None, None]*jnp.eye(3)
  outer = positions[..., None, :] * positions[..., :, None]
  return jnp.sum(masses[..., None, None] * (diag - outer), axis=-3)

def generate_datasets(key, num_train=1000, num_valid=100, num_points=10, min_mass=0.0, max_mass=1.0, stdev=1.0):
  # Generate random keys.
  train_position_key, train_masses_key, valid_position_key, valid_masses_key = jax.random.split(key, num=4)

  # Draw random point masses with random positions.
  train_positions = stdev * jax.random.normal(train_position_key,  shape=(num_train, num_points, 3))
  train_masses = jax.random.uniform(train_masses_key, shape=(num_train, num_points), minval=min_mass, maxval=max_mass)
  valid_positions = stdev * jax.random.normal(valid_position_key,  shape=(num_valid, num_points, 3))
  valid_masses = jax.random.uniform(valid_masses_key, shape=(num_valid, num_points), minval=min_mass, maxval=max_mass)

  # Calculate moment of inertia tensors.
  train_inertia_tensor = calculate_moment_of_inertia_tensor(train_masses, train_positions)
  valid_inertia_tensor = calculate_moment_of_inertia_tensor(valid_masses, valid_positions)

  # Return final train and validation datasets.
  train_data = dict(positions=train_positions, masses=train_masses, inertia_tensor=train_inertia_tensor)
  valid_data = dict(positions=valid_positions, masses=valid_masses, inertia_tensor=valid_inertia_tensor)
  return train_data, valid_data

We now define an equivariant model to solve this regression task using the modules in E3x. The architecture takes as input masses and positions and outputs a \(3\times3\) matrix, or rather, a second order tensor. It comprises the following steps:

  1. Initialize features by concatenating masses and positions and reshaping to match the feature shape conventions used in E3x.

  2. Apply the following transformations: First we project the mass-position features to a features-dimensional feature space using a Dense layer. Next, a TensorDense layer is applied to allow coupling between the irreps \(\mathbb{0}\) (scalars) and \(\mathbb{1}\) (vectors). A second TensorDense layer is applied, because in general, to predict an arbitrary second order tensor, we need (even) irreps \(\mathbb{0}\), \(\mathbb{1}\), and \(\mathbb{2}\) (since \(\mathbb{1} \otimes \mathbb{1} = \mathbb{0}\oplus\mathbb{1}\oplus\mathbb{2}\)). Thus, the features are “elevated” from maximum degree \(1\) (scalars and vectors) to max_degree=2. Further, since we only want to predict a single second order tensor, the layer also maps from the features-dimensional feature space to single output irreps \(\mathbb{0}\), \(\mathbb{1}\), and \(\mathbb{2}\) (features=1). Note: Since the moment of inertia tensor is symmetric, it really only consists of irreps \(\mathbb{0}\) and \(\mathbb{2}\). We could thus zero out the irrep of degree \(1\) to only predict symmetric tensors. However, let’s pretend that we do not know this, the model should learn to predict (almost) symmetric tensors anyway.

  3. Sum over contributions from individual points.

  4. Build the \(3\times3\) tensor from the irreps by applying the Clebsch-Gordan rule backwards: \(\mathbb{0}\oplus\mathbb{1}\oplus\mathbb{2} = \mathbb{1} \otimes \mathbb{1}\).

[3]:
class Model(nn.Module):
  features = 8
  max_degree = 1

  @nn.compact
  def __call__(self, masses, positions):  # Shapes (..., N) and (..., N, 3).
    # 1. Initialize features.
    x = jnp.concatenate((masses[..., None], positions), axis=-1) # Shape (..., N, 4).
    x = x[..., None, :, None]  # Shape (..., N, 1, 4, 1).

    # 2. Apply transformations.
    x = e3x.nn.Dense(features=self.features)(x)  # Shape (..., N, 1, 4, features).
    x = e3x.nn.TensorDense(max_degree=self.max_degree)(x)  # Shape (..., N, 2, (max_degree+1)**2, features).
    x = e3x.nn.TensorDense(  # Shape (..., N, 2, 9, 1).
        features=1,
        max_degree=2,
    )(x)
    # Try it: Zero-out irrep of degree 1 to only produce symmetric output tensors.
    # x = x.at[..., :, 1:4, :].set(0)

    # 3. Collect even irreps from feature channel 0 and sum over contributions from individual points.
    x = jnp.sum(x[..., 0, :, 0], axis=-2)  # Shape (..., (max_degree+1)**2).

    # 4. Convert output irreps to 3x3 matrix and return.
    cg = e3x.so3.clebsch_gordan(max_degree1=1, max_degree2=1, max_degree3=2)  # Shape (4, 4, 9).
    y = jnp.einsum('...l,nml->...nm', x, cg[1:, 1:, :])  # Shape (..., 3, 3).
    return y

Next, we define our loss function. As is common for regression tasks, we choose the \(L_2\) (squared error) loss.

[4]:
def mean_squared_loss(prediction, target):
  return jnp.mean(optax.l2_loss(prediction, target))

Now that we have all ingredients, let’s write some boilerplate for training models.

[5]:
@functools.partial(jax.jit, static_argnames=('model_apply', 'optimizer_update'))
def train_step(model_apply, optimizer_update, batch, opt_state, params):
  def loss_fn(params):
    inertia_tensor = model_apply(params, batch['masses'], batch['positions'])
    loss = mean_squared_loss(inertia_tensor, batch['inertia_tensor'])
    return loss
  loss, grad = jax.value_and_grad(loss_fn)(params)
  updates, opt_state = optimizer_update(grad, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss

@functools.partial(jax.jit, static_argnames=('model_apply',))
def eval_step(model_apply, batch, params):
  inertia_tensor = model_apply(params, batch['masses'], batch['positions'])
  loss = mean_squared_loss(inertia_tensor, batch['inertia_tensor'])
  return loss

def train_model(key, model, train_data, valid_data, num_epochs, learning_rate, batch_size):
  # Initialize model parameters and optimizer state.
  key, init_key = jax.random.split(key)
  optimizer = optax.adam(learning_rate)
  params = model.init(init_key, train_data['masses'][0:1], train_data['positions'][0:1])
  opt_state = optimizer.init(params)

  # Determine the number of training steps per epoch.
  train_size = len(train_data['masses'])
  steps_per_epoch = train_size//batch_size

  # Train for 'num_epochs' epochs.
  for epoch in range(1, num_epochs + 1):
    # Draw random permutations for fetching batches from the train data.
    key, shuffle_key = jax.random.split(key)
    perms = jax.random.permutation(shuffle_key, train_size)
    perms = perms[:steps_per_epoch * batch_size]  # Skip the last batch (if incomplete).
    perms = perms.reshape((steps_per_epoch, batch_size))

    # Loop over all batches.
    train_loss = 0.0  # For keeping a running average of the loss.
    for i, perm in enumerate(perms):
      batch = {k: v[perm, ...] for k, v in train_data.items()}
      params, opt_state, loss = train_step(
          model_apply=model.apply,
          optimizer_update=optimizer.update,
          batch=batch,
          opt_state=opt_state,
          params=params
      )
      train_loss += (loss - train_loss)/(i+1)

    # Evaluate on the test set after each training epoch.
    valid_loss = eval_step(
        model_apply=model.apply,
        batch=valid_data,
        params=params
    )

    # Print progress.
    print(f"epoch {epoch : 4d} train loss {train_loss : 8.6f} valid loss {valid_loss : 8.6f}")

  # Return final model parameters.
  return params

Finally, let’s create our toy dataset and choose the training hyperparameters.

[6]:
# Initialize PRNGKey for random number generation.
key = jax.random.PRNGKey(0)

# Generate train and test datasets.
key, data_key = jax.random.split(key)
train_data, valid_data = generate_datasets(data_key)

# Define training hyperparameters.
learning_rate = 0.002
num_epochs = 100
batch_size = 32

Now, we train our model.

[7]:
# Train the model.
key, train_key = jax.random.split(key)
model = Model()
params = train_model(
  key=train_key,
  model=model,
  train_data=train_data,
  valid_data=valid_data,
  num_epochs=num_epochs,
  learning_rate=learning_rate,
  batch_size=batch_size,
)
epoch    1 train loss  1.359933 valid loss  0.650806
epoch    2 train loss  0.471154 valid loss  0.361696
epoch    3 train loss  0.355795 valid loss  0.330646
epoch    4 train loss  0.335975 valid loss  0.313806
epoch    5 train loss  0.313707 valid loss  0.307905
epoch    6 train loss  0.295819 valid loss  0.261203
epoch    7 train loss  0.269274 valid loss  0.236152
epoch    8 train loss  0.247977 valid loss  0.230414
epoch    9 train loss  0.231734 valid loss  0.205375
epoch   10 train loss  0.225083 valid loss  0.209193
epoch   11 train loss  0.207602 valid loss  0.188981
epoch   12 train loss  0.200761 valid loss  0.185399
epoch   13 train loss  0.190793 valid loss  0.175384
epoch   14 train loss  0.178643 valid loss  0.169232
epoch   15 train loss  0.161587 valid loss  0.144264
epoch   16 train loss  0.147181 valid loss  0.133549
epoch   17 train loss  0.129426 valid loss  0.109603
epoch   18 train loss  0.105608 valid loss  0.088042
epoch   19 train loss  0.089911 valid loss  0.065109
epoch   20 train loss  0.063822 valid loss  0.046581
epoch   21 train loss  0.042836 valid loss  0.039833
epoch   22 train loss  0.045359 valid loss  0.037879
epoch   23 train loss  0.040164 valid loss  0.048566
epoch   24 train loss  0.041613 valid loss  0.038852
epoch   25 train loss  0.037659 valid loss  0.036376
epoch   26 train loss  0.034417 valid loss  0.038344
epoch   27 train loss  0.035188 valid loss  0.030785
epoch   28 train loss  0.033791 valid loss  0.031197
epoch   29 train loss  0.033894 valid loss  0.027737
epoch   30 train loss  0.033661 valid loss  0.030736
epoch   31 train loss  0.031063 valid loss  0.025989
epoch   32 train loss  0.029492 valid loss  0.026486
epoch   33 train loss  0.029381 valid loss  0.024327
epoch   34 train loss  0.028494 valid loss  0.023874
epoch   35 train loss  0.029745 valid loss  0.028929
epoch   36 train loss  0.028127 valid loss  0.023014
epoch   37 train loss  0.026591 valid loss  0.024094
epoch   38 train loss  0.030150 valid loss  0.028875
epoch   39 train loss  0.029794 valid loss  0.023051
epoch   40 train loss  0.031116 valid loss  0.028424
epoch   41 train loss  0.028212 valid loss  0.021998
epoch   42 train loss  0.025479 valid loss  0.021279
epoch   43 train loss  0.025208 valid loss  0.024125
epoch   44 train loss  0.026061 valid loss  0.021150
epoch   45 train loss  0.033841 valid loss  0.057384
epoch   46 train loss  0.027158 valid loss  0.020290
epoch   47 train loss  0.023987 valid loss  0.019953
epoch   48 train loss  0.024759 valid loss  0.024593
epoch   49 train loss  0.025928 valid loss  0.024374
epoch   50 train loss  0.023460 valid loss  0.018815
epoch   51 train loss  0.024572 valid loss  0.019993
epoch   52 train loss  0.022887 valid loss  0.018372
epoch   53 train loss  0.026181 valid loss  0.025382
epoch   54 train loss  0.025671 valid loss  0.021562
epoch   55 train loss  0.023084 valid loss  0.017371
epoch   56 train loss  0.024710 valid loss  0.019425
epoch   57 train loss  0.029084 valid loss  0.029276
epoch   58 train loss  0.022432 valid loss  0.015814
epoch   59 train loss  0.020230 valid loss  0.016195
epoch   60 train loss  0.018062 valid loss  0.014702
epoch   61 train loss  0.016403 valid loss  0.013027
epoch   62 train loss  0.016155 valid loss  0.012115
epoch   63 train loss  0.013847 valid loss  0.010955
epoch   64 train loss  0.014790 valid loss  0.011325
epoch   65 train loss  0.019166 valid loss  0.010374
epoch   66 train loss  0.012386 valid loss  0.007711
epoch   67 train loss  0.009181 valid loss  0.006591
epoch   68 train loss  0.006980 valid loss  0.005574
epoch   69 train loss  0.005331 valid loss  0.004197
epoch   70 train loss  0.006010 valid loss  0.013546
epoch   71 train loss  0.004674 valid loss  0.004104
epoch   72 train loss  0.002438 valid loss  0.001107
epoch   73 train loss  0.001398 valid loss  0.000977
epoch   74 train loss  0.001085 valid loss  0.000696
epoch   75 train loss  0.000974 valid loss  0.000871
epoch   76 train loss  0.000811 valid loss  0.000642
epoch   77 train loss  0.000723 valid loss  0.000854
epoch   78 train loss  0.000840 valid loss  0.001953
epoch   79 train loss  0.000751 valid loss  0.000519
epoch   80 train loss  0.000636 valid loss  0.000593
epoch   81 train loss  0.000609 valid loss  0.000431
epoch   82 train loss  0.000456 valid loss  0.000393
epoch   83 train loss  0.000414 valid loss  0.000432
epoch   84 train loss  0.000414 valid loss  0.000325
epoch   85 train loss  0.000374 valid loss  0.000378
epoch   86 train loss  0.000338 valid loss  0.000319
epoch   87 train loss  0.000391 valid loss  0.000368
epoch   88 train loss  0.000297 valid loss  0.000303
epoch   89 train loss  0.000289 valid loss  0.000255
epoch   90 train loss  0.000279 valid loss  0.000331
epoch   91 train loss  0.000330 valid loss  0.000231
epoch   92 train loss  0.000239 valid loss  0.000275
epoch   93 train loss  0.000244 valid loss  0.000312
epoch   94 train loss  0.000205 valid loss  0.000246
epoch   95 train loss  0.000211 valid loss  0.000307
epoch   96 train loss  0.000207 valid loss  0.000279
epoch   97 train loss  0.000195 valid loss  0.000168
epoch   98 train loss  0.000177 valid loss  0.000205
epoch   99 train loss  0.000346 valid loss  0.000148
epoch  100 train loss  0.000156 valid loss  0.000148

The loss goes down very quickly. With longer training, it would be possible to eventually reach a loss of virtually zero. However, the current value seems low enough. Let’s verify that our model really predicts the correct moment of inertia tensor by evaluating it on an entry from the validation set and comparing with the true value.

[8]:
i = 0
masses, positions, target = valid_data['masses'][i], valid_data['positions'][i], valid_data['inertia_tensor'][i]
prediction = model.apply(params, masses, positions)

print('target')
print(target)
print('prediction')
print(prediction)
print('mean squared error', jnp.mean((prediction-target)**2))
target
[[ 6.013584    1.6290329  -0.17871115]
 [ 1.6290329   4.8540945   0.73430276]
 [-0.17871115  0.73430276  6.1854286 ]]
prediction
[[ 6.0166373   1.6341103  -0.17790869]
 [ 1.634112    4.849129    0.73786813]
 [-0.1779084   0.73786616  6.182283  ]]
mean squared error 1.3572028e-05

That looks pretty good! But is our model really equivariant? Let’s try to randomly rotate the input positions. The output of our model should rotate accordingly.

[9]:
key, rotation_key = jax.random.split(key)
rotation = e3x.so3.random_rotation(rotation_key)
rotated_positions = positions@rotation
rotated_target = calculate_moment_of_inertia_tensor(masses, rotated_positions)
rotated_prediction = model.apply(params, masses, rotated_positions)

print('rotated target')
print(rotated_target)
print('rotated prediction')
print(rotated_prediction)
print('mean squared error', jnp.mean((rotated_prediction-rotated_target)**2))
rotated target
[[ 6.4829473  1.0275574  1.0426031]
 [ 1.0275574  4.921893  -0.9980936]
 [ 1.0426031 -0.9980936  5.6482677]]
rotated prediction
[[ 6.487672   1.0301298  1.0456917]
 [ 1.0301319  4.914525  -1.0000154]
 [ 1.0456927 -1.0000156  5.645855 ]]
mean squared error 1.3572109e-05

Notice that the individual entries of the moment of inertia tensor have changed quite a bit, but the prediction is still just as good (up to very small differences due to imprecisions of floating point arithmetic). This is the power of equivariant models!