Source code for InnerEye.ML.models.architectures.base_model

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import abc
from abc import ABC
from collections import OrderedDict
from typing import Any, List, Optional, Tuple, Union

import numpy as np
import torch

from InnerEye.Common.common_util import any_pairwise_larger, initialize_instance_variables
from InnerEye.Common.type_annotations import IntOrTuple3, TupleInt2, TupleInt3
from InnerEye.ML.utils.device_aware_module import DeviceAwareModule
from InnerEye.ML.visualizers.model_summary import ModelSummary, forward_preserve_state


[docs]class CropSizeConstraints: def __init__(self, multiple_of: Optional[IntOrTuple3] = None, minimum_size: Optional[IntOrTuple3] = None, num_dimensions: int = 3): """ :param multiple_of: Stores minimum size and other conditions that a training crop size must satisfy. :param minimum_size: Training crops must have a size that is a multiple of this value, along each dimension. For example, if set to (1, 16, 16), the crop size has to be a multiple of 16 along X and Y, and a multiple of 1 (i.e., any number) along the Z dimension. :param num_dimensions: Training crops must have a size that is at least this value. """ self.multiple_of = multiple_of self.minimum_size = minimum_size self.num_dimensions = num_dimensions def make_tuple3(o: Optional[IntOrTuple3]) -> Optional[TupleInt3]: # "type ignore" directives below are because mypy is not clever enough if o is None: return None if isinstance(o, int): # noinspection PyTypeChecker return (o,) * self.num_dimensions # type: ignore if len(o) != self.num_dimensions: # type: ignore raise ValueError("Object must have length {}, but got: {}" .format(self.num_dimensions, o)) return o # type: ignore self.multiple_of = make_tuple3(self.multiple_of) self.minimum_size = make_tuple3(self.minimum_size) if self.minimum_size is None: self.minimum_size = self.multiple_of else: if self.multiple_of is not None and any_pairwise_larger(self.multiple_of, self.minimum_size): raise ValueError(f"Invalid arguments: The minimum size must be at least as large as the multiple_of. " f"minimum_size: {self.minimum_size}, multiple_of: {self.multiple_of}")
[docs] def validate(self, crop_size: TupleInt3, message_prefix: Optional[str] = None) -> None: """ Checks if the given crop size is a valid crop size for the present model. If it is not valid, throw a ValueError. :param crop_size: The crop size that should be checked. :param message_prefix: A string prefix for the error message if the crop size is found to be invalid. :return: """ message_prefix = message_prefix + ": " if message_prefix else "" if len(crop_size) != self.num_dimensions: raise ValueError(f"{message_prefix}Crop size must have length {self.num_dimensions}, but got: {crop_size}") if self.minimum_size is not None: assert not isinstance(self.minimum_size, int) if any_pairwise_larger(self.minimum_size, crop_size): # type: ignore raise ValueError(f"{message_prefix}Crop size is not valid. The required minimum is {self.minimum_size}," f" but got: {crop_size}") if self.multiple_of is not None: assert not isinstance(self.multiple_of, int) if any(crop % mult != 0 for (crop, mult) in zip(crop_size, self.multiple_of)): raise ValueError(f"{message_prefix}Crop size is not valid. Crop size is should be a multiple of " f"{self.multiple_of}, but got: {crop_size}")
[docs] def restrict_crop_size_to_image(self, image_shape: TupleInt3, crop_size: TupleInt3, stride_size: TupleInt3) -> Tuple[TupleInt3, TupleInt3]: """ Computes an adjusted crop and stride size for cases where the image is smaller than the chosen crop size (at test time). The new crop size will be the largest multiple of self.multiple_of that fits into the image_shape. The stride size will attempt to maintain the stride-to-crop ratio before adjustment. :param image_shape: The shape of the image to process. :param crop_size: The present test crop size. :param stride_size: The present inference stride size. :return: A tuple of (crop_size, stride_size) """ shape = np.array(image_shape) crop = np.array(crop_size) stride = np.array(stride_size) multiple_of = np.array(self.multiple_of) minimum = np.array(self.minimum_size) if np.any(shape < minimum): raise ValueError("The input image must have at least a size of {}, but got: {}" .format(self.minimum_size, image_shape)) if np.all(shape >= crop): return crop_size, stride_size stride_to_crop = stride / crop crop_new = np.ceil(np.minimum(crop, shape) / multiple_of) * multiple_of stride_new = np.maximum(np.floor(stride_to_crop * crop_new), 1) def to_tuple(a: np.ndarray) -> TupleInt3: return int(a[0]), int(a[1]), int(a[2]) return to_tuple(crop_new), to_tuple(stride_new)
[docs]class BaseSegmentationModel(DeviceAwareModule, ABC): """ Base neural network segmentation model. """ @initialize_instance_variables def __init__(self, name: str, input_channels: int, crop_size_constraints: Optional[CropSizeConstraints] = None ): """ Creates a new instance of the base model class. :param name: A human readable name of the model. :param input_channels: The number of image input channels. :param crop_size_constraints: The size constraints for the training crop size. If not provided, a minimum crop size of 1 is assumed. """ super().__init__() self.num_dimensions = 3 self.name = name self.input_channels = input_channels self.summarizer: Optional[ModelSummary] = None self.summary: Optional[OrderedDict] = None self.summary_crop_size: Optional[TupleInt3] = None if crop_size_constraints is None: # Allow any size. With this initialization, both multiple_of and minimum_size will be populated. crop_size_constraints = CropSizeConstraints(multiple_of=1) self.crop_size_constraints = crop_size_constraints
[docs] def get_output_shape(self, input_shape: Union[TupleInt2, TupleInt3]) -> Tuple[int, ...]: """ Computes model's output tensor shape for given input tensor shape. The argument is expected to be either a 2-tuple or a 3-tuple. A batch dimension (1) and the number of channels are added as the first dimensions. The result tuple has batch and channel dimension stripped off. :param input_shape: A tuple (2D or 3D) representing incoming tensor shape. """ # Create a sample tensor for inference batch_size = 1 if len(input_shape) not in [2, 3]: raise ValueError("Input shape has to be in 2D or 3D, found {}".format(len(input_shape))) input_tensors = \ [torch.zeros(batch_size, self.input_channels, *input_shape, dtype=torch.float)] # Perform a forward pass then restore the state of the module output_shape = forward_preserve_state(module=self, inputs=input_tensors).size() return tuple(output_shape[2:])
[docs] def partition_model(self, devices: Optional[List[torch.device]] = None) -> None: """A method to partition a neural network model across multiple devices. If no list of devices is given, use all available GPU devices.""" pass
[docs] def validate_crop_size(self, crop_size: TupleInt3, message_prefix: Optional[str] = None) -> None: """ Checks if the given crop size is a valid crop size for the present model. If it is not valid, throw a ValueError. :param crop_size: The crop size that should be checked. :param message_prefix: A string prefix for the error message if the crop size is found to be invalid. """ if self.crop_size_constraints is not None: self.crop_size_constraints.validate(crop_size, message_prefix)
[docs] def generate_model_summary(self, crop_size: Optional[TupleInt3] = None, log_summaries_to_files: bool = False) -> None: """ Stores a model summary, containing information about layers, memory consumption and runtime in the model.summary field. When called again with the same crop_size, the summary is not created again. :param crop_size: The crop size for which the summary should be created. If not provided, the minimum allowed crop size is used. :param log_summaries_to_files: whether to write the summary to a file """ if crop_size is None: crop_size = self.crop_size_constraints.minimum_size # type: ignore assert crop_size is not None input_size = [crop_size] if self.summary is None or self.summary_crop_size != input_size: self.summarizer = ModelSummary(self) self.summary = self.summarizer.generate_summary( input_sizes=[(self.input_channels, *crop_size)], log_summaries_to_files=log_summaries_to_files) self.summary_crop_size = crop_size
[docs] @abc.abstractmethod def forward(self, input: Any) -> Any: # type: ignore raise NotImplementedError("forward must be implemented by subclasses")
[docs] def get_all_child_layers(self) -> List[torch.nn.Module]: raise NotImplementedError("get_all_child_layers must be implemented by subclasses")