import os
import shutil
import tarfile
import numpy as np
import matplotlib.pyplot as plt
from urllib.request import urlretrieve
import torch
from .visualization import _calc_plot_dim
TRAINED_WEIGHT_URLS = {
'base' : 'https://www.dropbox.com/s/hgtud62vg65g5ax/base.pth?dl=1',
'single-channel': 'https://www.dropbox.com/s/cl84b7flx9rguu4/single-channel.pth?dl=1',
'CO-Cl' : 'https://www.dropbox.com/s/pbp6fektz02emvh/CO-Cl.pth?dl=1',
'Xe-Cl' : 'https://www.dropbox.com/s/sc7bd78ybwfj31r/Xe-Cl.pth?dl=1',
'constant-noise': 'https://www.dropbox.com/s/uqwwm9tzm59lyf6/constant-noise.pth?dl=1',
'uniform-noise' : 'https://www.dropbox.com/s/ic4v0f1vc11v988/uniform_noise.pth?dl=1',
'no-gradient' : 'https://www.dropbox.com/s/1gleijnn89itjqt/no-gradient.pth?dl=1',
'matched-tips' : 'https://www.dropbox.com/s/fv16hpl4c9a09xo/matched-tips.pth?dl=1'
}
ELEMENTS = ['H' , 'He',
'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne',
'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar',
'K', 'Ca',
'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn',
'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr',
'Rb', 'Sr',
'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd',
'In', 'Sn', 'Sb', 'Te', 'I', 'Xe'
]
[docs]def download_molecules(save_path='./Molecules', verbose=1):
'''
Download database of molecules.
Arguments:
save_path: str. Path where the molecule xyz files will be saved.
verbose: int 0 or 1. Whether to print progress information.
'''
if not os.path.exists(save_path):
download_url = 'https://www.dropbox.com/s/z4113upq82puzht/Molecules_rebias_210611.tar.gz?dl=1'
temp_file = '.temp_molecule.tar'
if verbose: print('Downloading molecule tar archive...')
temp_file, info = urlretrieve(download_url, temp_file)
if verbose: print('Extracting tar archive...')
with tarfile.open(temp_file, 'r') as f:
base_dir = os.path.normpath(f.getmembers()[0].name).split(os.sep)[0]
f.extractall()
if verbose: print('Done extracting.')
shutil.move(base_dir, save_path)
os.remove(temp_file)
if verbose: print(f'Moved files to {save_path}.')
else:
if verbose: print(f'Target folder {save_path} already exists. Skipping downloading molecules.')
[docs]def download_weights(weights_type='base', weights_dir='./weights', verbose=1):
'''
Download pretrained weights for EDAFMNet model.
Arguments:
weights_type: str. Type of weights to download. One of 'base', 'single-channel', 'CO-Cl', 'Xe-Cl',
'constant-noise', 'uniform-noise', 'no-gradient', or 'matched-tips'. See README at
https://github.com/SINGROUP/ED-AFM for explanations for the different options.
weights_dir: str. Directory where the weight will be downloaded into.
verbose: int 0 or 1. Whether to print information.
'''
weights_path = os.path.join(weights_dir, f'{weights_type}.pth')
if not os.path.exists(weights_path):
try:
download_url = TRAINED_WEIGHT_URLS[weights_type]
except KeyError:
raise ValueError(f'Invalid trained weights type "{weights_type}". '
+ f'Has to be one of {", ".join(TRAINED_WEIGHT_URLS.keys())}.')
if not os.path.exists(weights_dir):
os.makedirs(weights_dir)
if verbose: print(f'Downloading pretrained weights of type "{weights_type}" into {weights_path}.')
weights_path, info = urlretrieve(download_url, weights_path)
else:
if verbose: print(f'Target path {weights_path} already exists. Skipping downloading weights.')
return weights_path
[docs]def count_parameters(module):
'''
Count pytorch module parameters.
Arguments:
module: torch.nn.Module.
'''
return sum(p.numel() for p in module.parameters() if p.requires_grad)
[docs]class LossLogPlot:
'''
Log and plot model training loss history.
Arguments:
log_path: str. Path where loss log is saved.
plot_path: str. Path where plot of loss history is saved.
loss_labels: list of str. Labels for different loss components.
loss_weights: list of int or str. Weights for different loss components.
Empty string for no weight (e.g. Total loss).
'''
def __init__(self, log_path, plot_path, loss_labels, loss_weights=None):
self.log_path = log_path
self.plot_path = plot_path
self.loss_labels = loss_labels
if not loss_weights:
self.loss_weights = [''] * len(self.loss_labels)
else:
assert len(loss_weights) == len(loss_labels)
self.loss_weights = loss_weights
self.train_losses = np.empty((0, len(loss_labels)))
self.val_losses = np.empty((0, len(loss_labels)))
self.epoch = 0
self._init_log()
def _init_log(self):
if not(os.path.isfile(self.log_path)):
with open(self.log_path, 'w') as f:
f.write('epoch')
for i, label in enumerate(self.loss_labels):
label = f';train_{label}'
if self.loss_weights[i]:
label += f' (x {self.loss_weights[i]})'
f.write(label)
for i, label in enumerate(self.loss_labels):
label = f';val_{label}'
if self.loss_weights[i]:
label += f' (x {self.loss_weights[i]})'
f.write(label)
f.write('\n')
print(f'Created log at {self.log_path}')
else:
with open(self.log_path, 'r') as f:
header = f.readline().rstrip('\r\n').split(';')
hl = (len(header)-1) // 2
if len(self.loss_labels) != hl:
raise ValueError(f'The length of the given list of loss names and the length of the header of the existing log at {self.log_path} do not match.')
for line in f:
line = line.rstrip('\n').split(';')
self.train_losses = np.append(self.train_losses, [[float(s) for s in line[1:hl+1]]], axis=0)
self.val_losses = np.append(self.val_losses, [[float(s) for s in line[hl+1:]]], axis=0)
self.epoch += 1
print(f'Using existing log at {self.log_path}')
[docs] def add_losses(self, train_loss, val_loss):
'''
Add losses to log.
Arguments:
train_loss: list of floats of length len(self.loss_labels). Training losses for the epoch.
val_loss: list of floats of length len(self.loss_labels). Validation losses for the epoch.
'''
self.epoch += 1
self.train_losses = np.append(self.train_losses, [train_loss], axis=0)
self.val_losses = np.append(self.val_losses, [val_loss], axis=0)
with open(self.log_path, 'a') as f:
f.write(str(self.epoch))
for l in train_loss:
f.write(f';{l}')
for l in val_loss:
f.write(f';{l}')
f.write('\n')
[docs] def plot_history(self, show=False, verbose=1):
'''
Plot and save history of current losses into plot_path.
Arguments:
show: Bool. Whether to show the plot on screen.
verbose: int 0 or 1. Whether to print output information.
'''
x = range(1, self.epoch+1)
n_rows, n_cols = _calc_plot_dim(len(self.loss_labels), f=0)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(7*n_cols, 6*n_rows))
if n_rows == 1 and n_cols == 1:
axes = np.expand_dims(axes, axis=0)
for i, (label, ax) in enumerate(zip(self.loss_labels, axes.flatten())):
ax.semilogy(x, self.train_losses[:,i],'-bx')
ax.semilogy(x, self.val_losses[:,i],'-gx')
ax.legend(['Training', 'Validation'])
ax.set_xlabel('Epochs')
ax.set_ylabel('Loss')
if self.loss_weights[i]:
label = f'{label} (x {self.loss_weights[i]})'
ax.set_title(label)
fig.tight_layout()
plt.savefig(self.plot_path)
if verbose: print(f'Loss history plot saved to {self.plot_path}')
if show:
plt.show()
else:
plt.close()
[docs]def save_checkpoint(model, optimizer, epoch, save_dir, lr_scheduler=None, verbose=1):
'''
Save pytorch checkpoint.
Arguments:
model: torch.nn.Module. Model whose state to save.
optimizer: torch.optim.Optimizer. Optimizer whose state to save
epoch: int. Training epoch.
save_dir: str. Directory to save in.
lr_scheduler: torch.optim.lr_scheduler or None. If not None, save state of this scheduler.
verbose: int 0 or 1. Whether to print information.
'''
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if hasattr(model, 'module'):
model = model.module
state = {
'model_params': model.state_dict(),
'optim_params': optimizer.state_dict(),
}
if lr_scheduler is not None:
state['scheduler_params'] = lr_scheduler.state_dict()
torch.save(state, os.path.join(save_dir, f'model_{epoch}.pth'))
if verbose: print(f'Model, optimizer weights on epoch {epoch} saved to {save_dir}')
[docs]def load_checkpoint(model, optimizer=None, file_name='./model.pth', lr_scheduler=None, verbose=1):
'''
Load pytorch checkpoint.
Arguments:
model: torch.nn.Module. Model where parameters are loaded to.
optimizer: torch.optim.Optimizer or None. If not None, load state to this optimizer.
file_name: str. Checkpoint file to load from.
lr_scheduler: torch.optim.lr_scheduler or None. If not None, try loading state to this scheduler.
verbose: int 0 or 1. Whether to print information.
'''
state = torch.load(file_name)
model.load_state_dict(state['model_params'])
if optimizer:
optimizer.load_state_dict(state['optim_params'])
msg = f'Model, optimizer weights loaded from {file_name}'
else:
msg = f'Model weights loaded from {file_name}'
if lr_scheduler is not None:
try:
lr_scheduler.load_state_dict(state['scheduler_params'])
except:
print('Learning rate scheduler parameters could not be loaded.')
if verbose: print(msg)
[docs]def read_xyzs(file_paths, return_comment=False):
'''
Read molecule xyz files.
Arguments:
file_paths: list of str. Paths to xyz files
return_comment: bool. If True, also return the comment string on second line of file.
Returns: list of np.array of shape (num_atoms, 4) or (num_atoms, 5).
Each row corresponds to one atom with [x, y, z, element] or [x, y, z, charge, element].
'''
mols = []
comments = []
for file_path in file_paths:
with open(file_path, 'r') as f:
N = int(f.readline().strip())
comments.append(f.readline())
atoms = []
for line in f:
line = line.strip().split()
try:
elem = int(line[0])
except ValueError:
elem = ELEMENTS.index(line[0]) + 1
posc = [float(p) for p in line[1:]]
atoms.append(posc + [elem])
mols.append(np.array(atoms))
if return_comment:
mols = mols, comments
return mols