Source code for edafm.layers


import torch
import torch.nn as nn
import torch.nn.functional as F

def _get_padding(kernel_size, nd):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, )*nd
    padding = []
    for i in range(nd):
        padding += [(kernel_size[i]-1) // 2]
    return tuple(padding)

class _ConvNdBlock(nn.Module):

    def __init__(self,
        in_channels,
        out_channels,
        nd,
        kernel_size=3,
        depth=2,
        padding_mode='zeros',
        res_connection=True,
        activation=None,
        last_activation=True
    ):

        assert depth >= 1

        if nd == 2:
            conv = nn.Conv2d
        elif nd == 3:
            conv = nn.Conv3d
        else:
            raise ValueError(f'Invalid convolution dimensionality {nd}.')
        
        super().__init__()
        
        self.res_connection = res_connection
        if not activation:
            self.act = nn.ReLU()
        else:
            self.act = activation

        if last_activation:
            self.acts = [self.act] * depth
        else:
            self.acts = [self.act] * (depth-1) + [self._identity]
        
        padding = _get_padding(kernel_size, nd)
        self.convs = nn.ModuleList([conv(in_channels, out_channels, kernel_size=kernel_size, padding=padding, padding_mode=padding_mode)])
        for i in range(depth-1):
            self.convs.append(conv(out_channels, out_channels, kernel_size=kernel_size, padding=padding, padding_mode=padding_mode))
        if res_connection and in_channels != out_channels:
            self.res_conv = conv(in_channels, out_channels, kernel_size=1)
        else:
            self.res_conv = None

    def _identity(self, x):
        return x
            
    def forward(self, x_in):
        x = x_in
        for conv, act in zip(self.convs, self.acts):
            x = act(conv(x))
        if self.res_connection:
            if self.res_conv:
                x = x + self.res_conv(x_in)
            else:
                x = x + x_in
        return x

[docs]class Conv3dBlock(_ConvNdBlock): ''' Pytorch 3D convolution block module. Arguments: in_channels: int. Number of channels entering the first convolution layer. out_channels: int. Number of output channels in each layer of the block. kernel_size: int or tuple. Size of convolution kernel. depth: int >= 1. Number of convolution layers in the block. padding_mode: str. Type of padding in each convolution layer. 'zeros', 'reflect', 'replicate' or 'circular'. res_connection: Boolean. Whether to use residual connection over the block (f(x) = h(x) + x). If in_channels != out_channels, a 1x1x1 convolution is applied to the res connection to make the channel numbers match. activation: torch.nn.Module. Activation function to use after every layer in block. If None, defaults to ReLU. last_activation: Bool. Whether to apply the activation after the last conv layer (before res connection). ''' def __init__(self, in_channels, out_channels, kernel_size=3, depth=2, padding_mode='zeros', res_connection=True, activation=None, last_activation=True ): super().__init__(in_channels, out_channels, 3, kernel_size, depth, padding_mode, res_connection, activation, last_activation)
[docs]class Conv2dBlock(_ConvNdBlock): ''' Pytorch 2D convolution block module. Arguments: in_channels: int. Number of channels entering the first convolution layer. out_channels: int. Number of output channels in each layer of the block. kernel_size: int or tuple. Size of convolution kernel. depth: int >= 1. Number of convolution layers in the block. padding_mode: str. Type of padding in each convolution layer. 'zeros', 'reflect', 'replicate' or 'circular'. res_connection: Boolean. Whether to use residual connection over the block (f(x) = h(x) + x). If in_channels != out_channels, a 1x1 convolution is applied to the res connection to make the channel numbers match. activation: torch.nn.Module. Activation function to use after every layer in block. If None, defaults to ReLU. last_activation: Bool. Whether to apply the activation after the last conv layer (before res connection). ''' def __init__(self, in_channels, out_channels, kernel_size=3, depth=2, padding_mode='zeros', res_connection=True, activation=None, last_activation=True ): super().__init__(in_channels, out_channels, 2, kernel_size, depth, padding_mode, res_connection, activation, last_activation)
[docs]class UNetAttentionConv(nn.Module): ''' Pytorch attention layer for Attention U-net model upsampling stage. Arguments: in_channels: int. Number of channels in the attended feature map. query_channels: int. Number of channels in query feature map. query_channels: int. Number of channels in hidden convolution layer before computing attention. kernel_size: int. Size of convolution kernel. padding_mode: str. Type of padding in each convolution layer. 'zeros', 'reflect', 'replicate' or 'circular'. conv_activation: nn.Module. Activation function to use after convolution layers attention_activation: str. Type of activation to use for attention map. 'sigmoid' or 'softmax'. upsample_mode: str. Algorithm for upsampling query feature map to the attended feature map size. For options see torch.nn.functional.interpolate. References: https://arxiv.org/abs/1804.03999 ''' def __init__(self, in_channels, query_channels, attention_channels, kernel_size, padding_mode='zeros', conv_activation=nn.ReLU(), attention_activation='softmax', upsample_mode='bilinear' ): super().__init__() if attention_activation == 'softmax': self.attention_activation = self._softmax elif attention_activation == 'sigmoid': self.attention_activation = self._sigmoid else: raise ValueError(f'Unrecognized attention map activation {attention_activation}.') padding = _get_padding(kernel_size, 2) self.x_conv = nn.Conv2d(in_channels, attention_channels, kernel_size=kernel_size, padding=padding, padding_mode=padding_mode) self.q_conv = nn.Conv2d(query_channels, attention_channels, kernel_size=kernel_size, padding=padding, padding_mode=padding_mode) self.a_conv = nn.Conv2d(attention_channels, 1, kernel_size=kernel_size, padding=padding, padding_mode=padding_mode) self.softmax = nn.Softmax(dim=1) self.sigmoid = nn.Sigmoid() self.upsample_mode = upsample_mode self.conv_activation = conv_activation def _softmax(self, a): shape = a.shape return self.softmax(a.reshape(shape[0], -1)).reshape(shape) def _sigmoid(self, a): return self.sigmoid(a)
[docs] def forward(self, x, q): # Upsample query q to the size of input x and convolve q = F.interpolate(q, size=x.size()[2:], mode=self.upsample_mode, align_corners=False) q = self.conv_activation(self.q_conv(q)) # Convolve input x and sum with q a = self.conv_activation(self.x_conv(x)) a = self.conv_activation(a + q) # Get attention map and mix it with x a = self.attention_activation(self.a_conv(a)) x = a * x return x, a.squeeze(dim=1)