ANN assisted TPS on capped alanine dipeptide 1#

In this notebook you will learn:

  • how to setup an ANN assisted TPS simulation on a simple molecular system

  • how to transform atomistic input coordinates to the descriptor space in which the ANN learns

This notebook assumes some familiarity with openpathsampling and aimmd, please try the Toy notebooks first.

%matplotlib inline
import os
import aimmd
import torch
import numpy as np
import mdtraj as md
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = 9, 6  # make the figures a bit bigger
import openpathsampling as paths
import openpathsampling.engines.openmm as peng_omm
import simtk.openmm as mm
import simtk.unit as unit
from simtk.openmm import app
from openmmtools.integrators import VVVRIntegrator
# setup logging
# executing this file sets the variable LOGCONFIG, which is a dictionary of logging presets 
%run ../resources/logconf.py
cur_dir = os.path.abspath(os.getcwd())  # needed for relative paths to initial TP
# change to the working directory of choice
wdir = '/home/tb/hejung/DATA/aimmd_scratch/SimData_pytorch_ala/'
#wdir = None
if wdir is not None:
    if not os.path.isdir(wdir):
        os.mkdir(wdir)
    os.chdir(wdir)

# setup logging in that directory
import logging.config
logging.config.dictConfig(LOGCONFIG)

initial TP and states#

Generating the initial TP is out of scope of this tutorial, please consult the openpathsampling examples for that.

# load initial TP for ala (generated at high temp, 400K)
# we load the h5py trajectory and define the states from scratch
# for better machine interoperability and smaller footprint in repo
initialTP_md = md.load(os.path.join(cur_dir, 'ala_400K_TP_low_barrier.h5'))
initialTP = peng_omm.trajectory_from_mdtraj(initialTP_md)
template = peng_omm.snapshot_from_pdb(os.path.join(cur_dir, "AD_initial_frame.pdb"))

# define the CVs
psi = paths.MDTrajFunctionCV("psi", md.compute_dihedrals, template.topology, indices=[[6,8,14,16]])
phi = paths.MDTrajFunctionCV("phi", md.compute_dihedrals, template.topology, indices=[[4,6,8,14]])

# define the states
deg = 180.0/np.pi
C_7eq = (paths.PeriodicCVDefinedVolume(phi, lambda_min=-180/deg, lambda_max=0/deg,
                                     period_min=-np.pi, period_max=np.pi) &
         paths.PeriodicCVDefinedVolume(psi, lambda_min=120/deg, lambda_max=200/deg,
                                     period_min=-np.pi, period_max=np.pi)
        ).named("C_7eq")

alpha_R = (paths.PeriodicCVDefinedVolume(phi, -180/deg, 0/deg, -np.pi, np.pi) &
           paths.PeriodicCVDefinedVolume(psi, -50/deg, 30/deg, -np.pi, np.pi)).named("alpha_R")

engine setup#

forcefield = app.ForceField('amber99sbildn.xml', 'tip3p.xml')
pdb = app.PDBFile(os.path.join(cur_dir, "AD_initial_frame.pdb"))
system = forcefield.createSystem(pdb.topology,
                                 nonbondedMethod=app.PME,
                                 nonbondedCutoff=1.0*unit.nanometers,
                                 constraints=app.HBonds,
                                 rigidWater=True,
                                 ewaldErrorTolerance=0.0005
                                 )
integrator = VVVRIntegrator(300*unit.kelvin,  # T
                            1.0/unit.picoseconds,  # fric
                            2.0*unit.femtoseconds)  # dt
integrator.setConstraintTolerance(0.00001)
engine_options = {
    'n_frames_max': 20000,
    'nsteps_per_frame': 10,
}
engine = peng_omm.Engine(template.topology,
                         system,
                         integrator,
                         #openmm_properties={'CudaDeviceIndex': '0', 'CudaPrecision': 'single'},
                         options=engine_options,
                         )
engine.name = '300K'

Transforming atomistic coordinates to training descriptors: Symmetry functions and internal coordinates#

# The slow part: symmetry functions

# symmerty functions first
cutoff = 0.6 # consider G5 worst case scenarios for cutoff!
# g2_parms are expected to be a list of lists, each sublist needs to contain [eta, r_s] 
g2_parms = [[200., 0.1], [200., 0.25], [200., 0.4]]#, [200., 0.55], [200., 0.7], [200., 0.85]]

# g5_parms are also expected to be a list of lists, here each sublist needs to contain
# [eta, r_s, zeta, lambda]
# here we just create the list of [eta, r_s] for G5,
# we add the missing parmeters below since:
# we use the same zetas, which influence the sharpness of the angular peaks, for all G5s at different probing radii
# and we use the same two lambda values [-1, +1], which influence the location of the maximum, i.e. at angle=0 or at angle=\pi
g5_etas_rs = [[120., 0.1],
              [120., 0.25],
              [120., 0.4],
              #[120., 0.55],
              #[120., 0.7],
              #[120., 0.85]
             ]
# all zetas for all G5, high zeta means sharp angular resultion
# Note: zeta must be an even number, using these powers of 2 empirically works well
zetas = [
        1,
        2,
        4,
        16,
        64,
        ]
# construct g5_parms from previously defined values
# combine every eta, r_s with all zetas and both possible lambda values
g5_parms = [[eta, r_s, zeta, lamb] for (eta, r_s) in g5_etas_rs for zeta in zetas for lamb in [+1., -1.]]

# combine G2 and G5 params into one list to pass to the SF transformation function
g_parms = [g2_parms, g5_parms]

mol_idxs, solv_idxs = aimmd.coords.symmetry.generate_indices(template.topology.mdtraj,
                                                            ['HOH'],
                                                            solvent_atelements=[['O', 'H']],
                                                            reactant_selection='not resname HOH')

sf_parms = {'mol_idxs': mol_idxs, 'solv_idxs': solv_idxs, 'g_parms': g_parms,
            'cutoff': cutoff, 'n_per_solv': [[1., 2.]], 'rho_solv': [33.]}

sf_transform = paths.MDTrajFunctionCV('sf_transform', aimmd.coords.symmetry.transform,  # transform is an alias for sf
                                      template.topology, **sf_parms)
# The fast part: internal coordinates

pairs, triples, quadruples = aimmd.coords.internal.generate_indices(template.topology.mdtraj, source_idx=0)

ic_parms = {'pairs': pairs, 'triples': triples, 'quadruples': quadruples}

ic_transform = paths.MDTrajFunctionCV('ic_transform', aimmd.coords.internal.transform,  # transform is an alias for ic
                                      template.topology, **ic_parms)
# set this to True to do the full setup with symmetry functions, will be much slower
# leave it at False and we will not calculate any symmetry functions and train on internal coords only
# setting this to True will increase runtime by a factor of approximately 10 (depending on the choice of symmetry function cutoff)
we_have_time = True

if we_have_time:
    # create this little helper function to concatenate the descriptors we are interested in
    def transform_func(mdtra, sf_parms, ic_parms):
        import mdtraj as md
        from aimmd.coords.symmetry import sf
        from aimmd.coords.internal import ic
        import numpy as np
        return np.concatenate([sf(mdtra, **sf_parms),
                               ic(mdtra, **ic_parms)],
                              axis=1)

    descriptor_transform = paths.MDTrajFunctionCV('descriptor_transform', transform_func,
                                                  template.topology,
                                                  sf_parms=sf_parms, 
                                                  ic_parms=ic_parms,
                                                  cv_scalarize_numpy_singletons=False).with_diskcache()
else:
    # create this little helper function to concatenate the descriptors we are interested in
    def transform_func(mdtra, ic_parms):
        import mdtraj as md
        from aimmd.coords.internal import ic
        import numpy as np
        return ic(mdtra, **ic_parms)

    descriptor_transform = paths.MDTrajFunctionCV('descriptor_transform', transform_func,
                                                  template.topology,
                                                  ic_parms=ic_parms,
                                                  cv_scalarize_numpy_singletons=False).with_diskcache()
# lets have a look a the transformed coordinate values
# they should be all approximately in [0, 1]
trans_coords = descriptor_transform(template)
# also set cv_ndim, because this will be the number of network inputs
cv_ndim = trans_coords.shape[0]
plt.hist(trans_coords);
../../../_images/a2136793d4eeefc572af7d4acfd6ebedc3a29106b48ab6323961e2a3bee9070f.png

aimmd setup: create an ANN, RCModel, etc#

import torch.nn.functional as F
# create a pyramidal feed-forward architecture with a ResNet top part
n_lay_pyramid = 4  # number of layers in the pyramid
n_unit_top = 10  # number of units per layer in the top ResNet part
n_lay_top = 2  # number of ResUnits in the top part, results in n_lay_top * residual_n_skip layers
n_unit_base = descriptor_transform(template).shape[0]  # number of inputs to the NN/number of units in the first layer
print('number of input descriptors: ', n_unit_base)

# calculate the factor by which we reduce the number of units per layer in the pyramidal part from layer to layer
fact = (n_unit_top / n_unit_base)**(1./(n_lay_pyramid-1))

ffnet = aimmd.pytorch.networks.FFNet(n_in=cv_ndim,
                                     n_hidden=[max(n_unit_top, int(n_unit_base * fact**i)) for i in range(n_lay_pyramid)],  # 4 hidden layer pyramidal network
                                     activation=F.elu,
                                    )

resnet = aimmd.pytorch.networks.ResNet(n_units=n_unit_top, n_blocks=n_lay_top)

torch_model = aimmd.pytorch.networks.ModuleStack(n_out=1,  # using a single output we will predict only p_B and use a binomial loss
                                                           # we could have also used n_out=n_states to use a multinomial loss and predict all states,
                                                           # but this is probably only worthwhile if n_states > 2 as it would increase the number of free parameters in the NN
                                                 modules=[ffnet, resnet],  # modules is a list of initialized torch.nn.Modules from aimmd.pytorch.networks
                                                )

# move model to GPU if CUDA is available
if torch.cuda.is_available():
    torch_model = torch_model.to('cuda')

# choose and initialize an optimizer to train the model
optimizer = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
number of input descriptors:  1507
torch.cuda.is_available()
True
aimmd_store = aimmd.Storage(os.path.join(wdir, 'aimmd_storage.h5'), mode='w')
# we take an ExpectedEfficiencyPytorchRCModel,
# this RCmodel scales the learning rate by the expected efficiency factor (1 - n_TP_true / n_TP_expected)**2
model = aimmd.pytorch.EEScalePytorchRCModel(nnet=torch_model,
                                            optimizer=optimizer,
                                            states=[C_7eq, alpha_R],
                                            ee_params={'lr_0': 1e-3,  
                                                       'lr_min': 5e-5,  # lr_min = lr_0 / 20 is a good choice empirically
                                                       'epochs_per_train': 5,
                                                       'interval': 5,
                                                       'window': 75,
                                                       },
                                            descriptor_transform=descriptor_transform,
                                            cache_file=aimmd_store,
                                            )
trainset = aimmd.TrainSet(n_states=2)
trainhook = aimmd.ops.TrainingHook(model, trainset)
storehook = aimmd.ops.AimmdStorageHook(aimmd_store, model, trainset)
densityhook = aimmd.ops.DensityCollectionHook(model)
selector = aimmd.ops.RCModelSelector(model=model,
                                     states=[C_7eq, alpha_R],
                                     distribution='lorentzian',
                                     scale=1.0,
                                     )

OPS setup: TPS strategy and sampled transitions#

network = paths.TPSNetwork.from_states_all_to_all([C_7eq, alpha_R])
move_scheme = paths.MoveScheme(network=network)
beta = 1.0 / (engine.integrator.getTemperature() * unit.BOLTZMANN_CONSTANT_kB)
modifier = paths.RandomVelocities(beta=beta, engine=engine)
tw_strategy = paths.strategies.TwoWayShootingStrategy(modifier=modifier,
                                                      selector=selector,
                                                      engine=engine,
                                                      group='TwoWayShooting')
move_scheme.append(tw_strategy)
move_scheme.append(paths.strategies.OrganizeByMoveGroupStrategy())
move_scheme.build_move_decision_tree()
initial_conditions = move_scheme.initial_conditions_from_trajectories(initialTP)
No missing ensembles.
No extra ensembles.
storage = paths.Storage('ala_low_barrier_SF+IC_pytorch.nc', mode='w', template=template)


sampler = paths.PathSampling(storage=storage,
                             sample_set=initial_conditions,
                             move_scheme=move_scheme)
sampler.attach_hook(trainhook)
sampler.attach_hook(storehook)
sampler.attach_hook(densityhook)
sampler.run(2000)
Working on Monte Carlo cycle number 2000
Running for 3 hours 39 minutes 25 seconds -  6.59 seconds per step
Estimated time remaining: 6 seconds
DONE! Completed 2000 Monte Carlo cycles.
log_train = np.array(model.log_train_decision)
lr = log_train[:,1]
plt.plot(lr, label='lr')
# see where we really trained: everywhere where train=True
# set lr_true to NaN anywhere where we did not train to have a nice plot
lr_true = lr
lr_true[log_train[:,0] == False] = np.nan
plt.plot(lr_true, '+', label='True learning')
# lr_min as a guide to the eye
plt.axhline(model.ee_params['lr_min'], label='lr_min', color='lime')
plt.legend()
plt.xlabel('MCStep', size=15);
plt.ylabel('Learning rate', size=15);
../../../_images/5c1669425fcc6b062fbd280aa5fb2f70ebf809ef57e443ded2bac3e3ca570875.png
# resort such that we have a loss value per MCStep, NaN if we did not train at that step
train_loss = []
count = 0
for t in log_train[:, 0]:
    if t:
        train_loss.append(model.log_train_loss[count])
        count += 1
    else:
        train_loss.append([np.nan for _ in range(model.ee_params['epochs_per_train'])])
    
plt.plot(train_loss, '+', label='training loss')
plt.legend();
plt.ylabel('loss per training point', size=15)
plt.xlabel('MCStep', size=15)
plt.tight_layout()
../../../_images/c4e4c1d395e3f99d9259e1cf437d99ed8aef4aea8d587b32e2d3eefa03f94b61.png
# get the number of accepts from OPS storage
accepts = []
for step in storage.steps:
    if step.change.canonical.accepted:
        accepts.append(1.)
    else:
        accepts.append(0.)
p_ex = np.array(model.expected_p)

l, = plt.plot(np.cumsum(trainset.transitions), label='generated');
plt.plot(np.cumsum(accepts), c=l.get_color(), ls='--', label='accepted');
plt.plot(np.cumsum(2*p_ex*(1 - p_ex)),c=l.get_color(), ls=':', label='expected');
plt.plot(np.cumsum(2*p_ex*(1 - p_ex))- np.cumsum(trainset.transitions), label='diff (expected - generated)')
plt.plot(np.linspace(0., len(trainset)/2., len(trainset)), c='k', ls='--', label='maximal', lw=2)
plt.legend(fontsize=12);
plt.ylabel('Cummulative count of TPs', size=15)
plt.xlabel('# MC Step', size=15);
../../../_images/13e50640373df6ca421b1d8571d14c95cc2c27017f7b434c922187f33b74936a.png

HIPR#

hipr = aimmd.analysis.HIPRanalysis(model, trainset)
hipr_plus_losses, hipr_plus_stds = hipr.do_hipr_plus(25)
loss_diffs = hipr_plus_losses[:-1] - hipr_plus_losses[-1]  # hipr_losses[-1] is the reference loss over the unaltered trainset

plt.bar(np.arange(len(loss_diffs)), loss_diffs, yerr=hipr_plus_stds[:-1])
plt.xlabel('Coordinate index', size=15)
plt.ylabel('Relative importance', size=15);
../../../_images/f2117291b42d64283b8e663c7763730d79d62c9c3d8e6c6cd44b74e15095e015.png
# what are the most important contributors?
max_idxs = np.argsort(loss_diffs)[::-1]
sf_parms = descriptor_transform.kwargs['sf_parms']
ic_parms = descriptor_transform.kwargs['ic_parms']

print('reference loss:', hipr_plus_losses[-1])
for idx in max_idxs[:40]:
    print()
    print('loss for idx {:d}: '.format(idx), hipr_plus_losses[idx])
    print(aimmd.coords.get_involved(idx, sf_parms=sf_parms, ic_parms=ic_parms, solvent_atoms=[['O', 'H']], solvent_resname=['HOH']))
reference loss: 0.3632467651367188

loss for idx 1500:  0.9590144882965089
('SF', ('G5', [120.0, 0.4, 2, -1.0], 21, 'HOH', 'H'))

loss for idx 758:  0.4083429124450684
('SF', ('G5', [120.0, 0.1, 16, -1.0], 21, 'HOH', 'O'))

loss for idx 593:  0.3996204382324219
('SF', ('G5', [120.0, 0.1, 16, -1.0], 16, 'HOH', 'O'))

loss for idx 98:  0.3973171133422852
('SF', ('G5', [120.0, 0.1, 16, -1.0], 1, 'HOH', 'O'))

loss for idx 890:  0.3949944214630128
('SF', ('G5', [120.0, 0.1, 16, -1.0], 3, 'HOH', 'H'))

loss for idx 1493:  0.39363210617065436
('SF', ('G5', [120.0, 0.25, 16, 1.0], 21, 'HOH', 'H'))

loss for idx 1498:  0.3931823140716552
('SF', ('G5', [120.0, 0.4, 1, -1.0], 21, 'HOH', 'H'))

loss for idx 233:  0.39307228401184074
('SF', ('G5', [120.0, 0.25, 1, 1.0], 5, 'HOH', 'O'))

loss for idx 824:  0.39306454734802243
('SF', ('G5', [120.0, 0.1, 16, -1.0], 1, 'HOH', 'H'))

loss for idx 560:  0.3924108724212647
('SF', ('G5', [120.0, 0.1, 16, -1.0], 15, 'HOH', 'O'))

loss for idx 517:  0.3911073852539062
('SF', ('G2', [200.0, 0.1], 14, 'HOH', 'O'))

loss for idx 1501:  0.39014006805419915
('SF', ('G5', [120.0, 0.4, 4, 1.0], 21, 'HOH', 'H'))

loss for idx 583:  0.38994347434997556
('SF', ('G2', [200.0, 0.1], 16, 'HOH', 'O'))

loss for idx 232:  0.38986561683654786
('SF', ('G5', [120.0, 0.1, 64, -1.0], 5, 'HOH', 'O'))

loss for idx 657:  0.389015143737793
('SF', ('G5', [120.0, 0.1, 4, -1.0], 18, 'HOH', 'O'))

loss for idx 1:  0.38901496284484866
('IC', [1, 4])

loss for idx 67:  0.3868617095184327
('SF', ('G5', [120.0, 0.1, 64, -1.0], 0, 'HOH', 'O'))

loss for idx 1451:  0.3855172144317627
('SF', ('G5', [120.0, 0.1, 16, -1.0], 20, 'HOH', 'H'))

loss for idx 857:  0.38538858467102044
('SF', ('G5', [120.0, 0.1, 16, -1.0], 2, 'HOH', 'H'))

loss for idx 32:  0.38455263587951655
('IC', [8, 14, 16])

loss for idx 1187:  0.3843810806274414
('SF', ('G5', [120.0, 0.1, 16, -1.0], 12, 'HOH', 'H'))

loss for idx 100:  0.3841659732055664
('SF', ('G5', [120.0, 0.1, 64, -1.0], 1, 'HOH', 'O'))

loss for idx 1154:  0.38303435478210446
('SF', ('G5', [120.0, 0.1, 16, -1.0], 11, 'HOH', 'H'))

loss for idx 1494:  0.382580640335083
('SF', ('G5', [120.0, 0.25, 16, -1.0], 21, 'HOH', 'H'))

loss for idx 426:  0.38248811836242674
('SF', ('G5', [120.0, 0.1, 4, -1.0], 11, 'HOH', 'O'))

loss for idx 1319:  0.3819457453918457
('SF', ('G5', [120.0, 0.1, 16, -1.0], 16, 'HOH', 'H'))

loss for idx 725:  0.38182475280761724
('SF', ('G5', [120.0, 0.1, 16, -1.0], 20, 'HOH', 'O'))

loss for idx 197:  0.3810885425567627
('SF', ('G5', [120.0, 0.1, 16, -1.0], 4, 'HOH', 'O'))

loss for idx 1352:  0.38021873443603516
('SF', ('G5', [120.0, 0.1, 16, -1.0], 17, 'HOH', 'H'))

loss for idx 1385:  0.3784497678375245
('SF', ('G5', [120.0, 0.1, 16, -1.0], 18, 'HOH', 'H'))

loss for idx 662:  0.37842965339660645
('SF', ('G5', [120.0, 0.25, 1, 1.0], 18, 'HOH', 'O'))

loss for idx 659:  0.37840117805480955
('SF', ('G5', [120.0, 0.1, 16, -1.0], 18, 'HOH', 'O'))

loss for idx 1502:  0.37827670089721677
('SF', ('G5', [120.0, 0.4, 4, -1.0], 21, 'HOH', 'H'))

loss for idx 397:  0.37811344436645505
('SF', ('G5', [120.0, 0.1, 64, -1.0], 10, 'HOH', 'O'))

loss for idx 845:  0.3779867031860352
('SF', ('G5', [120.0, 0.4, 64, 1.0], 1, 'HOH', 'H'))

loss for idx 1121:  0.3773052994537353
('SF', ('G5', [120.0, 0.1, 16, -1.0], 10, 'HOH', 'H'))

loss for idx 791:  0.37707365890502925
('SF', ('G5', [120.0, 0.1, 16, -1.0], 0, 'HOH', 'H'))

loss for idx 459:  0.377033796081543
('SF', ('G5', [120.0, 0.1, 4, -1.0], 12, 'HOH', 'O'))

loss for idx 692:  0.3770057695007325
('SF', ('G5', [120.0, 0.1, 16, -1.0], 19, 'HOH', 'O'))

loss for idx 626:  0.3768632215118408
('SF', ('G5', [120.0, 0.1, 16, -1.0], 17, 'HOH', 'O'))
storage.sync_all()
storage.close()
aimmd_store.close()