training module
- class edafm.training.ImgDataset(data_dir, preproc_fn=None, print_timings=False)[source]
Bases:
torch.utils.data.dataset.DatasetPytorch dataset for loading AFM data.
- Parameters
data_dir – str. Path to directory where data is saved.
preproc_fn – Python function. Preprocessing function which is applied to every batch.
print_timings – Bool. Whether to print timings for each loaded batch.
- class edafm.training.ImgLoss(loss_factors)[source]
Bases:
torch.nn.modules.module.ModuleWeighted mean squared loss for images.
- Parameters
loss_factors – list of int. Loss weights.
- forward(pred, ref, separate_batch_items=False)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- edafm.training.make_dataloader(datadir, preproc_fn, print_timings=False, num_workers=8)[source]
Produce a dataset and dataloader from data directory.
- Parameters
datadir – str. Path to directory with data.
preproc_fn – Python function. Preprocessing function to apply to each batch.
print_timings – Boolean. Whether to print timings for each batch.
num_workers – int. Number of parallel processes for data loading.
- Returns: tuple (dataset, dataloader)
- dataset: ImgDataset.dataloader: torch.DataLoader.