Source code for InnerEye.ML.models.layers.pooling_layers

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
from typing import Any

import torch
import torch.nn.functional as TF

from InnerEye.ML.utils.layer_util import initialise_layer_weights


[docs]class AveragePooling(torch.nn.Module): """ Global average pooling operation across all spatial dimensions (e.g. 2D and 3D image grids) """
[docs] def forward(self, *input: Any, **kwargs: Any) -> Any: kernel_size = input[0].shape[2:] return TF.avg_pool3d(input[0], kernel_size=kernel_size)
[docs]class MaxPooling(torch.nn.Module): """ Global max pooling operation across all spatial dimensions (e.g. 2D and 3D image grids) """
[docs] def forward(self, *input: Any, **kwargs: Any) -> Any: kernel_size = input[0].shape[2:] return TF.max_pool3d(input[0], kernel_size=kernel_size)
[docs]class MixPooling(torch.nn.Module): """ Compute a mixture of max pooling and average pooling. feature = a * avg_3d + (1-a) * max_3d given a in [0, 1] The mixing weight is a learnable parameter. """ def __init__(self) -> None: super().__init__() # noinspection PyArgumentList self.mixing_weight = torch.nn.Parameter(torch.zeros(1)) # type: ignore
[docs] def forward(self, *input: Any, **kwargs: Any) -> Any: """ :param input: batch of size [B, C, Z, X, Y] """ kernel_size = input[0].shape[2:] f_avg = torch.nn.functional.avg_pool3d(input[0], kernel_size) # B, C, 1, 1, 1 f_max = torch.nn.functional.max_pool3d(input[0], kernel_size) # B, C, 1, 1, 1 return TF.sigmoid(self.mixing_weight) * f_avg + (1 - TF.sigmoid(self.mixing_weight)) * f_max # type: ignore
[docs]class Gated3dPoolingLayer(torch.nn.Module): """ Gated pooling. Flatten each volume x [1, ZYX], feed through a one layer NN yield one weight per image. This weight is used as the mixing proportion for max_pooling features and average pooling features similar to what is done in MixPooling. """ def __init__(self, in_features: int) -> None: """ :param in_features: should be the size of the flatten volume X*Y*Z """ super().__init__() self.in_features = in_features # Create layers self.gate = torch.nn.Sequential( torch.nn.Linear(in_features=in_features, out_features=1), torch.nn.Sigmoid()) # Initialise the trainable parameters self.apply(initialise_layer_weights)
[docs] def forward(self, *input: Any, **kwargs: Any) -> torch.Tensor: """ :param input: batch of size [B, C, Z, X, Y """ item = input[0] channels = item.shape[1] kernel_size = item.shape[2:] # Common gating map across all channels. gating_weights = self.gate(item.reshape(-1, channels, self.in_features)) f_avg = torch.nn.functional.avg_pool3d(item, kernel_size) # B, C, 1 f_max = torch.nn.functional.max_pool3d(item, kernel_size) # B, C, 1 gating_weights = gating_weights.reshape_as(f_avg) # noinspection PyTypeChecker final = gating_weights * f_avg + (1 - gating_weights) * f_max # type: ignore return final
[docs]class ZAdaptive3dAvgLayer(torch.nn.Module): """ Performs 3D average pooling with custom weighting along the Z dimension. In short: extract the 2d average for each B-scan. Learn a weighting for averaging these features over all B-Scans. """ def __init__(self, in_features: int) -> None: """ :param in_features: number of B-scan """ super().__init__() self.in_features = in_features # Create layers # noinspection PyArgumentList self.scan_weight = torch.nn.Parameter(torch.zeros(in_features, 1)) # type: ignore # Initialise the trainable parameters self.apply(initialise_layer_weights)
[docs] def forward(self, *input: torch.Tensor, **kwargs: Any) -> torch.Tensor: """ :param input: batch of size [B, C, Z, X, Y] """ item = input[0] B, C, Z, Y, X = item.shape # Average first in 2d - one feature per B-scan f_avg_2d = torch.nn.functional.avg_pool3d(item, [1, Y, X]) # B, C, Z, 1, 1 # Give a custom weight to each z slice normalized_weight = TF.softmax(self.scan_weight, dim=0) custom_3d_avg = f_avg_2d.reshape(B, C, Z) @ normalized_weight return custom_3d_avg.reshape((B, C, 1, 1, 1))