Source code for InnerEye.ML.utils.layer_util

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
from contextlib import contextmanager
from typing import Generator, Iterable, Sized, Tuple, Union

import torch
from torch.nn import init

from InnerEye.Common.type_annotations import TupleInt3
from InnerEye.ML.config import PaddingMode

IntOrTuple3 = Union[int, TupleInt3, Iterable]


[docs]def initialise_layer_weights(module: torch.nn.Module) -> None: """ Torch kernel initialisations for conv and batch_norm are based on leaky_relu activation. Apply Kaiming initialisation and adapt the kernel weights for relu activation. :param module: Torch nn module """ if isinstance(module, torch.nn.Conv3d): init.kaiming_normal_(module.weight, a=0, mode='fan_in', nonlinearity='relu') if module.bias is not None: torch.nn.init.zeros_(module.bias) if isinstance(module, torch.nn.BatchNorm3d): init.normal_(module.weight, mean=1.0, std=0.01) init.constant_(module.bias, val=0.0)
[docs]def get_padding_from_kernel_size(padding: PaddingMode, kernel_size: Union[Iterable[int], int], dilation: Union[Iterable[int], int] = 1, num_dimensions: int = 3) -> Tuple[int, ...]: """ Returns padding value required for convolution layers based on input kernel size and dilation. :param padding: Padding type (Enum) {`zero`, `no_padding`}. Option `zero` is intended to preserve the tensor shape. In `no_padding` option, padding is not applied and the function returns only zeros. :param kernel_size: Spatial support of the convolution kernel. It is used to determine the padding size. This can be a scalar, tuple or array. :param dilation: Dilation of convolution kernel. It is used to determine the padding size. This can be a scalar, tuple or array. :param num_dimensions: The number of dimensions that the returned padding tuple should have, if both kernel_size and dilation are scalars. :return: padding value required for convolution layers based on input kernel size and dilation. """ if isinstance(kernel_size, Sized): num_dimensions = len(kernel_size) elif isinstance(dilation, Sized): num_dimensions = len(dilation) if not isinstance(kernel_size, Iterable): kernel_size = [kernel_size] * num_dimensions if not isinstance(dilation, Iterable): dilation = [dilation] * num_dimensions if padding == PaddingMode.NoPadding: return tuple([0] * num_dimensions) else: out = [0 if k == 1 else (k - 1) // 2 + d - 1 for k, d in zip(kernel_size, dilation)] return tuple(out)
[docs]def get_upsampling_kernel_size(downsampling_factor: IntOrTuple3, num_dimensions: int) -> TupleInt3: """ Returns the kernel size that should be used in the transpose convolution in the decoding blocks of the UNet. Use a value that is a multiple of the downsampling factor to avoid checkerboard artefacts, see https://distill.pub/2016/deconv-checkerboard/ :param downsampling_factor: downsampling factor use for each dimension of the kernel. Can be either a list of len(num_dimension) with one factor per dimension or an int in which case the same factor will be applied for all dimension. :param num_dimensions: number of dimensions of the kernel :return: upsampling_kernel_size """ def upsample_size(down: int) -> int: return 2 * down if down > 1 else 1 if isinstance(downsampling_factor, int): downsampling_factor = (downsampling_factor,) * num_dimensions if any(down < 1 for down in downsampling_factor): raise ValueError(f"The downsampling_factor must be >= 1 in each component, but got: {downsampling_factor}") if len(downsampling_factor) != num_dimensions: # type: ignore raise ValueError(f"The downsampling_factor must be an integer or an Iterable of length 3, but got: " f"{downsampling_factor}") # Writing this as a tuple so that the type checker knows we really return a 3-tuple. upsampling_kernel_size = (upsample_size(downsampling_factor[0]), # type: ignore upsample_size(downsampling_factor[1]), # type: ignore upsample_size(downsampling_factor[2])) # type: ignore return upsampling_kernel_size
[docs]@contextmanager def set_model_to_eval_mode(model: torch.nn.Module) -> Generator: """ Puts the given torch model into eval mode. At the end of the context, resets the state of the training flag to what is was before the call. :param model: The model to modify. """ old_mode = model.training model.eval() yield model.train(old_mode)