models module
- class edafm.models.AttentionUNet(conv3d_in_channels, conv2d_in_channels, conv3d_out_channels, n_in=1, n_out=3, merge_block_channels=[8], merge_block_depth=2, conv3d_block_channels=[8, 16, 32], conv3d_block_depth=2, conv3d_dropouts=[0.0, 0.0, 0.0], conv2d_block_channels=[128], conv2d_block_depth=3, conv2d_dropouts=[0.1], attention_channels=[32, 32, 32], upscale2d_block_channels=[16, 16, 16], upscale2d_block_depth=1, upscale2d_block_channels2=[16, 16, 16], upscale2d_block_depth2=2, split_conv_block_channels=[16], split_conv_block_depth=[3], res_connections=True, out_convs_channels=1, out_relus=True, pool_type='avg', pool_z_strides=[2, 1, 2], padding_mode='zeros', activation='lrelu', attention_activation='softmax', device='cuda')[source]
Bases:
torch.nn.modules.module.ModulePytorch 3D-to-2D U-net model with attention.
3D conv -> concatenate -> 3D conv/pool/dropout -> 2D conv/dropout -> 2D upsampling/conv with skip connections and attention. For multiple inputs, the inputs are first processed through separate 3D conv blocks before merging by concatenating along channel axis.
- Parameters
conv3d_in_channels – int. Number of channels in input.
conv2d_in_channels – int. Number of channels in first 2D conv layer after flattening 3D to 2D.
conv3d_out_channels – list of int of same length as conv3d_block_channels. Number of channels after 3D-to-2D flattening after each 3D conv block. Depends on input z size.
n_in – int. Number of input 3D images.
n_out – int. Number of output 2D maps.
merge_block_channels – list of int. Number of channels in input merging 3D conv blocks.
merge_block_depth – int. Number of layers in each merge conv block.
conv3d_block_channels – list of ints. Number channels in 3D conv blocks.
conv3d_block_depth – int. Number of layers in each 3D conv block.
conv3d_dropouts – list of int of same lenght as conv3d_block_channels. Dropout rates after each conv3d block.
conv2d_block_channels – list of ints. Number channels in 2D conv blocks.
conv2d_block_depth – int. Number of layers in each 2D conv block.
conv2d_dropouts – list of int of same lenght as conv2d_block_channels. Dropout rates after each conv2d block.
attention_channels – list of int of same lenght as conv3d_block_channels. Number of channels in conv layer within each attention block.
upscale2d_block_channels – list of int of same length as conv3d_block_channels. Number of channels in each 2D conv block after upscale before skip connection.
upscale2d_block_depth – int. Number of layers in each 2D conv block after upscale before skip connection.
upscale2d_block_channels2 – list of int of same length as conv3d_block_channels. Number of channels in each 2D conv block after skip connection.
upscale2d_block_depth2 – int. Number of layers in each 2D conv block after skip connection.
split_conv_block_channels – list of int. Number of channels in 2d conv blocks after splitting outputs.
split_conv_block_depth – int. Number of layers in each 2d conv block after splitting outputs.
res_connections – Boolean. Whether to use residual connections in conv blocks.
out_convs_channels – int or list of int. Number of channels in splitted outputs.
out_relus – Bool or list of Bool of length n_out. Whether to apply relu activation to the output 2D maps.
pool_type – str (‘max’ or ‘avg’). Type of pooling to use.
pool_z_strides – list of int of same length as conv3d_block_channels. Stride of pool layers in z direction.
padding_mode – str. Type of padding in each convolution layer. ‘zeros’, ‘reflect’, ‘replicate’ or ‘circular’.
activation – str (‘relu’, ‘lrelu’, or ‘elu’) or nn.Module. Activation to use after every layer except last one.
attention_activation – str. Type of activation to use for attention map. ‘sigmoid’ or ‘softmax’.
device – str. Device to load model onto.
- forward(x, return_attention=False)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class edafm.models.EDAFMNet(device='cuda', trained_weights=None, weights_dir='./weights')[source]
Bases:
edafm.models.AttentionUNetED-AFM Attention U-net.
This is the model used in the ED-AFM paper for task of predicting electrostatics from AFM images. It is a subclass of the AttentionUnet class with specific hyperparameters.
- Parameters
device – str. Device to load model onto.
trained_weights – str or None. If not None, load pretrained weights to the model. One of ‘base’, ‘single-channel’, ‘CO-Cl’, ‘Xe-Cl’, ‘constant-noise’, ‘uniform-noise’, ‘no-gradient’, or ‘matched-tips’. See README at https://github.com/SINGROUP/ED-AFM for explanations for the different options.
weights_dir – str. If weights_type is not None, directory where the weights will be downloaded into.
- training: bool