Source code for edafm.training


import os
import re
import glob
import time
import random
import numpy as np

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

[docs]class ImgDataset(Dataset): ''' Pytorch dataset for loading AFM data. Arguments: data_dir: str. Path to directory where data is saved. preproc_fn: Python function. Preprocessing function which is applied to every batch. print_timings: Bool. Whether to print timings for each loaded batch. ''' def __init__(self, data_dir, preproc_fn=None, print_timings=False): paths = glob.glob(os.path.join(data_dir, 'batch_*.npz')) self.batch_paths = sorted(paths, key=lambda s: int(re.search('batch_\d+', s)[0][6:])) if preproc_fn: self._preproc_fn = preproc_fn self.print_timings = print_timings def _preproc_fn(batch): Xs, Ys, mols = batch return Xs, Ys def _load_batch(self, file_path): batch = np.load(file_path, allow_pickle=True) return batch['arr_0'] def __len__(self): return len(self.batch_paths) def __getitem__(self, idx): if self.print_timings: t0 = time.time() batch = self._load_batch(self.batch_paths[idx]) if self.print_timings: t1 = time.time() batch = self._preproc_fn(batch) if self.print_timings: t2 = time.time() if wi := torch.utils.data.get_worker_info(): msg = f'(Gen {wi.id}) ' else: msg = '(Gen)' print(msg+f't0/load/preproc: {t0}/{t1-t0}/{t2-t1}') return batch
[docs] def shuffle(self): '''Shuffle batch order''' random.shuffle(self.batch_paths)
[docs]class ImgLoss(nn.Module): ''' Weighted mean squared loss for images. Arguments: loss_factors: list of int. Loss weights. ''' def __init__(self, loss_factors): super().__init__() self.loss_factors = loss_factors
[docs] def forward(self, pred, ref, separate_batch_items=False): assert len(pred) == len(ref) == len(self.loss_factors) losses = [] total_loss = 0.0 for p, r, f in zip(pred, ref, self.loss_factors): loss = torch.mean((p - r) ** 2, dim=(1,2)) if not separate_batch_items: loss = loss.mean() losses.append(loss) total_loss += f*loss loss = [total_loss] + losses return loss
def _collate_fn(batch): X, Y = batch X = [torch.from_numpy(x).unsqueeze(1).float() for x in X] Y = [torch.from_numpy(y).float() for y in Y] return X, Y def _worker_init_fn(worker_id): np.random.seed(int((time.time() % 1e5)*1000) + worker_id)
[docs]def make_dataloader(datadir, preproc_fn, print_timings=False, num_workers=8): #TODO memory pinning? ''' Produce a dataset and dataloader from data directory. Arguments: datadir: str. Path to directory with data. preproc_fn: Python function. Preprocessing function to apply to each batch. print_timings: Boolean. Whether to print timings for each batch. num_workers: int. Number of parallel processes for data loading. Returns: tuple (dataset, dataloader) | dataset: ImgDataset. | dataloader: torch.DataLoader. ''' dataset = ImgDataset(datadir, preproc_fn, print_timings=print_timings) dataloader = DataLoader( dataset, batch_size=None, shuffle=False, num_workers=num_workers, collate_fn = _collate_fn, worker_init_fn=_worker_init_fn, pin_memory=True ) return dataset, dataloader