Source code for InnerEye.ML.augmentations.transform_pipeline

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
from typing import Any, Callable, List, Union

import PIL
import torch

from torchvision.transforms import CenterCrop, ColorJitter, Compose, RandomAffine, RandomErasing, \
    RandomHorizontalFlip, RandomResizedCrop, Resize
from torchvision.transforms.functional import to_tensor
from yacs.config import CfgNode

from InnerEye.ML.augmentations.image_transforms import AddGaussianNoise, ElasticTransform, ExpandChannels, RandomGamma

ImageData = Union[PIL.Image.Image, torch.Tensor]


[docs]class ImageTransformationPipeline: """ This class is the base class to classes built to define data augmentation transformations for 3D or 2D image inputs (tensor or PIL.Image). In the case of 3D images, the transformations are applied slice by slices along the Z dimension (same transformation applied for each slice). The transformations are applied channel by channel, the user can specify whether to apply the same transformation to each channel (no random shuffling) or whether each channel should use a different transformation (random parameters of transforms shuffled for each channel). """ # noinspection PyMissingConstructor def __init__(self, transforms: Union[Callable, List[Callable]], use_different_transformation_per_channel: bool = False): """ :param transforms: List of transformations to apply to images. Supports out of the boxes torchvision transforms as they accept data of arbitrary dimension. You can also define your own transform class but be aware that you function should expect input of shape [C, Z, H, W] and apply the same transformation to each Z slice. :param use_different_transformation_per_channel: if True, apply a different version of the augmentation pipeline for each channel. If False, applies the same transformation to each channel, separately. """ self.use_different_transformation_per_channel = use_different_transformation_per_channel self.pipeline = Compose(transforms) if isinstance(transforms, List) else transforms
[docs] def transform_image(self, image: ImageData) -> torch.Tensor: """ Main function to apply the transformation pipeline to either slice by slice on one 3D-image or on the 2D image. Note for 3D images: Assumes the same transformations have to be applied on each 2D-slice along the Z-axis. Assumes the Z axis is the first dimension. :param image: batch of tensor images of size [C, Z, Y, X] or batch of 2D images as PIL Image """ def _convert_to_tensor_if_necessary(data: ImageData) -> torch.Tensor: return to_tensor(data) if not isinstance(data, torch.Tensor) else data image = _convert_to_tensor_if_necessary(image) original_input_is_2d = len(image.shape) == 3 # If we have a 2D image [C, H, W] expand to [Z, C, H, W]. Build-in torchvision transforms allow such 4D inputs. if original_input_is_2d: image = image.unsqueeze(0) else: # Some transforms assume the order of dimension is [..., C, H, W] so permute first and last dimension to # obtain [Z, C, H, W] if len(image.shape) != 4: raise ValueError(f"ScalarDataset should load images as 4D tensor [C, Z, H, W]. The input tensor here" f"was of shape {image.shape}. This is unexpected.") image = torch.transpose(image, 1, 0) if not self.use_different_transformation_per_channel: image = _convert_to_tensor_if_necessary(self.pipeline(image)) else: channels = [] for channel in range(image.shape[1]): channels.append(_convert_to_tensor_if_necessary(self.pipeline(image[:, channel, :, :].unsqueeze(1)))) image = torch.cat(channels, dim=1) # Back to [C, Z, H, W] image = torch.transpose(image, 1, 0) if original_input_is_2d: image = image.squeeze(1) return image.to(dtype=image.dtype)
def __call__(self, data: ImageData) -> torch.Tensor: return self.transform_image(data)
[docs]def create_transforms_from_config(config: CfgNode, apply_augmentations: bool, expand_channels: bool = True) -> ImageTransformationPipeline: """ Defines the image transformations pipeline from a config file. It has been designed for Chest X-Ray images but it can be used for other types of images data, type of augmentations to use and strength are expected to be defined in the config. The channel expansion is needed for gray images. :param config: config yaml file fixing strength and type of augmentation to apply :param apply_augmentations: if True return transformation pipeline with augmentations. Else, disable augmentations i.e. only resize and center crop the image. :param expand_channels: if True the expand channel transformation from InnerEye.ML.augmentations.image_transforms will be added to the transformation passed through the config. This is needed for single channel images as CXR. """ transforms: List[Any] = [] if expand_channels: transforms.append(ExpandChannels()) if apply_augmentations: if config.augmentation.use_random_affine: transforms.append(RandomAffine( degrees=config.augmentation.random_affine.max_angle, translate=(config.augmentation.random_affine.max_horizontal_shift, config.augmentation.random_affine.max_vertical_shift), shear=config.augmentation.random_affine.max_shear )) if config.augmentation.use_random_crop: transforms.append(RandomResizedCrop( scale=config.augmentation.random_crop.scale, size=config.preprocess.resize )) else: transforms.append(Resize(size=config.preprocess.resize)) if config.augmentation.use_random_horizontal_flip: transforms.append(RandomHorizontalFlip(p=config.augmentation.random_horizontal_flip.prob)) if config.augmentation.use_gamma_transform: transforms.append(RandomGamma(scale=config.augmentation.gamma.scale)) if config.augmentation.use_random_color: transforms.append(ColorJitter( brightness=config.augmentation.random_color.brightness, contrast=config.augmentation.random_color.contrast, saturation=config.augmentation.random_color.saturation )) if config.augmentation.use_elastic_transform: transforms.append(ElasticTransform( alpha=config.augmentation.elastic_transform.alpha, sigma=config.augmentation.elastic_transform.sigma, p_apply=config.augmentation.elastic_transform.p_apply )) transforms.append(CenterCrop(config.preprocess.center_crop_size)) if config.augmentation.use_random_erasing: transforms.append(RandomErasing( scale=config.augmentation.random_erasing.scale, ratio=config.augmentation.random_erasing.ratio )) if config.augmentation.add_gaussian_noise: transforms.append(AddGaussianNoise( p_apply=config.augmentation.gaussian_noise.p_apply, std=config.augmentation.gaussian_noise.std )) else: transforms += [Resize(size=config.preprocess.resize), CenterCrop(config.preprocess.center_crop_size)] pipeline = ImageTransformationPipeline(transforms) return pipeline