Source code for InnerEye.ML.models.architectures.classification.bit

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------

#  ------------------------------------------------------------------------------------------
#  This is a re-implementation of the Big Transfer model described in
#  Kolesnikov, Alexander, et al. "Big transfer (BiT): General visual representation learning."
#  arXiv preprint arXiv:1912.11370 (2019).

#  The original source code (https://github.com/google-research/big_transfer) is licensed
#  under the Apache License 2.0 which can be found in the THIRDPARTYNOTICES.md file in this
#  repository.
#  ------------------------------------------------------------------------------------------

import torch

from torch import nn
from typing import Tuple, Optional

from InnerEye.ML.models.layers.weight_standardization import WeightStandardizedConv2d


[docs]class ResNetV2Block(nn.Module): """ ResNetV2 (https://arxiv.org/pdf/1603.05027.pdf) uses pre activation in the ResNet blocks. Big Transfer replaces BatchNorm with GroupNorm """ def __init__(self, in_channels: int, out_channels: int, bottleneck_channels: int, num_groups: int, downsample_stride: int = 1): super().__init__() self.gn1 = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) self.conv1 = WeightStandardizedConv2d(in_channels=in_channels, out_channels=bottleneck_channels, kernel_size=(1, 1), stride=(1, 1), bias=False) self.gn2 = nn.GroupNorm(num_groups=num_groups, num_channels=bottleneck_channels) self.conv2 = WeightStandardizedConv2d(in_channels=bottleneck_channels, out_channels=bottleneck_channels, kernel_size=(3, 3), stride=(downsample_stride, downsample_stride), padding=(1, 1), bias=False) self.gn3 = nn.GroupNorm(num_groups=num_groups, num_channels=bottleneck_channels) self.conv3 = WeightStandardizedConv2d(in_channels=bottleneck_channels, out_channels=out_channels, kernel_size=(1, 1), stride=(1, 1), bias=False) self.relu = nn.ReLU(inplace=True) if in_channels != out_channels: self.downsample: Optional[WeightStandardizedConv2d] = \ WeightStandardizedConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1), stride=(downsample_stride, downsample_stride), bias=False) else: self.downsample = None
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore input_normed = self.relu(self.gn1(x)) out = self.conv1(input_normed) out = self.conv2(self.relu(self.gn2(out))) out = self.conv3(self.relu(self.gn3(out))) # add residual residual = self.downsample(input_normed) if self.downsample else input_normed return out + residual
[docs]class ResNetV2Layer(nn.Module): """ Single layer of ResNetV2 """ def __init__(self, in_channels: int, out_channels: int, bottleneck_channels: int, num_groups: int, downsample_stride: int, num_blocks: int): super().__init__() _layers = [ResNetV2Block(in_channels=in_channels if i == 0 else out_channels, out_channels=out_channels, bottleneck_channels=bottleneck_channels, num_groups=num_groups, downsample_stride=downsample_stride if i == 0 else 1) for i in range(num_blocks)] self.layer = nn.Sequential(*_layers)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore return self.layer(x)
[docs]class BiTResNetV2(nn.Module): """ Implements the Big Transfer (BiT) model https://arxiv.org/pdf/1912.11370.pdf https://github.com/google-research/big_transfer """ def __init__(self, num_groups: int = 32, num_classes: int = 21843, num_blocks_in_layer: Tuple[int, int, int, int] = (3, 4, 23, 3), width_factor: int = 1): super().__init__() self.initial = nn.Sequential( WeightStandardizedConv2d(in_channels=3, out_channels=64 * width_factor, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False), nn.ConstantPad2d(padding=1, value=0), nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1)) self.conv_stack = nn.Sequential( ResNetV2Layer(in_channels=64 * width_factor, out_channels=256 * width_factor, bottleneck_channels=64 * width_factor, num_groups=num_groups, downsample_stride=1, num_blocks=num_blocks_in_layer[0]), ResNetV2Layer(in_channels=256 * width_factor, out_channels=512 * width_factor, bottleneck_channels=128 * width_factor, num_groups=num_groups, downsample_stride=2, num_blocks=num_blocks_in_layer[1]), ResNetV2Layer(in_channels=512 * width_factor, out_channels=1024 * width_factor, bottleneck_channels=256 * width_factor, num_groups=num_groups, downsample_stride=2, num_blocks=num_blocks_in_layer[2]), ResNetV2Layer(in_channels=1024 * width_factor, out_channels=2048 * width_factor, bottleneck_channels=512 * width_factor, num_groups=num_groups, downsample_stride=2, num_blocks=num_blocks_in_layer[3])) self.linear = nn.Sequential( nn.GroupNorm(num_groups=num_groups, num_channels=2048 * width_factor), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(output_size=1), nn.Conv2d(in_channels=2048 * width_factor, out_channels=num_classes, kernel_size=(1, 1), bias=True))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore x = self.initial(x) x = self.conv_stack(x) x = self.linear(x) return x