MD17: Ethanol (force field construction)
This tutorial demonstrates how to use E3x to construct a machine-learned force field. For this example, we use the MD17 dataset for ethanol. The code is written to be easy to adapt to datasets that contain multiple molecules of different size.
First, all necessary packages are imported.
[1]:
import functools
import os
import urllib.request
import e3x
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
# Disable future warnings.
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
Next, we download the MD17 dataset for ethanol.
[2]:
# Download the dataset.
filename = "md17_ethanol.npz"
if not os.path.exists(filename):
print(f"Downloading {filename} (this may take a while)...")
urllib.request.urlretrieve(f"http://www.quantum-machine.org/gdml/data/npz/{filename}", filename)
Now we write a small helper function to prepare train and validation datasets. Apart from choosing subsets of the full dataset and converting the data to a different format, this function also subtracts the mean energy of the training set from all data points. This is a useful numerical trick: The energy of a molecule is often a very large number (around \(-97 000\) kcal/mol for ethanol), but the relative energy between different conformations of the same molecule is usually much smaller (a few kcal/mol for ethanol). Since models are typically trained with single precision floats for efficiency, the large energy offset wastes precious numerical precision. For this reason, large offsets should be subtracted frow the raw data (typically stored in double precision) before converting to single precision. Note that subtracting arbitrary constants from the energy does not change physics (of course, the constant can also just be added back to the model predictions after training).
[3]:
def prepare_datasets(key, num_train, num_valid):
# Load the dataset.
dataset = np.load(filename)
# Make sure that the dataset contains enough entries.
num_data = len(dataset['E'])
num_draw = num_train + num_valid
if num_draw > num_data:
raise RuntimeError(
f'datasets only contains {num_data} points, requested num_train={num_train}, num_valid={num_valid}')
# Randomly draw train and validation sets from dataset.
choice = np.asarray(jax.random.choice(key, num_data, shape=(num_draw,), replace=False))
train_choice = choice[:num_train]
valid_choice = choice[num_train:]
# Determine mean energy of the training set.
mean_energy = np.mean(dataset['E'][train_choice]) # ~ -97000
# Collect and return train and validation sets.
train_data = dict(
energy=jnp.asarray(dataset['E'][train_choice, 0] - mean_energy),
forces=jnp.asarray(dataset['F'][train_choice]),
atomic_numbers=jnp.asarray(dataset['z']),
positions=jnp.asarray(dataset['R'][train_choice]),
)
valid_data = dict(
energy=jnp.asarray(dataset['E'][valid_choice, 0] - mean_energy),
forces=jnp.asarray(dataset['F'][valid_choice]),
atomic_numbers=jnp.asarray(dataset['z']),
positions=jnp.asarray(dataset['R'][valid_choice]),
)
return train_data, valid_data, mean_energy
Next, we define a very simple message-passing neural network.
Let’s first describe the model inputs and outputs: Our model takes atomic_numbers
of shape (num_atoms,)
(encoding the atom types), positions
of shape (num_atoms, 3)
, and index lists dst_idx
(destination index) and src_idx
(source index), both of shape (num_pairs,)
, as inputs. The dst_idx
and src_idx
index lists specify “which atoms talk with each other” during a message-pass. For example, for a molecule with three atoms, these index lists could be
dst_idx = [0, 0, 1, 1, 2, 2]
and src_idx = [1, 2, 0, 2, 0, 1]
. This would mean that the “destination atom” at index 0 receives messages from the “source atoms” at indices 1 and 2, and so on. This input format allows the model to handle batches of molecules with different numbers of atoms without needing any padding (in this example, all molecules are ethanol, so this is not necessary, but it’s useful in general for more complicated tasks). We simply concatenate the atomic_numbers
and positions
of all molecules in the batch (as if they were one big molecule) and specify the index lists such that atoms in different molecules of the batch do not “talk to each other”. The only additional information we need is a structure batch_segments
that tells the model which atoms belong to which molecule in the batch and the total batch_size
. For example, a batch consisting of a H\(_2\)O and a N\(_2\) molecule could be specified with
atomic_numbers = [1, 1, 8, 7, 7]
, dst_idx = [0, 0, 1, 1, 2, 2, 3, 4]
, src_idx = [1, 2, 0, 2, 0, 1, 4, 3]
, batch_segments = [0, 0, 0, 1, 1]
(meaning the first three atoms belong to molecule 0
, the last two to molecule 1
), and batch_size = 2
. The outputs of our model are the energy
of shape (batch_size,)
of every molecule in the batch and the forces
of shape (num_atoms, 3)
acting on each atom.
Since forces
are the negative gradient of the energy
with respect to the atomic positions
, the model only needs to predict energy
, as forces
can be derived with automatic differentiation. This has the additional advantage that the forces will be conservative, i.e. they will respect the physical principle of energy conservation. The energy prediction comprises the following steps:
Calculate the displacement vectors \(\vec{r}_{ij}=\vec{r}_{j}-\vec{r}_{i}\), where \(\vec{r}_{i}\) and \(\vec{r}_{j}\) are the positions of atoms \(i\) (destination) and \(j\) (source), for all pairs \(i,j\) specified by the index lists.
Expand the displacement vectors in radial-spherical basis functions to featurize them. Note that the basis functions use a cutoff, so that they go to zero beyond a certain distance.
Embed the atoms in feature space by assigning them to learnable embeddings (one for each element).
Perform
num_iterations
feature refinements. Each iteration performs a message-pass and combines the message with the current features of each atom. These intermediate features are passed through an atom-wise two-layer MLP and added to the original features (residual connection).Predict atomic energy contributions by performing linear regression on the final (scalar) features of each atom. Each atom type/element has its own bias term.
Sum the atomic energy contributions within each batch segment to obtain the energy for every molecule in the batch.
[4]:
class MessagePassingModel(nn.Module):
features: int = 32
max_degree: int = 2
num_iterations: int = 3
num_basis_functions: int = 8
cutoff: float = 5.0
max_atomic_number: int = 118 # This is overkill for most applications.
def energy(self, atomic_numbers, positions, dst_idx, src_idx, batch_segments, batch_size):
# 1. Calculate displacement vectors.
positions_dst = e3x.ops.gather_dst(positions, dst_idx=dst_idx)
positions_src = e3x.ops.gather_src(positions, src_idx=src_idx)
displacements = positions_src - positions_dst # Shape (num_pairs, 3).
# 2. Expand displacement vectors in basis functions.
basis = e3x.nn.basis( # Shape (num_pairs, 1, (max_degree+1)**2, num_basis_functions).
displacements,
num=self.num_basis_functions,
max_degree=self.max_degree,
radial_fn=e3x.nn.reciprocal_bernstein,
cutoff_fn=functools.partial(e3x.nn.smooth_cutoff, cutoff=self.cutoff)
)
# 3. Embed atomic numbers in feature space, x has shape (num_atoms, 1, 1, features).
x = e3x.nn.Embed(num_embeddings=self.max_atomic_number+1, features=self.features)(atomic_numbers)
# 4. Perform iterations (message-passing + atom-wise refinement).
for i in range(self.num_iterations):
# Message-pass.
if i == self.num_iterations-1: # Final iteration.
# Since we will only use scalar features after the final message-pass, we do not want to produce non-scalar
# features for efficiency reasons.
y = e3x.nn.MessagePass(max_degree=0, include_pseudotensors=False)(x, basis, dst_idx=dst_idx, src_idx=src_idx)
# After the final message pass, we can safely throw away all non-scalar features.
x = e3x.nn.change_max_degree_or_type(x, max_degree=0, include_pseudotensors=False)
else:
# In intermediate iterations, the message-pass should consider all possible coupling paths.
y = e3x.nn.MessagePass()(x, basis, dst_idx=dst_idx, src_idx=src_idx)
y = e3x.nn.add(x, y)
# Atom-wise refinement MLP.
y = e3x.nn.Dense(self.features)(y)
y = e3x.nn.silu(y)
y = e3x.nn.Dense(self.features, kernel_init=jax.nn.initializers.zeros)(y)
# Residual connection.
x = e3x.nn.add(x, y)
# 5. Predict atomic energies with an ordinary dense layer.
element_bias = self.param('element_bias', lambda rng, shape: jnp.zeros(shape), (self.max_atomic_number+1))
atomic_energies = nn.Dense(1, use_bias=False, kernel_init=jax.nn.initializers.zeros)(x) # (..., Natoms, 1, 1, 1)
atomic_energies = jnp.squeeze(atomic_energies, axis=(-1, -2, -3)) # Squeeze last 3 dimensions.
atomic_energies += element_bias[atomic_numbers]
# 6. Sum atomic energies to obtain the total energy.
energy = jax.ops.segment_sum(atomic_energies, segment_ids=batch_segments, num_segments=batch_size)
# To be able to efficiently compute forces, our model should return a single output (instead of one for each
# molecule in the batch). Fortunately, since all atomic contributions only influence the energy in their own
# batch segment, we can simply sum the energy of all molecules in the batch to obtain a single proxy output
# to differentiate.
return -jnp.sum(energy), energy # Forces are the negative gradient, hence the minus sign.
@nn.compact
def __call__(self, atomic_numbers, positions, dst_idx, src_idx, batch_segments=None, batch_size=None):
if batch_segments is None:
batch_segments = jnp.zeros_like(atomic_numbers)
batch_size = 1
# Since we want to also predict forces, i.e. the gradient of the energy w.r.t. positions (argument 1), we use
# jax.value_and_grad to create a function for predicting both energy and forces for us.
energy_and_forces = jax.value_and_grad(self.energy, argnums=1, has_aux=True)
(_, energy), forces = energy_and_forces(atomic_numbers, positions, dst_idx, src_idx, batch_segments, batch_size)
return energy, forces
With our model in place, we now write a function to prepare a batch of molecules in the format described above. For this simple example, the source and destination index lists include all \(N(N-1)\) pairwise combinations of \(N\) atoms (without self-interactions). E3x contains helper functions, such as sparse_pairwise_indices
, to construct such \(O(N^2)\) index lists. This is fine for a small molecule like ethanol, but we can do better than \(O(N^2)\) scaling, which becomes
important for larger molecules. Recall that our message-passing model uses a cutoff when expanding displacement vectors in basis functions. While it doesn’t hurt to include interactions beyond the cutoff in the index lists (they will simply contribute nothing to the message received by the destination atom), it is not very efficient to do so. In a real application, it would be better to construct the index lists using some spatial partitioning method, such that only interactions within the
cutoff distance are included. With this trick, evaluating the message-passing model scales \(O(NM)\), where \(M \ll N\) is the average number of atoms within the cutoff distance.
[5]:
def prepare_batches(key, data, batch_size):
# Determine the number of training steps per epoch.
data_size = len(data['energy'])
steps_per_epoch = data_size//batch_size
# Draw random permutations for fetching batches from the train data.
perms = jax.random.permutation(key, data_size)
perms = perms[:steps_per_epoch * batch_size] # Skip the last batch (if incomplete).
perms = perms.reshape((steps_per_epoch, batch_size))
# Prepare entries that are identical for each batch.
num_atoms = len(data['atomic_numbers'])
batch_segments = jnp.repeat(jnp.arange(batch_size), num_atoms)
atomic_numbers = jnp.tile(data['atomic_numbers'], batch_size)
offsets = jnp.arange(batch_size) * num_atoms
dst_idx, src_idx = e3x.ops.sparse_pairwise_indices(num_atoms)
dst_idx = (dst_idx + offsets[:, None]).reshape(-1)
src_idx = (src_idx + offsets[:, None]).reshape(-1)
# Assemble and return batches.
return [
dict(
energy=data['energy'][perm],
forces=data['forces'][perm].reshape(-1, 3),
atomic_numbers=atomic_numbers,
positions=data['positions'][perm].reshape(-1, 3),
dst_idx=dst_idx,
src_idx=src_idx,
batch_segments = batch_segments,
)
for perm in perms
]
Next, we define our loss function. As is common for regression tasks, we choose the \(L_2\) (squared error) loss. Since we want to fit two different quantities at once (energy and forces), we simply add their respective losses together. We also add a weighting factor that allows us to tune the relative importance of the different loss terms. For convenience, we also define a function to compute the mean absolute errors, for keeping track of the model performance during training.
[6]:
def mean_squared_loss(energy_prediction, energy_target, forces_prediction, forces_target, forces_weight):
energy_loss = jnp.mean(optax.l2_loss(energy_prediction, energy_target))
forces_loss = jnp.mean(optax.l2_loss(forces_prediction, forces_target))
return energy_loss + forces_weight * forces_loss
def mean_absolute_error(prediction, target):
return jnp.mean(jnp.abs(prediction - target))
Now that we have all the ingredients, we need to write some boilerplate for training models.
[7]:
@functools.partial(jax.jit, static_argnames=('model_apply', 'optimizer_update', 'batch_size'))
def train_step(model_apply, optimizer_update, batch, batch_size, forces_weight, opt_state, params):
def loss_fn(params):
energy, forces = model_apply(
params,
atomic_numbers=batch['atomic_numbers'],
positions=batch['positions'],
dst_idx=batch['dst_idx'],
src_idx=batch['src_idx'],
batch_segments=batch['batch_segments'],
batch_size=batch_size
)
loss = mean_squared_loss(
energy_prediction=energy,
energy_target=batch['energy'],
forces_prediction=forces,
forces_target=batch['forces'],
forces_weight=forces_weight
)
return loss, (energy, forces)
(loss, (energy, forces)), grad = jax.value_and_grad(loss_fn, has_aux=True)(params)
updates, opt_state = optimizer_update(grad, opt_state, params)
params = optax.apply_updates(params, updates)
energy_mae = mean_absolute_error(energy, batch['energy'])
forces_mae = mean_absolute_error(forces, batch['forces'])
return params, opt_state, loss, energy_mae, forces_mae
@functools.partial(jax.jit, static_argnames=('model_apply', 'batch_size'))
def eval_step(model_apply, batch, batch_size, forces_weight, params):
energy, forces = model_apply(
params,
atomic_numbers=batch['atomic_numbers'],
positions=batch['positions'],
dst_idx=batch['dst_idx'],
src_idx=batch['src_idx'],
batch_segments=batch['batch_segments'],
batch_size=batch_size
)
loss = mean_squared_loss(
energy_prediction=energy,
energy_target=batch['energy'],
forces_prediction=forces,
forces_target=batch['forces'],
forces_weight=forces_weight
)
energy_mae = mean_absolute_error(energy, batch['energy'])
forces_mae = mean_absolute_error(forces, batch['forces'])
return loss, energy_mae, forces_mae
def train_model(key, model, train_data, valid_data, num_epochs, learning_rate, forces_weight, batch_size):
# Initialize model parameters and optimizer state.
key, init_key = jax.random.split(key)
optimizer = optax.adam(learning_rate)
dst_idx, src_idx = e3x.ops.sparse_pairwise_indices(len(train_data['atomic_numbers']))
params = model.init(init_key,
atomic_numbers=train_data['atomic_numbers'],
positions=train_data['positions'][0],
dst_idx=dst_idx,
src_idx=src_idx,
)
opt_state = optimizer.init(params)
# Batches for the validation set need to be prepared only once.
key, shuffle_key = jax.random.split(key)
valid_batches = prepare_batches(shuffle_key, valid_data, batch_size)
# Train for 'num_epochs' epochs.
for epoch in range(1, num_epochs + 1):
# Prepare batches.
key, shuffle_key = jax.random.split(key)
train_batches = prepare_batches(shuffle_key, train_data, batch_size)
# Loop over train batches.
train_loss = 0.0
train_energy_mae = 0.0
train_forces_mae = 0.0
for i, batch in enumerate(train_batches):
params, opt_state, loss, energy_mae, forces_mae = train_step(
model_apply=model.apply,
optimizer_update=optimizer.update,
batch=batch,
batch_size=batch_size,
forces_weight=forces_weight,
opt_state=opt_state,
params=params
)
train_loss += (loss - train_loss)/(i+1)
train_energy_mae += (energy_mae - train_energy_mae)/(i+1)
train_forces_mae += (forces_mae - train_forces_mae)/(i+1)
# Evaluate on validation set.
valid_loss = 0.0
valid_energy_mae = 0.0
valid_forces_mae = 0.0
for i, batch in enumerate(valid_batches):
loss, energy_mae, forces_mae = eval_step(
model_apply=model.apply,
batch=batch,
batch_size=batch_size,
forces_weight=forces_weight,
params=params
)
valid_loss += (loss - valid_loss)/(i+1)
valid_energy_mae += (energy_mae - valid_energy_mae)/(i+1)
valid_forces_mae += (forces_mae - valid_forces_mae)/(i+1)
# Print progress.
print(f"epoch: {epoch: 3d} train: valid:")
print(f" loss [a.u.] {train_loss : 8.3f} {valid_loss : 8.3f}")
print(f" energy mae [kcal/mol] {train_energy_mae: 8.3f} {valid_energy_mae: 8.3f}")
print(f" forces mae [kcal/mol/Å] {train_forces_mae: 8.3f} {valid_forces_mae: 8.3f}")
# Return final model parameters.
return params
The last step before training the model is to define the hyperparamters.
[8]:
# Model hyperparameters.
features = 32
max_degree = 2
num_iterations = 3
num_basis_functions = 16
cutoff = 5.0
# Training hyperparameters.
num_train = 900
num_valid = 100
num_epochs = 100
learning_rate = 0.01
forces_weight = 1.0
batch_size = 10
Finally, we can train our model.
[9]:
# Create PRNGKeys.
data_key, train_key = jax.random.split(jax.random.PRNGKey(0), 2)
# Draw training and validation sets.
train_data, valid_data, _ = prepare_datasets(data_key, num_train=num_train, num_valid=num_valid)
# Create and train model.
message_passing_model = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
)
params = train_model(
key=train_key,
model=message_passing_model,
train_data=train_data,
valid_data=valid_data,
num_epochs=num_epochs,
learning_rate=learning_rate,
forces_weight=forces_weight,
batch_size=batch_size,
)
epoch: 1 train: valid:
loss [a.u.] 354.857 356.750
energy mae [kcal/mol] 3.333 3.731
forces mae [kcal/mol/Å] 19.506 19.367
epoch: 2 train: valid:
loss [a.u.] 299.191 193.864
energy mae [kcal/mol] 3.777 2.866
forces mae [kcal/mol/Å] 17.600 14.072
epoch: 3 train: valid:
loss [a.u.] 145.379 101.722
energy mae [kcal/mol] 4.058 5.459
forces mae [kcal/mol/Å] 11.632 9.264
epoch: 4 train: valid:
loss [a.u.] 92.005 58.962
energy mae [kcal/mol] 3.685 2.793
forces mae [kcal/mol/Å] 9.238 7.459
epoch: 5 train: valid:
loss [a.u.] 65.880 47.165
energy mae [kcal/mol] 3.112 1.574
forces mae [kcal/mol/Å] 7.773 6.711
epoch: 6 train: valid:
loss [a.u.] 47.643 38.291
energy mae [kcal/mol] 3.033 1.962
forces mae [kcal/mol/Å] 6.433 6.058
epoch: 7 train: valid:
loss [a.u.] 34.482 24.528
energy mae [kcal/mol] 2.234 1.509
forces mae [kcal/mol/Å] 5.536 4.864
epoch: 8 train: valid:
loss [a.u.] 28.742 25.687
energy mae [kcal/mol] 2.656 3.414
forces mae [kcal/mol/Å] 4.924 4.488
epoch: 9 train: valid:
loss [a.u.] 25.716 18.172
energy mae [kcal/mol] 2.366 1.424
forces mae [kcal/mol/Å] 4.737 4.169
epoch: 10 train: valid:
loss [a.u.] 21.506 19.457
energy mae [kcal/mol] 1.999 1.346
forces mae [kcal/mol/Å] 4.391 4.381
epoch: 11 train: valid:
loss [a.u.] 18.409 15.401
energy mae [kcal/mol] 1.890 1.706
forces mae [kcal/mol/Å] 4.072 3.850
epoch: 12 train: valid:
loss [a.u.] 16.933 13.269
energy mae [kcal/mol] 1.873 1.201
forces mae [kcal/mol/Å] 3.894 3.652
epoch: 13 train: valid:
loss [a.u.] 14.334 11.579
energy mae [kcal/mol] 1.484 1.516
forces mae [kcal/mol/Å] 3.685 3.294
epoch: 14 train: valid:
loss [a.u.] 13.378 22.506
energy mae [kcal/mol] 1.676 4.767
forces mae [kcal/mol/Å] 3.475 3.423
epoch: 15 train: valid:
loss [a.u.] 13.606 14.388
energy mae [kcal/mol] 1.731 1.928
forces mae [kcal/mol/Å] 3.520 3.769
epoch: 16 train: valid:
loss [a.u.] 12.465 9.474
energy mae [kcal/mol] 1.499 1.412
forces mae [kcal/mol/Å] 3.372 2.958
epoch: 17 train: valid:
loss [a.u.] 10.487 7.631
energy mae [kcal/mol] 1.457 1.041
forces mae [kcal/mol/Å] 3.085 2.664
epoch: 18 train: valid:
loss [a.u.] 11.442 8.065
energy mae [kcal/mol] 1.522 0.983
forces mae [kcal/mol/Å] 3.234 2.783
epoch: 19 train: valid:
loss [a.u.] 9.678 11.261
energy mae [kcal/mol] 1.424 2.474
forces mae [kcal/mol/Å] 2.981 2.852
epoch: 20 train: valid:
loss [a.u.] 8.938 7.979
energy mae [kcal/mol] 1.549 0.960
forces mae [kcal/mol/Å] 2.775 2.928
epoch: 21 train: valid:
loss [a.u.] 8.589 6.849
energy mae [kcal/mol] 1.344 1.351
forces mae [kcal/mol/Å] 2.770 2.497
epoch: 22 train: valid:
loss [a.u.] 6.768 14.650
energy mae [kcal/mol] 1.230 4.167
forces mae [kcal/mol/Å] 2.466 2.446
epoch: 23 train: valid:
loss [a.u.] 7.701 6.763
energy mae [kcal/mol] 1.289 1.685
forces mae [kcal/mol/Å] 2.625 2.279
epoch: 24 train: valid:
loss [a.u.] 6.788 6.545
energy mae [kcal/mol] 1.036 0.752
forces mae [kcal/mol/Å] 2.522 2.520
epoch: 25 train: valid:
loss [a.u.] 6.357 6.105
energy mae [kcal/mol] 1.389 1.183
forces mae [kcal/mol/Å] 2.313 2.469
epoch: 26 train: valid:
loss [a.u.] 6.928 5.101
energy mae [kcal/mol] 1.424 1.209
forces mae [kcal/mol/Å] 2.403 2.120
epoch: 27 train: valid:
loss [a.u.] 5.723 4.971
energy mae [kcal/mol] 1.118 0.576
forces mae [kcal/mol/Å] 2.263 2.182
epoch: 28 train: valid:
loss [a.u.] 5.307 6.905
energy mae [kcal/mol] 0.959 0.617
forces mae [kcal/mol/Å] 2.226 2.763
epoch: 29 train: valid:
loss [a.u.] 4.863 6.667
energy mae [kcal/mol] 1.120 2.057
forces mae [kcal/mol/Å] 2.064 2.120
epoch: 30 train: valid:
loss [a.u.] 5.157 5.295
energy mae [kcal/mol] 1.157 1.480
forces mae [kcal/mol/Å] 2.115 2.083
epoch: 31 train: valid:
loss [a.u.] 5.064 5.305
energy mae [kcal/mol] 1.061 1.803
forces mae [kcal/mol/Å] 2.135 1.934
epoch: 32 train: valid:
loss [a.u.] 5.590 3.384
energy mae [kcal/mol] 1.073 0.494
forces mae [kcal/mol/Å] 2.242 1.830
epoch: 33 train: valid:
loss [a.u.] 3.675 4.018
energy mae [kcal/mol] 0.891 0.784
forces mae [kcal/mol/Å] 1.810 1.983
epoch: 34 train: valid:
loss [a.u.] 4.757 3.721
energy mae [kcal/mol] 1.216 0.734
forces mae [kcal/mol/Å] 1.983 1.847
epoch: 35 train: valid:
loss [a.u.] 4.804 3.326
energy mae [kcal/mol] 1.054 0.506
forces mae [kcal/mol/Å] 2.058 1.835
epoch: 36 train: valid:
loss [a.u.] 3.732 3.326
energy mae [kcal/mol] 0.759 0.479
forces mae [kcal/mol/Å] 1.872 1.888
epoch: 37 train: valid:
loss [a.u.] 4.103 3.825
energy mae [kcal/mol] 0.953 1.059
forces mae [kcal/mol/Å] 1.911 1.927
epoch: 38 train: valid:
loss [a.u.] 4.199 5.174
energy mae [kcal/mol] 1.206 0.752
forces mae [kcal/mol/Å] 1.844 2.273
epoch: 39 train: valid:
loss [a.u.] 3.006 2.711
energy mae [kcal/mol] 0.632 0.444
forces mae [kcal/mol/Å] 1.703 1.620
epoch: 40 train: valid:
loss [a.u.] 3.244 4.489
energy mae [kcal/mol] 0.876 1.486
forces mae [kcal/mol/Å] 1.711 1.855
epoch: 41 train: valid:
loss [a.u.] 3.123 7.015
energy mae [kcal/mol] 0.890 2.608
forces mae [kcal/mol/Å] 1.650 1.996
epoch: 42 train: valid:
loss [a.u.] 4.496 4.365
energy mae [kcal/mol] 1.126 1.243
forces mae [kcal/mol/Å] 1.974 1.948
epoch: 43 train: valid:
loss [a.u.] 2.913 2.353
energy mae [kcal/mol] 0.679 0.432
forces mae [kcal/mol/Å] 1.657 1.536
epoch: 44 train: valid:
loss [a.u.] 2.285 2.577
energy mae [kcal/mol] 0.591 0.423
forces mae [kcal/mol/Å] 1.478 1.625
epoch: 45 train: valid:
loss [a.u.] 2.743 2.041
energy mae [kcal/mol] 0.779 0.511
forces mae [kcal/mol/Å] 1.566 1.398
epoch: 46 train: valid:
loss [a.u.] 2.836 2.580
energy mae [kcal/mol] 0.570 0.662
forces mae [kcal/mol/Å] 1.663 1.510
epoch: 47 train: valid:
loss [a.u.] 2.846 3.848
energy mae [kcal/mol] 0.787 1.620
forces mae [kcal/mol/Å] 1.603 1.650
epoch: 48 train: valid:
loss [a.u.] 2.563 1.875
energy mae [kcal/mol] 0.688 0.410
forces mae [kcal/mol/Å] 1.540 1.368
epoch: 49 train: valid:
loss [a.u.] 2.900 2.622
energy mae [kcal/mol] 0.823 0.975
forces mae [kcal/mol/Å] 1.591 1.541
epoch: 50 train: valid:
loss [a.u.] 3.487 2.985
energy mae [kcal/mol] 0.900 0.484
forces mae [kcal/mol/Å] 1.725 1.705
epoch: 51 train: valid:
loss [a.u.] 3.613 4.224
energy mae [kcal/mol] 1.081 0.633
forces mae [kcal/mol/Å] 1.714 2.192
epoch: 52 train: valid:
loss [a.u.] 2.610 2.156
energy mae [kcal/mol] 0.686 0.734
forces mae [kcal/mol/Å] 1.544 1.361
epoch: 53 train: valid:
loss [a.u.] 2.147 1.746
energy mae [kcal/mol] 0.652 0.349
forces mae [kcal/mol/Å] 1.402 1.291
epoch: 54 train: valid:
loss [a.u.] 3.319 4.700
energy mae [kcal/mol] 0.889 0.662
forces mae [kcal/mol/Å] 1.708 2.355
epoch: 55 train: valid:
loss [a.u.] 3.721 2.377
energy mae [kcal/mol] 0.980 1.106
forces mae [kcal/mol/Å] 1.801 1.362
epoch: 56 train: valid:
loss [a.u.] 2.098 2.240
energy mae [kcal/mol] 0.688 0.606
forces mae [kcal/mol/Å] 1.377 1.456
epoch: 57 train: valid:
loss [a.u.] 1.907 1.330
energy mae [kcal/mol] 0.534 0.598
forces mae [kcal/mol/Å] 1.353 1.077
epoch: 58 train: valid:
loss [a.u.] 1.924 1.640
energy mae [kcal/mol] 0.637 0.311
forces mae [kcal/mol/Å] 1.300 1.276
epoch: 59 train: valid:
loss [a.u.] 2.122 2.328
energy mae [kcal/mol] 0.797 0.984
forces mae [kcal/mol/Å] 1.347 1.337
epoch: 60 train: valid:
loss [a.u.] 1.857 1.852
energy mae [kcal/mol] 0.541 0.712
forces mae [kcal/mol/Å] 1.325 1.254
epoch: 61 train: valid:
loss [a.u.] 2.010 2.182
energy mae [kcal/mol] 0.580 0.303
forces mae [kcal/mol/Å] 1.365 1.520
epoch: 62 train: valid:
loss [a.u.] 2.462 2.200
energy mae [kcal/mol] 0.741 0.444
forces mae [kcal/mol/Å] 1.465 1.486
epoch: 63 train: valid:
loss [a.u.] 1.674 1.520
energy mae [kcal/mol] 0.493 0.508
forces mae [kcal/mol/Å] 1.268 1.236
epoch: 64 train: valid:
loss [a.u.] 4.390 4.820
energy mae [kcal/mol] 0.742 1.436
forces mae [kcal/mol/Å] 1.979 1.987
epoch: 65 train: valid:
loss [a.u.] 3.699 2.103
energy mae [kcal/mol] 0.808 0.496
forces mae [kcal/mol/Å] 1.802 1.486
epoch: 66 train: valid:
loss [a.u.] 1.601 1.467
energy mae [kcal/mol] 0.557 0.425
forces mae [kcal/mol/Å] 1.220 1.191
epoch: 67 train: valid:
loss [a.u.] 1.972 1.711
energy mae [kcal/mol] 0.601 0.523
forces mae [kcal/mol/Å] 1.355 1.237
epoch: 68 train: valid:
loss [a.u.] 1.483 1.373
energy mae [kcal/mol] 0.514 0.827
forces mae [kcal/mol/Å] 1.187 1.041
epoch: 69 train: valid:
loss [a.u.] 1.825 2.062
energy mae [kcal/mol] 0.640 0.632
forces mae [kcal/mol/Å] 1.271 1.388
epoch: 70 train: valid:
loss [a.u.] 1.973 1.732
energy mae [kcal/mol] 0.617 0.392
forces mae [kcal/mol/Å] 1.332 1.283
epoch: 71 train: valid:
loss [a.u.] 1.816 3.849
energy mae [kcal/mol] 0.530 1.314
forces mae [kcal/mol/Å] 1.309 1.911
epoch: 72 train: valid:
loss [a.u.] 1.774 2.064
energy mae [kcal/mol] 0.585 0.539
forces mae [kcal/mol/Å] 1.274 1.477
epoch: 73 train: valid:
loss [a.u.] 2.003 1.536
energy mae [kcal/mol] 0.611 0.576
forces mae [kcal/mol/Å] 1.362 1.175
epoch: 74 train: valid:
loss [a.u.] 1.913 2.213
energy mae [kcal/mol] 0.644 0.961
forces mae [kcal/mol/Å] 1.312 1.252
epoch: 75 train: valid:
loss [a.u.] 1.862 2.396
energy mae [kcal/mol] 0.643 0.364
forces mae [kcal/mol/Å] 1.277 1.615
epoch: 76 train: valid:
loss [a.u.] 1.563 1.570
energy mae [kcal/mol] 0.413 0.556
forces mae [kcal/mol/Å] 1.239 1.168
epoch: 77 train: valid:
loss [a.u.] 1.869 2.005
energy mae [kcal/mol] 0.769 0.539
forces mae [kcal/mol/Å] 1.247 1.451
epoch: 78 train: valid:
loss [a.u.] 1.565 1.496
energy mae [kcal/mol] 0.694 0.311
forces mae [kcal/mol/Å] 1.154 1.227
epoch: 79 train: valid:
loss [a.u.] 1.476 1.275
energy mae [kcal/mol] 0.551 0.611
forces mae [kcal/mol/Å] 1.162 1.022
epoch: 80 train: valid:
loss [a.u.] 3.263 4.741
energy mae [kcal/mol] 0.807 0.839
forces mae [kcal/mol/Å] 1.653 2.087
epoch: 81 train: valid:
loss [a.u.] 2.422 1.622
energy mae [kcal/mol] 0.643 0.434
forces mae [kcal/mol/Å] 1.491 1.230
epoch: 82 train: valid:
loss [a.u.] 1.173 1.157
energy mae [kcal/mol] 0.443 0.438
forces mae [kcal/mol/Å] 1.043 1.037
epoch: 83 train: valid:
loss [a.u.] 2.153 1.241
energy mae [kcal/mol] 0.745 0.636
forces mae [kcal/mol/Å] 1.363 1.026
epoch: 84 train: valid:
loss [a.u.] 2.218 2.663
energy mae [kcal/mol] 0.616 0.495
forces mae [kcal/mol/Å] 1.437 1.517
epoch: 85 train: valid:
loss [a.u.] 1.928 1.388
energy mae [kcal/mol] 0.558 0.297
forces mae [kcal/mol/Å] 1.347 1.223
epoch: 86 train: valid:
loss [a.u.] 1.785 1.662
energy mae [kcal/mol] 0.759 0.458
forces mae [kcal/mol/Å] 1.183 1.247
epoch: 87 train: valid:
loss [a.u.] 1.565 1.159
energy mae [kcal/mol] 0.587 0.661
forces mae [kcal/mol/Å] 1.195 0.979
epoch: 88 train: valid:
loss [a.u.] 1.192 1.318
energy mae [kcal/mol] 0.526 0.453
forces mae [kcal/mol/Å] 1.040 1.105
epoch: 89 train: valid:
loss [a.u.] 1.398 2.636
energy mae [kcal/mol] 0.494 0.424
forces mae [kcal/mol/Å] 1.141 1.771
epoch: 90 train: valid:
loss [a.u.] 1.490 1.044
energy mae [kcal/mol] 0.487 0.582
forces mae [kcal/mol/Å] 1.188 0.941
epoch: 91 train: valid:
loss [a.u.] 1.229 2.328
energy mae [kcal/mol] 0.515 0.928
forces mae [kcal/mol/Å] 1.047 1.350
epoch: 92 train: valid:
loss [a.u.] 1.562 0.969
energy mae [kcal/mol] 0.464 0.452
forces mae [kcal/mol/Å] 1.218 0.941
epoch: 93 train: valid:
loss [a.u.] 1.891 1.960
energy mae [kcal/mol] 0.673 0.906
forces mae [kcal/mol/Å] 1.270 1.267
epoch: 94 train: valid:
loss [a.u.] 1.389 1.734
energy mae [kcal/mol] 0.485 0.517
forces mae [kcal/mol/Å] 1.126 1.242
epoch: 95 train: valid:
loss [a.u.] 2.041 4.921
energy mae [kcal/mol] 0.612 0.908
forces mae [kcal/mol/Å] 1.369 2.325
epoch: 96 train: valid:
loss [a.u.] 2.202 1.152
energy mae [kcal/mol] 0.618 0.320
forces mae [kcal/mol/Å] 1.433 1.067
epoch: 97 train: valid:
loss [a.u.] 2.549 2.517
energy mae [kcal/mol] 0.768 0.582
forces mae [kcal/mol/Å] 1.481 1.556
epoch: 98 train: valid:
loss [a.u.] 2.426 2.014
energy mae [kcal/mol] 0.665 0.424
forces mae [kcal/mol/Å] 1.492 1.364
epoch: 99 train: valid:
loss [a.u.] 1.519 1.761
energy mae [kcal/mol] 0.579 0.347
forces mae [kcal/mol/Å] 1.167 1.301
epoch: 100 train: valid:
loss [a.u.] 1.584 1.085
energy mae [kcal/mol] 0.547 0.286
forces mae [kcal/mol/Å] 1.211 1.066
Even a very simple model can achieve energy prediction errors well below 1 kcal/mol (“chemical accuracy”) after training for a hundred epochs (the model is not converged yet, but it’s good enough for now).
Once the model is trained, we can use it as a force field, e.g. to do structure optimization, or to run molecular dynamics simulations. To demonstrate this, we use the atomic simulation environment (ASE) and py3Dmol (for visualization), so a few additional imports are necessary.
[10]:
import io
import ase
import ase.calculators.calculator as ase_calc
import ase.io as ase_io
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary, ZeroRotation
from ase.md.verlet import VelocityVerlet
import ase.optimize as ase_opt
import matplotlib.pyplot as plt
import py3Dmol
Before we can use our model as a Calculator
in ASE, we need to write a small interface. The following implementation is very bare bones, but it does the trick for a few simple demonstrations.
[11]:
@jax.jit
def evaluate_energies_and_forces(atomic_numbers, positions, dst_idx, src_idx):
return message_passing_model.apply(params,
atomic_numbers=atomic_numbers,
positions=positions,
dst_idx=dst_idx,
src_idx=src_idx,
)
class MessagePassingCalculator(ase_calc.Calculator):
implemented_properties = ["energy", "forces"]
def calculate(self, atoms, properties, system_changes = ase.calculators.calculator.all_changes):
ase_calc.Calculator.calculate(self, atoms, properties, system_changes)
dst_idx, src_idx = e3x.ops.sparse_pairwise_indices(len(atoms))
energy, forces = evaluate_energies_and_forces(
atomic_numbers=atoms.get_atomic_numbers(),
positions=atoms.get_positions(),
dst_idx=dst_idx,
src_idx=src_idx
)
self.results['energy'] = energy * ase.units.kcal/ase.units.mol
self.results['forces'] = forces * ase.units.kcal/ase.units.mol
At first, let’s find an optimized structure for ethanol. We arbitrarily initialize the structure to some entry in our training data and then use ASE’s BFGS
optimizer to optimize it.
[12]:
# Initialize atoms object and attach calculator.
atoms = ase.Atoms(train_data['atomic_numbers'], train_data['positions'][0])
atoms.set_calculator(MessagePassingCalculator())
# Run structure optimization with BFGS.
_ = ase_opt.BFGS(atoms).run(fmax=0.05)
Step Time Energy fmax
BFGS: 0 08:51:53 0.048474 3.0194
BFGS: 1 08:51:53 -0.251758 1.6836
BFGS: 2 08:51:53 -0.396961 1.4604
BFGS: 3 08:51:53 -0.464969 0.9650
BFGS: 4 08:51:53 -0.499852 0.4686
BFGS: 5 08:51:53 -0.516434 0.3922
BFGS: 6 08:51:53 -0.542414 0.4775
BFGS: 7 08:51:53 -0.552331 0.2690
BFGS: 8 08:51:53 -0.558380 0.2286
BFGS: 9 08:51:53 -0.563047 0.2404
BFGS: 10 08:51:53 -0.567419 0.2482
BFGS: 11 08:51:53 -0.570238 0.1684
BFGS: 12 08:51:53 -0.571843 0.1195
BFGS: 13 08:51:53 -0.573060 0.1111
BFGS: 14 08:51:53 -0.574244 0.1020
BFGS: 15 08:51:53 -0.575149 0.0860
BFGS: 16 08:51:53 -0.575878 0.1010
BFGS: 17 08:51:53 -0.576817 0.1314
BFGS: 18 08:51:53 -0.578436 0.1580
BFGS: 19 08:51:53 -0.580875 0.1718
BFGS: 20 08:51:53 -0.583505 0.1523
BFGS: 21 08:51:53 -0.585652 0.1337
BFGS: 22 08:51:53 -0.587525 0.1372
BFGS: 23 08:51:53 -0.589606 0.1612
BFGS: 24 08:51:53 -0.591745 0.1458
BFGS: 25 08:51:53 -0.593183 0.1071
BFGS: 26 08:51:53 -0.593988 0.0820
BFGS: 27 08:51:53 -0.594515 0.0670
BFGS: 28 08:51:53 -0.594925 0.0593
BFGS: 29 08:51:53 -0.595153 0.0346
A promising start! Note that ASE uses electron volts (eV) as energy units, so the final structure has an energy that is lower by roughly 0.6 eV, or about 14 kcal/mol. Let’s check how the optimized structure looks.
[13]:
# Write structure to xyz file.
xyz = io.StringIO()
ase_io.write(xyz, atoms, format='xyz')
# Visualize the structure with py3Dmol.
view = py3Dmol.view()
view.addModel(xyz.getvalue(), 'xyz')
view.setStyle({'stick': {'radius': 0.15}, 'sphere': {'scale': 0.25}})
view.show()
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the 3dmol extension:
jupyter labextension install jupyterlab_3dmol
Now let’s try to run a molecular dynamics simulation. We draw initial momenta from a Maxwell-Boltzmann distribution at 300 K, remove center of mass translation and rotation, and then use Verlet integration to integrate the equations of motion for 2000 steps using a timestep of 0.5 fs. For visualization purposes, we also save each frame and keep track of potential, kinetic, and total energies.
[14]:
# Parameters.
temperature = 300
timestep_fs = 0.5
num_steps = 2000
# Draw initial momenta.
MaxwellBoltzmannDistribution(atoms, temperature_K=temperature)
Stationary(atoms) # Remove center of mass translation.
ZeroRotation(atoms) # Remove rotations.
# Initialize Velocity Verlet integrator.
integrator = VelocityVerlet(atoms, timestep=timestep_fs*ase.units.fs)
# Run molecular dynamics.
frames = np.zeros((num_steps, len(atoms), 3))
potential_energy = np.zeros((num_steps,))
kinetic_energy = np.zeros((num_steps,))
total_energy = np.zeros((num_steps,))
for i in range(num_steps):
# Run 1 time step.
integrator.run(1)
# Save current frame and keep track of energies.
frames[i] = atoms.get_positions()
potential_energy[i] = atoms.get_potential_energy()
kinetic_energy[i] = atoms.get_kinetic_energy()
total_energy[i] = atoms.get_total_energy()
# Occasionally print progress.
if i % 100 == 0:
print(f"step {i:5d} epot {potential_energy[i]: 5.3f} ekin {kinetic_energy[i]: 5.3f} etot {total_energy[i]: 5.3f}")
step 0 epot -0.588 ekin 0.166 etot -0.422
step 100 epot -0.498 ekin 0.077 etot -0.421
step 200 epot -0.476 ekin 0.055 etot -0.421
step 300 epot -0.514 ekin 0.092 etot -0.422
step 400 epot -0.507 ekin 0.085 etot -0.422
step 500 epot -0.499 ekin 0.078 etot -0.421
step 600 epot -0.502 ekin 0.080 etot -0.422
step 700 epot -0.518 ekin 0.097 etot -0.422
step 800 epot -0.504 ekin 0.083 etot -0.421
step 900 epot -0.523 ekin 0.102 etot -0.422
step 1000 epot -0.498 ekin 0.076 etot -0.422
step 1100 epot -0.500 ekin 0.078 etot -0.422
step 1200 epot -0.528 ekin 0.107 etot -0.422
step 1300 epot -0.487 ekin 0.065 etot -0.422
step 1400 epot -0.497 ekin 0.075 etot -0.422
step 1500 epot -0.527 ekin 0.105 etot -0.422
step 1600 epot -0.523 ekin 0.102 etot -0.422
step 1700 epot -0.514 ekin 0.092 etot -0.422
step 1800 epot -0.495 ekin 0.074 etot -0.422
step 1900 epot -0.501 ekin 0.080 etot -0.421
Let’s have a look at the trajectory.
[15]:
view.getModel().setCoordinates(frames, 'array')
view.animate({'loop': 'forward', 'interval': 0.1})
view.show()
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the 3dmol extension:
jupyter labextension install jupyterlab_3dmol