Moment of inertia (regression of equivariant properties)
This tutorial demonstrates how to use E3x to construct a simple model for the prediction of equivariant properties. In this toy example, we want to predict the moment of inertia tensor of a collection of point masses. The \(3\times 3\) inertia tensor \(\mathbf{I}\) for a collection of \(N\) point masses with masses \(m_i\) and positions \(\vec{r}_i = [x_i\ y_i\ z_i]\) is given by
with the components
where \(\alpha\) and \(\beta\) can be either \(x\), \(y\), or \(z\), and \(\delta_{\alpha\beta}\) is \(1\) if \(\alpha = \beta\) and \(0\) otherwise.
First, all necessary packages are imported.
[1]:
import functools
import e3x
from flax import linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
# Disable future warnings.
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
Next, we define a function that generates a dataset by randomly drawing positions and masses and calculating the corresponding moment of inertia tensor.
[2]:
def calculate_moment_of_inertia_tensor(masses, positions):
diag = jnp.sum(positions**2, axis=-1)[..., None, None]*jnp.eye(3)
outer = positions[..., None, :] * positions[..., :, None]
return jnp.sum(masses[..., None, None] * (diag - outer), axis=-3)
def generate_datasets(key, num_train=1000, num_valid=100, num_points=10, min_mass=0.0, max_mass=1.0, stdev=1.0):
# Generate random keys.
train_position_key, train_masses_key, valid_position_key, valid_masses_key = jax.random.split(key, num=4)
# Draw random point masses with random positions.
train_positions = stdev * jax.random.normal(train_position_key, shape=(num_train, num_points, 3))
train_masses = jax.random.uniform(train_masses_key, shape=(num_train, num_points), minval=min_mass, maxval=max_mass)
valid_positions = stdev * jax.random.normal(valid_position_key, shape=(num_valid, num_points, 3))
valid_masses = jax.random.uniform(valid_masses_key, shape=(num_valid, num_points), minval=min_mass, maxval=max_mass)
# Calculate moment of inertia tensors.
train_inertia_tensor = calculate_moment_of_inertia_tensor(train_masses, train_positions)
valid_inertia_tensor = calculate_moment_of_inertia_tensor(valid_masses, valid_positions)
# Return final train and validation datasets.
train_data = dict(positions=train_positions, masses=train_masses, inertia_tensor=train_inertia_tensor)
valid_data = dict(positions=valid_positions, masses=valid_masses, inertia_tensor=valid_inertia_tensor)
return train_data, valid_data
We now define an equivariant model to solve this regression task using the modules in E3x. The architecture takes as input masses and positions and outputs a \(3\times3\) matrix, or rather, a second order tensor. It comprises the following steps:
Initialize features by concatenating masses and positions and reshaping to match the feature shape conventions used in E3x.
Apply the following transformations: First we project the mass-position features to a
features
-dimensional feature space using aDense
layer. Next, aTensorDense
layer is applied to allow coupling between the irreps \(\mathbb{0}\) (scalars) and \(\mathbb{1}\) (vectors). A secondTensorDense
layer is applied, because in general, to predict an arbitrary second order tensor, we need (even) irreps \(\mathbb{0}\), \(\mathbb{1}\), and \(\mathbb{2}\) (since \(\mathbb{1} \otimes \mathbb{1} = \mathbb{0}\oplus\mathbb{1}\oplus\mathbb{2}\)). Thus, the features are “elevated” from maximum degree \(1\) (scalars and vectors) tomax_degree=2
. Further, since we only want to predict a single second order tensor, the layer also maps from thefeatures
-dimensional feature space to single output irreps \(\mathbb{0}\), \(\mathbb{1}\), and \(\mathbb{2}\) (features=1
). Note: Since the moment of inertia tensor is symmetric, it really only consists of irreps \(\mathbb{0}\) and \(\mathbb{2}\). We could thus zero out the irrep of degree \(1\) to only predict symmetric tensors. However, let’s pretend that we do not know this, the model should learn to predict (almost) symmetric tensors anyway.Sum over contributions from individual points.
Build the \(3\times3\) tensor from the irreps by applying the Clebsch-Gordan rule backwards: \(\mathbb{0}\oplus\mathbb{1}\oplus\mathbb{2} = \mathbb{1} \otimes \mathbb{1}\).
[3]:
class Model(nn.Module):
features = 8
max_degree = 1
@nn.compact
def __call__(self, masses, positions): # Shapes (..., N) and (..., N, 3).
# 1. Initialize features.
x = jnp.concatenate((masses[..., None], positions), axis=-1) # Shape (..., N, 4).
x = x[..., None, :, None] # Shape (..., N, 1, 4, 1).
# 2. Apply transformations.
x = e3x.nn.Dense(features=self.features)(x) # Shape (..., N, 1, 4, features).
x = e3x.nn.TensorDense(max_degree=self.max_degree)(x) # Shape (..., N, 2, (max_degree+1)**2, features).
x = e3x.nn.TensorDense( # Shape (..., N, 2, 9, 1).
features=1,
max_degree=2,
)(x)
# Try it: Zero-out irrep of degree 1 to only produce symmetric output tensors.
# x = x.at[..., :, 1:4, :].set(0)
# 3. Collect even irreps from feature channel 0 and sum over contributions from individual points.
x = jnp.sum(x[..., 0, :, 0], axis=-2) # Shape (..., (max_degree+1)**2).
# 4. Convert output irreps to 3x3 matrix and return.
cg = e3x.so3.clebsch_gordan(max_degree1=1, max_degree2=1, max_degree3=2) # Shape (4, 4, 9).
y = jnp.einsum('...l,nml->...nm', x, cg[1:, 1:, :]) # Shape (..., 3, 3).
return y
Next, we define our loss function. As is common for regression tasks, we choose the \(L_2\) (squared error) loss.
[4]:
def mean_squared_loss(prediction, target):
return jnp.mean(optax.l2_loss(prediction, target))
Now that we have all ingredients, let’s write some boilerplate for training models.
[5]:
@functools.partial(jax.jit, static_argnames=('model_apply', 'optimizer_update'))
def train_step(model_apply, optimizer_update, batch, opt_state, params):
def loss_fn(params):
inertia_tensor = model_apply(params, batch['masses'], batch['positions'])
loss = mean_squared_loss(inertia_tensor, batch['inertia_tensor'])
return loss
loss, grad = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer_update(grad, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
@functools.partial(jax.jit, static_argnames=('model_apply',))
def eval_step(model_apply, batch, params):
inertia_tensor = model_apply(params, batch['masses'], batch['positions'])
loss = mean_squared_loss(inertia_tensor, batch['inertia_tensor'])
return loss
def train_model(key, model, train_data, valid_data, num_epochs, learning_rate, batch_size):
# Initialize model parameters and optimizer state.
key, init_key = jax.random.split(key)
optimizer = optax.adam(learning_rate)
params = model.init(init_key, train_data['masses'][0:1], train_data['positions'][0:1])
opt_state = optimizer.init(params)
# Determine the number of training steps per epoch.
train_size = len(train_data['masses'])
steps_per_epoch = train_size//batch_size
# Train for 'num_epochs' epochs.
for epoch in range(1, num_epochs + 1):
# Draw random permutations for fetching batches from the train data.
key, shuffle_key = jax.random.split(key)
perms = jax.random.permutation(shuffle_key, train_size)
perms = perms[:steps_per_epoch * batch_size] # Skip the last batch (if incomplete).
perms = perms.reshape((steps_per_epoch, batch_size))
# Loop over all batches.
train_loss = 0.0 # For keeping a running average of the loss.
for i, perm in enumerate(perms):
batch = {k: v[perm, ...] for k, v in train_data.items()}
params, opt_state, loss = train_step(
model_apply=model.apply,
optimizer_update=optimizer.update,
batch=batch,
opt_state=opt_state,
params=params
)
train_loss += (loss - train_loss)/(i+1)
# Evaluate on the test set after each training epoch.
valid_loss = eval_step(
model_apply=model.apply,
batch=valid_data,
params=params
)
# Print progress.
print(f"epoch {epoch : 4d} train loss {train_loss : 8.6f} valid loss {valid_loss : 8.6f}")
# Return final model parameters.
return params
Finally, let’s create our toy dataset and choose the training hyperparameters.
[6]:
# Initialize PRNGKey for random number generation.
key = jax.random.PRNGKey(0)
# Generate train and test datasets.
key, data_key = jax.random.split(key)
train_data, valid_data = generate_datasets(data_key)
# Define training hyperparameters.
learning_rate = 0.002
num_epochs = 100
batch_size = 32
Now, we train our model.
[7]:
# Train the model.
key, train_key = jax.random.split(key)
model = Model()
params = train_model(
key=train_key,
model=model,
train_data=train_data,
valid_data=valid_data,
num_epochs=num_epochs,
learning_rate=learning_rate,
batch_size=batch_size,
)
epoch 1 train loss 1.359933 valid loss 0.650806
epoch 2 train loss 0.471154 valid loss 0.361696
epoch 3 train loss 0.355795 valid loss 0.330646
epoch 4 train loss 0.335975 valid loss 0.313806
epoch 5 train loss 0.313707 valid loss 0.307905
epoch 6 train loss 0.295819 valid loss 0.261203
epoch 7 train loss 0.269274 valid loss 0.236152
epoch 8 train loss 0.247977 valid loss 0.230414
epoch 9 train loss 0.231734 valid loss 0.205375
epoch 10 train loss 0.225083 valid loss 0.209193
epoch 11 train loss 0.207602 valid loss 0.188981
epoch 12 train loss 0.200761 valid loss 0.185399
epoch 13 train loss 0.190793 valid loss 0.175384
epoch 14 train loss 0.178643 valid loss 0.169232
epoch 15 train loss 0.161587 valid loss 0.144264
epoch 16 train loss 0.147181 valid loss 0.133549
epoch 17 train loss 0.129426 valid loss 0.109603
epoch 18 train loss 0.105608 valid loss 0.088042
epoch 19 train loss 0.089911 valid loss 0.065109
epoch 20 train loss 0.063822 valid loss 0.046581
epoch 21 train loss 0.042836 valid loss 0.039833
epoch 22 train loss 0.045359 valid loss 0.037879
epoch 23 train loss 0.040164 valid loss 0.048566
epoch 24 train loss 0.041613 valid loss 0.038852
epoch 25 train loss 0.037659 valid loss 0.036376
epoch 26 train loss 0.034417 valid loss 0.038344
epoch 27 train loss 0.035188 valid loss 0.030785
epoch 28 train loss 0.033791 valid loss 0.031197
epoch 29 train loss 0.033894 valid loss 0.027737
epoch 30 train loss 0.033661 valid loss 0.030736
epoch 31 train loss 0.031063 valid loss 0.025989
epoch 32 train loss 0.029492 valid loss 0.026486
epoch 33 train loss 0.029381 valid loss 0.024327
epoch 34 train loss 0.028494 valid loss 0.023874
epoch 35 train loss 0.029745 valid loss 0.028929
epoch 36 train loss 0.028127 valid loss 0.023014
epoch 37 train loss 0.026591 valid loss 0.024094
epoch 38 train loss 0.030150 valid loss 0.028875
epoch 39 train loss 0.029794 valid loss 0.023051
epoch 40 train loss 0.031116 valid loss 0.028424
epoch 41 train loss 0.028212 valid loss 0.021998
epoch 42 train loss 0.025479 valid loss 0.021279
epoch 43 train loss 0.025208 valid loss 0.024125
epoch 44 train loss 0.026061 valid loss 0.021150
epoch 45 train loss 0.033841 valid loss 0.057384
epoch 46 train loss 0.027158 valid loss 0.020290
epoch 47 train loss 0.023987 valid loss 0.019953
epoch 48 train loss 0.024759 valid loss 0.024593
epoch 49 train loss 0.025928 valid loss 0.024374
epoch 50 train loss 0.023460 valid loss 0.018815
epoch 51 train loss 0.024572 valid loss 0.019993
epoch 52 train loss 0.022887 valid loss 0.018372
epoch 53 train loss 0.026181 valid loss 0.025382
epoch 54 train loss 0.025671 valid loss 0.021562
epoch 55 train loss 0.023084 valid loss 0.017371
epoch 56 train loss 0.024710 valid loss 0.019425
epoch 57 train loss 0.029084 valid loss 0.029276
epoch 58 train loss 0.022432 valid loss 0.015814
epoch 59 train loss 0.020230 valid loss 0.016195
epoch 60 train loss 0.018062 valid loss 0.014702
epoch 61 train loss 0.016403 valid loss 0.013027
epoch 62 train loss 0.016155 valid loss 0.012115
epoch 63 train loss 0.013847 valid loss 0.010955
epoch 64 train loss 0.014790 valid loss 0.011325
epoch 65 train loss 0.019166 valid loss 0.010374
epoch 66 train loss 0.012386 valid loss 0.007711
epoch 67 train loss 0.009181 valid loss 0.006591
epoch 68 train loss 0.006980 valid loss 0.005574
epoch 69 train loss 0.005331 valid loss 0.004197
epoch 70 train loss 0.006010 valid loss 0.013546
epoch 71 train loss 0.004674 valid loss 0.004104
epoch 72 train loss 0.002438 valid loss 0.001107
epoch 73 train loss 0.001398 valid loss 0.000977
epoch 74 train loss 0.001085 valid loss 0.000696
epoch 75 train loss 0.000974 valid loss 0.000871
epoch 76 train loss 0.000811 valid loss 0.000642
epoch 77 train loss 0.000723 valid loss 0.000854
epoch 78 train loss 0.000840 valid loss 0.001953
epoch 79 train loss 0.000751 valid loss 0.000519
epoch 80 train loss 0.000636 valid loss 0.000593
epoch 81 train loss 0.000609 valid loss 0.000431
epoch 82 train loss 0.000456 valid loss 0.000393
epoch 83 train loss 0.000414 valid loss 0.000432
epoch 84 train loss 0.000414 valid loss 0.000325
epoch 85 train loss 0.000374 valid loss 0.000378
epoch 86 train loss 0.000338 valid loss 0.000319
epoch 87 train loss 0.000391 valid loss 0.000368
epoch 88 train loss 0.000297 valid loss 0.000303
epoch 89 train loss 0.000289 valid loss 0.000255
epoch 90 train loss 0.000279 valid loss 0.000331
epoch 91 train loss 0.000330 valid loss 0.000231
epoch 92 train loss 0.000239 valid loss 0.000275
epoch 93 train loss 0.000244 valid loss 0.000312
epoch 94 train loss 0.000205 valid loss 0.000246
epoch 95 train loss 0.000211 valid loss 0.000307
epoch 96 train loss 0.000207 valid loss 0.000279
epoch 97 train loss 0.000195 valid loss 0.000168
epoch 98 train loss 0.000177 valid loss 0.000205
epoch 99 train loss 0.000346 valid loss 0.000148
epoch 100 train loss 0.000156 valid loss 0.000148
The loss goes down very quickly. With longer training, it would be possible to eventually reach a loss of virtually zero. However, the current value seems low enough. Let’s verify that our model really predicts the correct moment of inertia tensor by evaluating it on an entry from the validation set and comparing with the true value.
[8]:
i = 0
masses, positions, target = valid_data['masses'][i], valid_data['positions'][i], valid_data['inertia_tensor'][i]
prediction = model.apply(params, masses, positions)
print('target')
print(target)
print('prediction')
print(prediction)
print('mean squared error', jnp.mean((prediction-target)**2))
target
[[ 6.013584 1.6290329 -0.17871115]
[ 1.6290329 4.8540945 0.73430276]
[-0.17871115 0.73430276 6.1854286 ]]
prediction
[[ 6.0166373 1.6341103 -0.17790869]
[ 1.634112 4.849129 0.73786813]
[-0.1779084 0.73786616 6.182283 ]]
mean squared error 1.3572028e-05
That looks pretty good! But is our model really equivariant? Let’s try to randomly rotate the input positions. The output of our model should rotate accordingly.
[9]:
key, rotation_key = jax.random.split(key)
rotation = e3x.so3.random_rotation(rotation_key)
rotated_positions = positions@rotation
rotated_target = calculate_moment_of_inertia_tensor(masses, rotated_positions)
rotated_prediction = model.apply(params, masses, rotated_positions)
print('rotated target')
print(rotated_target)
print('rotated prediction')
print(rotated_prediction)
print('mean squared error', jnp.mean((rotated_prediction-rotated_target)**2))
rotated target
[[ 6.4829473 1.0275574 1.0426031]
[ 1.0275574 4.921893 -0.9980936]
[ 1.0426031 -0.9980936 5.6482677]]
rotated prediction
[[ 6.487672 1.0301298 1.0456917]
[ 1.0301319 4.914525 -1.0000154]
[ 1.0456927 -1.0000156 5.645855 ]]
mean squared error 1.3572109e-05
Notice that the individual entries of the moment of inertia tensor have changed quite a bit, but the prediction is still just as good (up to very small differences due to imprecisions of floating point arithmetic). This is the power of equivariant models!