Download this example as Jupyter notebook

MD17: Ethanol (force field construction)

This tutorial demonstrates how to use E3x to construct a machine-learned force field. For this example, we use the MD17 dataset for ethanol. The code is written to be easy to adapt to datasets that contain multiple molecules of different size.

First, all necessary packages are imported.

[1]:
import functools
import os
import urllib.request
import e3x
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax

# Disable future warnings.
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

Next, we download the MD17 dataset for ethanol.

[2]:
# Download the dataset.
filename = "md17_ethanol.npz"
if not os.path.exists(filename):
  print(f"Downloading {filename} (this may take a while)...")
  urllib.request.urlretrieve(f"http://www.quantum-machine.org/gdml/data/npz/{filename}", filename)

Now we write a small helper function to prepare train and validation datasets. Apart from choosing subsets of the full dataset and converting the data to a different format, this function also subtracts the mean energy of the training set from all data points. This is a useful numerical trick: The energy of a molecule is often a very large number (around \(-97 000\) kcal/mol for ethanol), but the relative energy between different conformations of the same molecule is usually much smaller (a few kcal/mol for ethanol). Since models are typically trained with single precision floats for efficiency, the large energy offset wastes precious numerical precision. For this reason, large offsets should be subtracted frow the raw data (typically stored in double precision) before converting to single precision. Note that subtracting arbitrary constants from the energy does not change physics (of course, the constant can also just be added back to the model predictions after training).

[3]:
def prepare_datasets(key, num_train, num_valid):
  # Load the dataset.
  dataset = np.load(filename)

  # Make sure that the dataset contains enough entries.
  num_data = len(dataset['E'])
  num_draw = num_train + num_valid
  if num_draw > num_data:
    raise RuntimeError(
      f'datasets only contains {num_data} points, requested num_train={num_train}, num_valid={num_valid}')

  # Randomly draw train and validation sets from dataset.
  choice = np.asarray(jax.random.choice(key, num_data, shape=(num_draw,), replace=False))
  train_choice = choice[:num_train]
  valid_choice = choice[num_train:]

  # Determine mean energy of the training set.
  mean_energy = np.mean(dataset['E'][train_choice])  # ~ -97000

  # Collect and return train and validation sets.
  train_data = dict(
    energy=jnp.asarray(dataset['E'][train_choice, 0] - mean_energy),
    forces=jnp.asarray(dataset['F'][train_choice]),
    atomic_numbers=jnp.asarray(dataset['z']),
    positions=jnp.asarray(dataset['R'][train_choice]),
  )
  valid_data = dict(
    energy=jnp.asarray(dataset['E'][valid_choice, 0] - mean_energy),
    forces=jnp.asarray(dataset['F'][valid_choice]),
    atomic_numbers=jnp.asarray(dataset['z']),
    positions=jnp.asarray(dataset['R'][valid_choice]),
  )
  return train_data, valid_data, mean_energy

Next, we define a very simple message-passing neural network.

Let’s first describe the model inputs and outputs: Our model takes atomic_numbers of shape (num_atoms,) (encoding the atom types), positions of shape (num_atoms, 3), and index lists dst_idx (destination index) and src_idx (source index), both of shape (num_pairs,), as inputs. The dst_idx and src_idx index lists specify “which atoms talk with each other” during a message-pass. For example, for a molecule with three atoms, these index lists could be dst_idx = [0, 0, 1, 1, 2, 2] and src_idx = [1, 2, 0, 2, 0, 1]. This would mean that the “destination atom” at index 0 receives messages from the “source atoms” at indices 1 and 2, and so on. This input format allows the model to handle batches of molecules with different numbers of atoms without needing any padding (in this example, all molecules are ethanol, so this is not necessary, but it’s useful in general for more complicated tasks). We simply concatenate the atomic_numbers and positions of all molecules in the batch (as if they were one big molecule) and specify the index lists such that atoms in different molecules of the batch do not “talk to each other”. The only additional information we need is a structure batch_segments that tells the model which atoms belong to which molecule in the batch and the total batch_size. For example, a batch consisting of a H\(_2\)O and a N\(_2\) molecule could be specified with atomic_numbers = [1, 1, 8, 7, 7], dst_idx = [0, 0, 1, 1, 2, 2, 3, 4], src_idx = [1, 2, 0, 2, 0, 1, 4, 3], batch_segments = [0, 0, 0, 1, 1] (meaning the first three atoms belong to molecule 0, the last two to molecule 1), and batch_size = 2. The outputs of our model are the energy of shape (batch_size,) of every molecule in the batch and the forces of shape (num_atoms, 3) acting on each atom.

Since forces are the negative gradient of the energy with respect to the atomic positions, the model only needs to predict energy, as forces can be derived with automatic differentiation. This has the additional advantage that the forces will be conservative, i.e. they will respect the physical principle of energy conservation. The energy prediction comprises the following steps:

  1. Calculate the displacement vectors \(\vec{r}_{ij}=\vec{r}_{j}-\vec{r}_{i}\), where \(\vec{r}_{i}\) and \(\vec{r}_{j}\) are the positions of atoms \(i\) (destination) and \(j\) (source), for all pairs \(i,j\) specified by the index lists.

  2. Expand the displacement vectors in radial-spherical basis functions to featurize them. Note that the basis functions use a cutoff, so that they go to zero beyond a certain distance.

  3. Embed the atoms in feature space by assigning them to learnable embeddings (one for each element).

  4. Perform num_iterations feature refinements. Each iteration performs a message-pass and combines the message with the current features of each atom. These intermediate features are passed through an atom-wise two-layer MLP and added to the original features (residual connection).

  5. Predict atomic energy contributions by performing linear regression on the final (scalar) features of each atom. Each atom type/element has its own bias term.

  6. Sum the atomic energy contributions within each batch segment to obtain the energy for every molecule in the batch.

[4]:
class MessagePassingModel(nn.Module):
  features: int = 32
  max_degree: int = 2
  num_iterations: int = 3
  num_basis_functions: int = 8
  cutoff: float = 5.0
  max_atomic_number: int = 118  # This is overkill for most applications.


  def energy(self, atomic_numbers, positions, dst_idx, src_idx, batch_segments, batch_size):
    # 1. Calculate displacement vectors.
    positions_dst = e3x.ops.gather_dst(positions, dst_idx=dst_idx)
    positions_src = e3x.ops.gather_src(positions, src_idx=src_idx)
    displacements = positions_src - positions_dst  # Shape (num_pairs, 3).

    # 2. Expand displacement vectors in basis functions.
    basis = e3x.nn.basis(  # Shape (num_pairs, 1, (max_degree+1)**2, num_basis_functions).
      displacements,
      num=self.num_basis_functions,
      max_degree=self.max_degree,
      radial_fn=e3x.nn.reciprocal_bernstein,
      cutoff_fn=functools.partial(e3x.nn.smooth_cutoff, cutoff=self.cutoff)
    )

    # 3. Embed atomic numbers in feature space, x has shape (num_atoms, 1, 1, features).
    x = e3x.nn.Embed(num_embeddings=self.max_atomic_number+1, features=self.features)(atomic_numbers)

    # 4. Perform iterations (message-passing + atom-wise refinement).
    for i in range(self.num_iterations):
      # Message-pass.
      if i == self.num_iterations-1:  # Final iteration.
        # Since we will only use scalar features after the final message-pass, we do not want to produce non-scalar
        # features for efficiency reasons.
        y = e3x.nn.MessagePass(max_degree=0, include_pseudotensors=False)(x, basis, dst_idx=dst_idx, src_idx=src_idx)
        # After the final message pass, we can safely throw away all non-scalar features.
        x = e3x.nn.change_max_degree_or_type(x, max_degree=0, include_pseudotensors=False)
      else:
        # In intermediate iterations, the message-pass should consider all possible coupling paths.
        y = e3x.nn.MessagePass()(x, basis, dst_idx=dst_idx, src_idx=src_idx)
      y = e3x.nn.add(x, y)

      # Atom-wise refinement MLP.
      y = e3x.nn.Dense(self.features)(y)
      y = e3x.nn.silu(y)
      y = e3x.nn.Dense(self.features, kernel_init=jax.nn.initializers.zeros)(y)

      # Residual connection.
      x = e3x.nn.add(x, y)

    # 5. Predict atomic energies with an ordinary dense layer.
    element_bias = self.param('element_bias', lambda rng, shape: jnp.zeros(shape), (self.max_atomic_number+1))
    atomic_energies = nn.Dense(1, use_bias=False, kernel_init=jax.nn.initializers.zeros)(x)  # (..., Natoms, 1, 1, 1)
    atomic_energies = jnp.squeeze(atomic_energies, axis=(-1, -2, -3))  # Squeeze last 3 dimensions.
    atomic_energies += element_bias[atomic_numbers]

    # 6. Sum atomic energies to obtain the total energy.
    energy = jax.ops.segment_sum(atomic_energies, segment_ids=batch_segments, num_segments=batch_size)

    # To be able to efficiently compute forces, our model should return a single output (instead of one for each
    # molecule in the batch). Fortunately, since all atomic contributions only influence the energy in their own
    # batch segment, we can simply sum the energy of all molecules in the batch to obtain a single proxy output
    # to differentiate.
    return -jnp.sum(energy), energy  # Forces are the negative gradient, hence the minus sign.


  @nn.compact
  def __call__(self, atomic_numbers, positions, dst_idx, src_idx, batch_segments=None, batch_size=None):
    if batch_segments is None:
      batch_segments = jnp.zeros_like(atomic_numbers)
      batch_size = 1

    # Since we want to also predict forces, i.e. the gradient of the energy w.r.t. positions (argument 1), we use
    # jax.value_and_grad to create a function for predicting both energy and forces for us.
    energy_and_forces = jax.value_and_grad(self.energy, argnums=1, has_aux=True)
    (_, energy), forces = energy_and_forces(atomic_numbers, positions, dst_idx, src_idx, batch_segments, batch_size)

    return energy, forces

With our model in place, we now write a function to prepare a batch of molecules in the format described above. For this simple example, the source and destination index lists include all \(N(N-1)\) pairwise combinations of \(N\) atoms (without self-interactions). E3x contains helper functions, such as sparse_pairwise_indices, to construct such \(O(N^2)\) index lists. This is fine for a small molecule like ethanol, but we can do better than \(O(N^2)\) scaling, which becomes important for larger molecules. Recall that our message-passing model uses a cutoff when expanding displacement vectors in basis functions. While it doesn’t hurt to include interactions beyond the cutoff in the index lists (they will simply contribute nothing to the message received by the destination atom), it is not very efficient to do so. In a real application, it would be better to construct the index lists using some spatial partitioning method, such that only interactions within the cutoff distance are included. With this trick, evaluating the message-passing model scales \(O(NM)\), where \(M \ll N\) is the average number of atoms within the cutoff distance.

[5]:
def prepare_batches(key, data, batch_size):
  # Determine the number of training steps per epoch.
  data_size = len(data['energy'])
  steps_per_epoch = data_size//batch_size

  # Draw random permutations for fetching batches from the train data.
  perms = jax.random.permutation(key, data_size)
  perms = perms[:steps_per_epoch * batch_size]  # Skip the last batch (if incomplete).
  perms = perms.reshape((steps_per_epoch, batch_size))

  # Prepare entries that are identical for each batch.
  num_atoms = len(data['atomic_numbers'])
  batch_segments = jnp.repeat(jnp.arange(batch_size), num_atoms)
  atomic_numbers = jnp.tile(data['atomic_numbers'], batch_size)
  offsets = jnp.arange(batch_size) * num_atoms
  dst_idx, src_idx = e3x.ops.sparse_pairwise_indices(num_atoms)
  dst_idx = (dst_idx + offsets[:, None]).reshape(-1)
  src_idx = (src_idx + offsets[:, None]).reshape(-1)

  # Assemble and return batches.
  return [
    dict(
        energy=data['energy'][perm],
        forces=data['forces'][perm].reshape(-1, 3),
        atomic_numbers=atomic_numbers,
        positions=data['positions'][perm].reshape(-1, 3),
        dst_idx=dst_idx,
        src_idx=src_idx,
        batch_segments = batch_segments,
    )
    for perm in perms
  ]

Next, we define our loss function. As is common for regression tasks, we choose the \(L_2\) (squared error) loss. Since we want to fit two different quantities at once (energy and forces), we simply add their respective losses together. We also add a weighting factor that allows us to tune the relative importance of the different loss terms. For convenience, we also define a function to compute the mean absolute errors, for keeping track of the model performance during training.

[6]:
def mean_squared_loss(energy_prediction, energy_target, forces_prediction, forces_target, forces_weight):
  energy_loss = jnp.mean(optax.l2_loss(energy_prediction, energy_target))
  forces_loss = jnp.mean(optax.l2_loss(forces_prediction, forces_target))
  return energy_loss + forces_weight * forces_loss

def mean_absolute_error(prediction, target):
  return jnp.mean(jnp.abs(prediction - target))

Now that we have all the ingredients, we need to write some boilerplate for training models.

[7]:
@functools.partial(jax.jit, static_argnames=('model_apply', 'optimizer_update', 'batch_size'))
def train_step(model_apply, optimizer_update, batch, batch_size, forces_weight, opt_state, params):
  def loss_fn(params):
    energy, forces = model_apply(
      params,
      atomic_numbers=batch['atomic_numbers'],
      positions=batch['positions'],
      dst_idx=batch['dst_idx'],
      src_idx=batch['src_idx'],
      batch_segments=batch['batch_segments'],
      batch_size=batch_size
    )
    loss = mean_squared_loss(
      energy_prediction=energy,
      energy_target=batch['energy'],
      forces_prediction=forces,
      forces_target=batch['forces'],
      forces_weight=forces_weight
    )
    return loss, (energy, forces)
  (loss, (energy, forces)), grad = jax.value_and_grad(loss_fn, has_aux=True)(params)
  updates, opt_state = optimizer_update(grad, opt_state, params)
  params = optax.apply_updates(params, updates)
  energy_mae = mean_absolute_error(energy, batch['energy'])
  forces_mae = mean_absolute_error(forces, batch['forces'])
  return params, opt_state, loss, energy_mae, forces_mae


@functools.partial(jax.jit, static_argnames=('model_apply', 'batch_size'))
def eval_step(model_apply, batch, batch_size, forces_weight, params):
  energy, forces = model_apply(
    params,
    atomic_numbers=batch['atomic_numbers'],
    positions=batch['positions'],
    dst_idx=batch['dst_idx'],
    src_idx=batch['src_idx'],
    batch_segments=batch['batch_segments'],
    batch_size=batch_size
  )
  loss = mean_squared_loss(
    energy_prediction=energy,
    energy_target=batch['energy'],
    forces_prediction=forces,
    forces_target=batch['forces'],
    forces_weight=forces_weight
  )
  energy_mae = mean_absolute_error(energy, batch['energy'])
  forces_mae = mean_absolute_error(forces, batch['forces'])
  return loss, energy_mae, forces_mae


def train_model(key, model, train_data, valid_data, num_epochs, learning_rate, forces_weight, batch_size):
  # Initialize model parameters and optimizer state.
  key, init_key = jax.random.split(key)
  optimizer = optax.adam(learning_rate)
  dst_idx, src_idx = e3x.ops.sparse_pairwise_indices(len(train_data['atomic_numbers']))
  params = model.init(init_key,
    atomic_numbers=train_data['atomic_numbers'],
    positions=train_data['positions'][0],
    dst_idx=dst_idx,
    src_idx=src_idx,
  )
  opt_state = optimizer.init(params)

  # Batches for the validation set need to be prepared only once.
  key, shuffle_key = jax.random.split(key)
  valid_batches = prepare_batches(shuffle_key, valid_data, batch_size)

  # Train for 'num_epochs' epochs.
  for epoch in range(1, num_epochs + 1):
    # Prepare batches.
    key, shuffle_key = jax.random.split(key)
    train_batches = prepare_batches(shuffle_key, train_data, batch_size)

    # Loop over train batches.
    train_loss = 0.0
    train_energy_mae = 0.0
    train_forces_mae = 0.0
    for i, batch in enumerate(train_batches):
      params, opt_state, loss, energy_mae, forces_mae = train_step(
        model_apply=model.apply,
        optimizer_update=optimizer.update,
        batch=batch,
        batch_size=batch_size,
        forces_weight=forces_weight,
        opt_state=opt_state,
        params=params
      )
      train_loss += (loss - train_loss)/(i+1)
      train_energy_mae += (energy_mae - train_energy_mae)/(i+1)
      train_forces_mae += (forces_mae - train_forces_mae)/(i+1)

    # Evaluate on validation set.
    valid_loss = 0.0
    valid_energy_mae = 0.0
    valid_forces_mae = 0.0
    for i, batch in enumerate(valid_batches):
      loss, energy_mae, forces_mae = eval_step(
        model_apply=model.apply,
        batch=batch,
        batch_size=batch_size,
        forces_weight=forces_weight,
        params=params
      )
      valid_loss += (loss - valid_loss)/(i+1)
      valid_energy_mae += (energy_mae - valid_energy_mae)/(i+1)
      valid_forces_mae += (forces_mae - valid_forces_mae)/(i+1)

    # Print progress.
    print(f"epoch: {epoch: 3d}                    train:   valid:")
    print(f"    loss [a.u.]             {train_loss : 8.3f} {valid_loss : 8.3f}")
    print(f"    energy mae [kcal/mol]   {train_energy_mae: 8.3f} {valid_energy_mae: 8.3f}")
    print(f"    forces mae [kcal/mol/Å] {train_forces_mae: 8.3f} {valid_forces_mae: 8.3f}")


  # Return final model parameters.
  return params

The last step before training the model is to define the hyperparamters.

[8]:
# Model hyperparameters.
features = 32
max_degree = 2
num_iterations = 3
num_basis_functions = 16
cutoff = 5.0

# Training hyperparameters.
num_train = 900
num_valid = 100
num_epochs = 100
learning_rate = 0.01
forces_weight = 1.0
batch_size = 10

Finally, we can train our model.

[9]:
# Create PRNGKeys.
data_key, train_key = jax.random.split(jax.random.PRNGKey(0), 2)

# Draw training and validation sets.
train_data, valid_data, _ = prepare_datasets(data_key, num_train=num_train, num_valid=num_valid)

# Create and train model.
message_passing_model = MessagePassingModel(
  features=features,
  max_degree=max_degree,
  num_iterations=num_iterations,
  num_basis_functions=num_basis_functions,
  cutoff=cutoff,
)
params = train_model(
  key=train_key,
  model=message_passing_model,
  train_data=train_data,
  valid_data=valid_data,
  num_epochs=num_epochs,
  learning_rate=learning_rate,
  forces_weight=forces_weight,
  batch_size=batch_size,
)
epoch:   1                    train:   valid:
    loss [a.u.]              354.857  356.750
    energy mae [kcal/mol]      3.333    3.731
    forces mae [kcal/mol/Å]   19.506   19.367
epoch:   2                    train:   valid:
    loss [a.u.]              299.191  193.864
    energy mae [kcal/mol]      3.777    2.866
    forces mae [kcal/mol/Å]   17.600   14.072
epoch:   3                    train:   valid:
    loss [a.u.]              145.379  101.722
    energy mae [kcal/mol]      4.058    5.459
    forces mae [kcal/mol/Å]   11.632    9.264
epoch:   4                    train:   valid:
    loss [a.u.]               92.005   58.962
    energy mae [kcal/mol]      3.685    2.793
    forces mae [kcal/mol/Å]    9.238    7.459
epoch:   5                    train:   valid:
    loss [a.u.]               65.880   47.165
    energy mae [kcal/mol]      3.112    1.574
    forces mae [kcal/mol/Å]    7.773    6.711
epoch:   6                    train:   valid:
    loss [a.u.]               47.643   38.291
    energy mae [kcal/mol]      3.033    1.962
    forces mae [kcal/mol/Å]    6.433    6.058
epoch:   7                    train:   valid:
    loss [a.u.]               34.482   24.528
    energy mae [kcal/mol]      2.234    1.509
    forces mae [kcal/mol/Å]    5.536    4.864
epoch:   8                    train:   valid:
    loss [a.u.]               28.742   25.687
    energy mae [kcal/mol]      2.656    3.414
    forces mae [kcal/mol/Å]    4.924    4.488
epoch:   9                    train:   valid:
    loss [a.u.]               25.716   18.172
    energy mae [kcal/mol]      2.366    1.424
    forces mae [kcal/mol/Å]    4.737    4.169
epoch:  10                    train:   valid:
    loss [a.u.]               21.506   19.457
    energy mae [kcal/mol]      1.999    1.346
    forces mae [kcal/mol/Å]    4.391    4.381
epoch:  11                    train:   valid:
    loss [a.u.]               18.409   15.401
    energy mae [kcal/mol]      1.890    1.706
    forces mae [kcal/mol/Å]    4.072    3.850
epoch:  12                    train:   valid:
    loss [a.u.]               16.933   13.269
    energy mae [kcal/mol]      1.873    1.201
    forces mae [kcal/mol/Å]    3.894    3.652
epoch:  13                    train:   valid:
    loss [a.u.]               14.334   11.579
    energy mae [kcal/mol]      1.484    1.516
    forces mae [kcal/mol/Å]    3.685    3.294
epoch:  14                    train:   valid:
    loss [a.u.]               13.378   22.506
    energy mae [kcal/mol]      1.676    4.767
    forces mae [kcal/mol/Å]    3.475    3.423
epoch:  15                    train:   valid:
    loss [a.u.]               13.606   14.388
    energy mae [kcal/mol]      1.731    1.928
    forces mae [kcal/mol/Å]    3.520    3.769
epoch:  16                    train:   valid:
    loss [a.u.]               12.465    9.474
    energy mae [kcal/mol]      1.499    1.412
    forces mae [kcal/mol/Å]    3.372    2.958
epoch:  17                    train:   valid:
    loss [a.u.]               10.487    7.631
    energy mae [kcal/mol]      1.457    1.041
    forces mae [kcal/mol/Å]    3.085    2.664
epoch:  18                    train:   valid:
    loss [a.u.]               11.442    8.065
    energy mae [kcal/mol]      1.522    0.983
    forces mae [kcal/mol/Å]    3.234    2.783
epoch:  19                    train:   valid:
    loss [a.u.]                9.678   11.261
    energy mae [kcal/mol]      1.424    2.474
    forces mae [kcal/mol/Å]    2.981    2.852
epoch:  20                    train:   valid:
    loss [a.u.]                8.938    7.979
    energy mae [kcal/mol]      1.549    0.960
    forces mae [kcal/mol/Å]    2.775    2.928
epoch:  21                    train:   valid:
    loss [a.u.]                8.589    6.849
    energy mae [kcal/mol]      1.344    1.351
    forces mae [kcal/mol/Å]    2.770    2.497
epoch:  22                    train:   valid:
    loss [a.u.]                6.768   14.650
    energy mae [kcal/mol]      1.230    4.167
    forces mae [kcal/mol/Å]    2.466    2.446
epoch:  23                    train:   valid:
    loss [a.u.]                7.701    6.763
    energy mae [kcal/mol]      1.289    1.685
    forces mae [kcal/mol/Å]    2.625    2.279
epoch:  24                    train:   valid:
    loss [a.u.]                6.788    6.545
    energy mae [kcal/mol]      1.036    0.752
    forces mae [kcal/mol/Å]    2.522    2.520
epoch:  25                    train:   valid:
    loss [a.u.]                6.357    6.105
    energy mae [kcal/mol]      1.389    1.183
    forces mae [kcal/mol/Å]    2.313    2.469
epoch:  26                    train:   valid:
    loss [a.u.]                6.928    5.101
    energy mae [kcal/mol]      1.424    1.209
    forces mae [kcal/mol/Å]    2.403    2.120
epoch:  27                    train:   valid:
    loss [a.u.]                5.723    4.971
    energy mae [kcal/mol]      1.118    0.576
    forces mae [kcal/mol/Å]    2.263    2.182
epoch:  28                    train:   valid:
    loss [a.u.]                5.307    6.905
    energy mae [kcal/mol]      0.959    0.617
    forces mae [kcal/mol/Å]    2.226    2.763
epoch:  29                    train:   valid:
    loss [a.u.]                4.863    6.667
    energy mae [kcal/mol]      1.120    2.057
    forces mae [kcal/mol/Å]    2.064    2.120
epoch:  30                    train:   valid:
    loss [a.u.]                5.157    5.295
    energy mae [kcal/mol]      1.157    1.480
    forces mae [kcal/mol/Å]    2.115    2.083
epoch:  31                    train:   valid:
    loss [a.u.]                5.064    5.305
    energy mae [kcal/mol]      1.061    1.803
    forces mae [kcal/mol/Å]    2.135    1.934
epoch:  32                    train:   valid:
    loss [a.u.]                5.590    3.384
    energy mae [kcal/mol]      1.073    0.494
    forces mae [kcal/mol/Å]    2.242    1.830
epoch:  33                    train:   valid:
    loss [a.u.]                3.675    4.018
    energy mae [kcal/mol]      0.891    0.784
    forces mae [kcal/mol/Å]    1.810    1.983
epoch:  34                    train:   valid:
    loss [a.u.]                4.757    3.721
    energy mae [kcal/mol]      1.216    0.734
    forces mae [kcal/mol/Å]    1.983    1.847
epoch:  35                    train:   valid:
    loss [a.u.]                4.804    3.326
    energy mae [kcal/mol]      1.054    0.506
    forces mae [kcal/mol/Å]    2.058    1.835
epoch:  36                    train:   valid:
    loss [a.u.]                3.732    3.326
    energy mae [kcal/mol]      0.759    0.479
    forces mae [kcal/mol/Å]    1.872    1.888
epoch:  37                    train:   valid:
    loss [a.u.]                4.103    3.825
    energy mae [kcal/mol]      0.953    1.059
    forces mae [kcal/mol/Å]    1.911    1.927
epoch:  38                    train:   valid:
    loss [a.u.]                4.199    5.174
    energy mae [kcal/mol]      1.206    0.752
    forces mae [kcal/mol/Å]    1.844    2.273
epoch:  39                    train:   valid:
    loss [a.u.]                3.006    2.711
    energy mae [kcal/mol]      0.632    0.444
    forces mae [kcal/mol/Å]    1.703    1.620
epoch:  40                    train:   valid:
    loss [a.u.]                3.244    4.489
    energy mae [kcal/mol]      0.876    1.486
    forces mae [kcal/mol/Å]    1.711    1.855
epoch:  41                    train:   valid:
    loss [a.u.]                3.123    7.015
    energy mae [kcal/mol]      0.890    2.608
    forces mae [kcal/mol/Å]    1.650    1.996
epoch:  42                    train:   valid:
    loss [a.u.]                4.496    4.365
    energy mae [kcal/mol]      1.126    1.243
    forces mae [kcal/mol/Å]    1.974    1.948
epoch:  43                    train:   valid:
    loss [a.u.]                2.913    2.353
    energy mae [kcal/mol]      0.679    0.432
    forces mae [kcal/mol/Å]    1.657    1.536
epoch:  44                    train:   valid:
    loss [a.u.]                2.285    2.577
    energy mae [kcal/mol]      0.591    0.423
    forces mae [kcal/mol/Å]    1.478    1.625
epoch:  45                    train:   valid:
    loss [a.u.]                2.743    2.041
    energy mae [kcal/mol]      0.779    0.511
    forces mae [kcal/mol/Å]    1.566    1.398
epoch:  46                    train:   valid:
    loss [a.u.]                2.836    2.580
    energy mae [kcal/mol]      0.570    0.662
    forces mae [kcal/mol/Å]    1.663    1.510
epoch:  47                    train:   valid:
    loss [a.u.]                2.846    3.848
    energy mae [kcal/mol]      0.787    1.620
    forces mae [kcal/mol/Å]    1.603    1.650
epoch:  48                    train:   valid:
    loss [a.u.]                2.563    1.875
    energy mae [kcal/mol]      0.688    0.410
    forces mae [kcal/mol/Å]    1.540    1.368
epoch:  49                    train:   valid:
    loss [a.u.]                2.900    2.622
    energy mae [kcal/mol]      0.823    0.975
    forces mae [kcal/mol/Å]    1.591    1.541
epoch:  50                    train:   valid:
    loss [a.u.]                3.487    2.985
    energy mae [kcal/mol]      0.900    0.484
    forces mae [kcal/mol/Å]    1.725    1.705
epoch:  51                    train:   valid:
    loss [a.u.]                3.613    4.224
    energy mae [kcal/mol]      1.081    0.633
    forces mae [kcal/mol/Å]    1.714    2.192
epoch:  52                    train:   valid:
    loss [a.u.]                2.610    2.156
    energy mae [kcal/mol]      0.686    0.734
    forces mae [kcal/mol/Å]    1.544    1.361
epoch:  53                    train:   valid:
    loss [a.u.]                2.147    1.746
    energy mae [kcal/mol]      0.652    0.349
    forces mae [kcal/mol/Å]    1.402    1.291
epoch:  54                    train:   valid:
    loss [a.u.]                3.319    4.700
    energy mae [kcal/mol]      0.889    0.662
    forces mae [kcal/mol/Å]    1.708    2.355
epoch:  55                    train:   valid:
    loss [a.u.]                3.721    2.377
    energy mae [kcal/mol]      0.980    1.106
    forces mae [kcal/mol/Å]    1.801    1.362
epoch:  56                    train:   valid:
    loss [a.u.]                2.098    2.240
    energy mae [kcal/mol]      0.688    0.606
    forces mae [kcal/mol/Å]    1.377    1.456
epoch:  57                    train:   valid:
    loss [a.u.]                1.907    1.330
    energy mae [kcal/mol]      0.534    0.598
    forces mae [kcal/mol/Å]    1.353    1.077
epoch:  58                    train:   valid:
    loss [a.u.]                1.924    1.640
    energy mae [kcal/mol]      0.637    0.311
    forces mae [kcal/mol/Å]    1.300    1.276
epoch:  59                    train:   valid:
    loss [a.u.]                2.122    2.328
    energy mae [kcal/mol]      0.797    0.984
    forces mae [kcal/mol/Å]    1.347    1.337
epoch:  60                    train:   valid:
    loss [a.u.]                1.857    1.852
    energy mae [kcal/mol]      0.541    0.712
    forces mae [kcal/mol/Å]    1.325    1.254
epoch:  61                    train:   valid:
    loss [a.u.]                2.010    2.182
    energy mae [kcal/mol]      0.580    0.303
    forces mae [kcal/mol/Å]    1.365    1.520
epoch:  62                    train:   valid:
    loss [a.u.]                2.462    2.200
    energy mae [kcal/mol]      0.741    0.444
    forces mae [kcal/mol/Å]    1.465    1.486
epoch:  63                    train:   valid:
    loss [a.u.]                1.674    1.520
    energy mae [kcal/mol]      0.493    0.508
    forces mae [kcal/mol/Å]    1.268    1.236
epoch:  64                    train:   valid:
    loss [a.u.]                4.390    4.820
    energy mae [kcal/mol]      0.742    1.436
    forces mae [kcal/mol/Å]    1.979    1.987
epoch:  65                    train:   valid:
    loss [a.u.]                3.699    2.103
    energy mae [kcal/mol]      0.808    0.496
    forces mae [kcal/mol/Å]    1.802    1.486
epoch:  66                    train:   valid:
    loss [a.u.]                1.601    1.467
    energy mae [kcal/mol]      0.557    0.425
    forces mae [kcal/mol/Å]    1.220    1.191
epoch:  67                    train:   valid:
    loss [a.u.]                1.972    1.711
    energy mae [kcal/mol]      0.601    0.523
    forces mae [kcal/mol/Å]    1.355    1.237
epoch:  68                    train:   valid:
    loss [a.u.]                1.483    1.373
    energy mae [kcal/mol]      0.514    0.827
    forces mae [kcal/mol/Å]    1.187    1.041
epoch:  69                    train:   valid:
    loss [a.u.]                1.825    2.062
    energy mae [kcal/mol]      0.640    0.632
    forces mae [kcal/mol/Å]    1.271    1.388
epoch:  70                    train:   valid:
    loss [a.u.]                1.973    1.732
    energy mae [kcal/mol]      0.617    0.392
    forces mae [kcal/mol/Å]    1.332    1.283
epoch:  71                    train:   valid:
    loss [a.u.]                1.816    3.849
    energy mae [kcal/mol]      0.530    1.314
    forces mae [kcal/mol/Å]    1.309    1.911
epoch:  72                    train:   valid:
    loss [a.u.]                1.774    2.064
    energy mae [kcal/mol]      0.585    0.539
    forces mae [kcal/mol/Å]    1.274    1.477
epoch:  73                    train:   valid:
    loss [a.u.]                2.003    1.536
    energy mae [kcal/mol]      0.611    0.576
    forces mae [kcal/mol/Å]    1.362    1.175
epoch:  74                    train:   valid:
    loss [a.u.]                1.913    2.213
    energy mae [kcal/mol]      0.644    0.961
    forces mae [kcal/mol/Å]    1.312    1.252
epoch:  75                    train:   valid:
    loss [a.u.]                1.862    2.396
    energy mae [kcal/mol]      0.643    0.364
    forces mae [kcal/mol/Å]    1.277    1.615
epoch:  76                    train:   valid:
    loss [a.u.]                1.563    1.570
    energy mae [kcal/mol]      0.413    0.556
    forces mae [kcal/mol/Å]    1.239    1.168
epoch:  77                    train:   valid:
    loss [a.u.]                1.869    2.005
    energy mae [kcal/mol]      0.769    0.539
    forces mae [kcal/mol/Å]    1.247    1.451
epoch:  78                    train:   valid:
    loss [a.u.]                1.565    1.496
    energy mae [kcal/mol]      0.694    0.311
    forces mae [kcal/mol/Å]    1.154    1.227
epoch:  79                    train:   valid:
    loss [a.u.]                1.476    1.275
    energy mae [kcal/mol]      0.551    0.611
    forces mae [kcal/mol/Å]    1.162    1.022
epoch:  80                    train:   valid:
    loss [a.u.]                3.263    4.741
    energy mae [kcal/mol]      0.807    0.839
    forces mae [kcal/mol/Å]    1.653    2.087
epoch:  81                    train:   valid:
    loss [a.u.]                2.422    1.622
    energy mae [kcal/mol]      0.643    0.434
    forces mae [kcal/mol/Å]    1.491    1.230
epoch:  82                    train:   valid:
    loss [a.u.]                1.173    1.157
    energy mae [kcal/mol]      0.443    0.438
    forces mae [kcal/mol/Å]    1.043    1.037
epoch:  83                    train:   valid:
    loss [a.u.]                2.153    1.241
    energy mae [kcal/mol]      0.745    0.636
    forces mae [kcal/mol/Å]    1.363    1.026
epoch:  84                    train:   valid:
    loss [a.u.]                2.218    2.663
    energy mae [kcal/mol]      0.616    0.495
    forces mae [kcal/mol/Å]    1.437    1.517
epoch:  85                    train:   valid:
    loss [a.u.]                1.928    1.388
    energy mae [kcal/mol]      0.558    0.297
    forces mae [kcal/mol/Å]    1.347    1.223
epoch:  86                    train:   valid:
    loss [a.u.]                1.785    1.662
    energy mae [kcal/mol]      0.759    0.458
    forces mae [kcal/mol/Å]    1.183    1.247
epoch:  87                    train:   valid:
    loss [a.u.]                1.565    1.159
    energy mae [kcal/mol]      0.587    0.661
    forces mae [kcal/mol/Å]    1.195    0.979
epoch:  88                    train:   valid:
    loss [a.u.]                1.192    1.318
    energy mae [kcal/mol]      0.526    0.453
    forces mae [kcal/mol/Å]    1.040    1.105
epoch:  89                    train:   valid:
    loss [a.u.]                1.398    2.636
    energy mae [kcal/mol]      0.494    0.424
    forces mae [kcal/mol/Å]    1.141    1.771
epoch:  90                    train:   valid:
    loss [a.u.]                1.490    1.044
    energy mae [kcal/mol]      0.487    0.582
    forces mae [kcal/mol/Å]    1.188    0.941
epoch:  91                    train:   valid:
    loss [a.u.]                1.229    2.328
    energy mae [kcal/mol]      0.515    0.928
    forces mae [kcal/mol/Å]    1.047    1.350
epoch:  92                    train:   valid:
    loss [a.u.]                1.562    0.969
    energy mae [kcal/mol]      0.464    0.452
    forces mae [kcal/mol/Å]    1.218    0.941
epoch:  93                    train:   valid:
    loss [a.u.]                1.891    1.960
    energy mae [kcal/mol]      0.673    0.906
    forces mae [kcal/mol/Å]    1.270    1.267
epoch:  94                    train:   valid:
    loss [a.u.]                1.389    1.734
    energy mae [kcal/mol]      0.485    0.517
    forces mae [kcal/mol/Å]    1.126    1.242
epoch:  95                    train:   valid:
    loss [a.u.]                2.041    4.921
    energy mae [kcal/mol]      0.612    0.908
    forces mae [kcal/mol/Å]    1.369    2.325
epoch:  96                    train:   valid:
    loss [a.u.]                2.202    1.152
    energy mae [kcal/mol]      0.618    0.320
    forces mae [kcal/mol/Å]    1.433    1.067
epoch:  97                    train:   valid:
    loss [a.u.]                2.549    2.517
    energy mae [kcal/mol]      0.768    0.582
    forces mae [kcal/mol/Å]    1.481    1.556
epoch:  98                    train:   valid:
    loss [a.u.]                2.426    2.014
    energy mae [kcal/mol]      0.665    0.424
    forces mae [kcal/mol/Å]    1.492    1.364
epoch:  99                    train:   valid:
    loss [a.u.]                1.519    1.761
    energy mae [kcal/mol]      0.579    0.347
    forces mae [kcal/mol/Å]    1.167    1.301
epoch:  100                    train:   valid:
    loss [a.u.]                1.584    1.085
    energy mae [kcal/mol]      0.547    0.286
    forces mae [kcal/mol/Å]    1.211    1.066

Even a very simple model can achieve energy prediction errors well below 1 kcal/mol (“chemical accuracy”) after training for a hundred epochs (the model is not converged yet, but it’s good enough for now).

Once the model is trained, we can use it as a force field, e.g. to do structure optimization, or to run molecular dynamics simulations. To demonstrate this, we use the atomic simulation environment (ASE) and py3Dmol (for visualization), so a few additional imports are necessary.

[10]:
import io
import ase
import ase.calculators.calculator as ase_calc
import ase.io as ase_io
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary, ZeroRotation
from ase.md.verlet import VelocityVerlet
import ase.optimize as ase_opt
import matplotlib.pyplot as plt
import py3Dmol

Before we can use our model as a Calculator in ASE, we need to write a small interface. The following implementation is very bare bones, but it does the trick for a few simple demonstrations.

[11]:
@jax.jit
def evaluate_energies_and_forces(atomic_numbers, positions, dst_idx, src_idx):
  return message_passing_model.apply(params,
    atomic_numbers=atomic_numbers,
    positions=positions,
    dst_idx=dst_idx,
    src_idx=src_idx,
  )


class MessagePassingCalculator(ase_calc.Calculator):
  implemented_properties = ["energy", "forces"]

  def calculate(self, atoms, properties, system_changes = ase.calculators.calculator.all_changes):
    ase_calc.Calculator.calculate(self, atoms, properties, system_changes)
    dst_idx, src_idx = e3x.ops.sparse_pairwise_indices(len(atoms))
    energy, forces = evaluate_energies_and_forces(
      atomic_numbers=atoms.get_atomic_numbers(),
      positions=atoms.get_positions(),
      dst_idx=dst_idx,
      src_idx=src_idx
    )
    self.results['energy'] = energy * ase.units.kcal/ase.units.mol
    self.results['forces'] = forces * ase.units.kcal/ase.units.mol

At first, let’s find an optimized structure for ethanol. We arbitrarily initialize the structure to some entry in our training data and then use ASE’s BFGS optimizer to optimize it.

[12]:
# Initialize atoms object and attach calculator.
atoms = ase.Atoms(train_data['atomic_numbers'], train_data['positions'][0])
atoms.set_calculator(MessagePassingCalculator())

# Run structure optimization with BFGS.
_ = ase_opt.BFGS(atoms).run(fmax=0.05)
      Step     Time          Energy         fmax
BFGS:    0 08:51:53        0.048474        3.0194
BFGS:    1 08:51:53       -0.251758        1.6836
BFGS:    2 08:51:53       -0.396961        1.4604
BFGS:    3 08:51:53       -0.464969        0.9650
BFGS:    4 08:51:53       -0.499852        0.4686
BFGS:    5 08:51:53       -0.516434        0.3922
BFGS:    6 08:51:53       -0.542414        0.4775
BFGS:    7 08:51:53       -0.552331        0.2690
BFGS:    8 08:51:53       -0.558380        0.2286
BFGS:    9 08:51:53       -0.563047        0.2404
BFGS:   10 08:51:53       -0.567419        0.2482
BFGS:   11 08:51:53       -0.570238        0.1684
BFGS:   12 08:51:53       -0.571843        0.1195
BFGS:   13 08:51:53       -0.573060        0.1111
BFGS:   14 08:51:53       -0.574244        0.1020
BFGS:   15 08:51:53       -0.575149        0.0860
BFGS:   16 08:51:53       -0.575878        0.1010
BFGS:   17 08:51:53       -0.576817        0.1314
BFGS:   18 08:51:53       -0.578436        0.1580
BFGS:   19 08:51:53       -0.580875        0.1718
BFGS:   20 08:51:53       -0.583505        0.1523
BFGS:   21 08:51:53       -0.585652        0.1337
BFGS:   22 08:51:53       -0.587525        0.1372
BFGS:   23 08:51:53       -0.589606        0.1612
BFGS:   24 08:51:53       -0.591745        0.1458
BFGS:   25 08:51:53       -0.593183        0.1071
BFGS:   26 08:51:53       -0.593988        0.0820
BFGS:   27 08:51:53       -0.594515        0.0670
BFGS:   28 08:51:53       -0.594925        0.0593
BFGS:   29 08:51:53       -0.595153        0.0346

A promising start! Note that ASE uses electron volts (eV) as energy units, so the final structure has an energy that is lower by roughly 0.6 eV, or about 14 kcal/mol. Let’s check how the optimized structure looks.

[13]:
# Write structure to xyz file.
xyz = io.StringIO()
ase_io.write(xyz, atoms, format='xyz')

# Visualize the structure with py3Dmol.
view = py3Dmol.view()
view.addModel(xyz.getvalue(), 'xyz')
view.setStyle({'stick': {'radius': 0.15}, 'sphere': {'scale': 0.25}})
view.show()

You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the 3dmol extension:
jupyter labextension install jupyterlab_3dmol

Now let’s try to run a molecular dynamics simulation. We draw initial momenta from a Maxwell-Boltzmann distribution at 300 K, remove center of mass translation and rotation, and then use Verlet integration to integrate the equations of motion for 2000 steps using a timestep of 0.5 fs. For visualization purposes, we also save each frame and keep track of potential, kinetic, and total energies.

[14]:
# Parameters.
temperature = 300
timestep_fs = 0.5
num_steps = 2000

# Draw initial momenta.
MaxwellBoltzmannDistribution(atoms, temperature_K=temperature)
Stationary(atoms)  # Remove center of mass translation.
ZeroRotation(atoms)  # Remove rotations.

# Initialize Velocity Verlet integrator.
integrator = VelocityVerlet(atoms, timestep=timestep_fs*ase.units.fs)

# Run molecular dynamics.
frames = np.zeros((num_steps, len(atoms), 3))
potential_energy = np.zeros((num_steps,))
kinetic_energy = np.zeros((num_steps,))
total_energy = np.zeros((num_steps,))
for i in range(num_steps):
  # Run 1 time step.
  integrator.run(1)
  # Save current frame and keep track of energies.
  frames[i] = atoms.get_positions()
  potential_energy[i] = atoms.get_potential_energy()
  kinetic_energy[i] = atoms.get_kinetic_energy()
  total_energy[i] = atoms.get_total_energy()
  # Occasionally print progress.
  if i % 100 == 0:
    print(f"step {i:5d} epot {potential_energy[i]: 5.3f} ekin {kinetic_energy[i]: 5.3f} etot {total_energy[i]: 5.3f}")
step     0 epot -0.588 ekin  0.166 etot -0.422
step   100 epot -0.498 ekin  0.077 etot -0.421
step   200 epot -0.476 ekin  0.055 etot -0.421
step   300 epot -0.514 ekin  0.092 etot -0.422
step   400 epot -0.507 ekin  0.085 etot -0.422
step   500 epot -0.499 ekin  0.078 etot -0.421
step   600 epot -0.502 ekin  0.080 etot -0.422
step   700 epot -0.518 ekin  0.097 etot -0.422
step   800 epot -0.504 ekin  0.083 etot -0.421
step   900 epot -0.523 ekin  0.102 etot -0.422
step  1000 epot -0.498 ekin  0.076 etot -0.422
step  1100 epot -0.500 ekin  0.078 etot -0.422
step  1200 epot -0.528 ekin  0.107 etot -0.422
step  1300 epot -0.487 ekin  0.065 etot -0.422
step  1400 epot -0.497 ekin  0.075 etot -0.422
step  1500 epot -0.527 ekin  0.105 etot -0.422
step  1600 epot -0.523 ekin  0.102 etot -0.422
step  1700 epot -0.514 ekin  0.092 etot -0.422
step  1800 epot -0.495 ekin  0.074 etot -0.422
step  1900 epot -0.501 ekin  0.080 etot -0.421

Let’s have a look at the trajectory.

[15]:
view.getModel().setCoordinates(frames, 'array')
view.animate({'loop': 'forward', 'interval': 0.1})
view.show()

You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the 3dmol extension:
jupyter labextension install jupyterlab_3dmol

This looks reasonable. As a final check, let’s plot the potential, kinetic, and total energies.

[16]:
%matplotlib inline
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
plt.xlabel('time [fs]')
plt.ylabel('energy [eV]')
time = np.arange(num_steps) * timestep_fs
plt.plot(time, potential_energy, label='potential energy')
plt.plot(time, kinetic_energy, label='kinetic energy')
plt.plot(time, total_energy, label='total energy')
plt.legend()
plt.grid()
../_images/examples_md17_ethanol_32_0.png

As expected, the potential and kinetic energies fluctuate during the trajectory. However, their sum, the total energy, is constant, as it should be (energy conservation). Tiny deviations in the total energy stem primarily from inaccuracies in the Velocity Verlet integration and can be reduced by choosing a smaller timestep.