Source code for InnerEye.ML.models.architectures.complex

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
from typing import Any, List, Optional

import torch

from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.models.architectures.base_model import BaseSegmentationModel, CropSizeConstraints
from InnerEye.ML.models.blocks.residual import ResidualBlock
from InnerEye.ML.models.layers.basic import BasicLayer


[docs]class ComplexModel(BaseSegmentationModel): """ A general class of feed-forward convolutional neural networks that is characterised by a network definition (list of lists of modules). It supports residual blocks, auto-focus and atrous spatial pyramid pooling layers. """ # noinspection PyTypeChecker def __init__(self, args: SegmentationModelBase, full_channels_list: List[int], dilations: List[int], network_definition: List[List[torch.nn.Module]], crop_size_constraints: Optional[CropSizeConstraints] = None): """ Creates a new instance of the class. :param args: The full model configuration. :param full_channels_list: A vector of channel sizes. First entry is the number of image channels, then all feature channels, then the number of classes. :param network_definition: :param crop_size_constraints: The size constraints for the training crop size. """ super().__init__(name='ComplexModel', input_channels=full_channels_list[0], crop_size_constraints=crop_size_constraints) self.full_channels_list = full_channels_list self.kernel_size = args.kernel_size self.dilations = dilations self._layers = torch.nn.ModuleList() channel_i = dilation_i = 0 for layer in network_definition: if isinstance(layer, list): n_layers = len(layer) model_block = ResidualBlock(layers=layer, channels=full_channels_list[channel_i:channel_i + (n_layers + 1)], kernel_size=self.kernel_size, dilations=self.dilations[dilation_i:dilation_i + n_layers]) channel_i += n_layers dilation_i += n_layers self._layers.append(model_block) elif layer == BasicLayer: model_block = BasicLayer(full_channels_list[channel_i:channel_i + 2], self.kernel_size, # type: ignore dilation=self.dilations[dilation_i], use_bias=True) channel_i += 1 dilation_i += 1 self._layers.append(model_block) else: raise ValueError(f"Unknown layer {layer}") fc = torch.nn.Conv3d(full_channels_list[channel_i], full_channels_list[channel_i + 1], kernel_size=1) self._layers.append(fc)
[docs] def forward(self, x: Any) -> Any: # type: ignore for layer in self._layers.children(): x = layer(x) return x
[docs] def get_all_child_layers(self) -> List[torch.nn.Module]: return list(self._layers.children())