Source code for InnerEye.ML.dataset.cropping_dataset

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import logging
from typing import Any, Dict, List, Optional

import numpy as np
import pandas as pd

from InnerEye.ML.augmentations.augmentation_for_segmentation_utils import random_crop
from InnerEye.Common.common_util import any_pairwise_larger
from InnerEye.Common.type_annotations import TupleInt3
from InnerEye.ML.config import PaddingMode, SegmentationModelBase
from InnerEye.ML.dataset.full_image_dataset import FullImageDataset
from InnerEye.ML.dataset.sample import CroppedSample, Sample
from InnerEye.ML.utils import image_util
from InnerEye.ML.utils.image_util import pad_images
from InnerEye.ML.utils.io_util import ImageDataType
from InnerEye.ML.utils.transforms import Compose3D


[docs]class CroppingDataset(FullImageDataset): """ Dataset class that creates random cropped samples from full 3D images from a given pd.DataFrame. The following are the operations performed to generate a sample from this dataset. The crops extracted are of size crop_size which is defined in the model config, and the crop center class population is distributed as per the class_weights vector in the model config (which by default weights all classes equally) """ def __init__(self, args: SegmentationModelBase, data_frame: pd.DataFrame, cropped_sample_transforms: Optional[Compose3D[CroppedSample]] = None, full_image_sample_transforms: Optional[Compose3D[Sample]] = None): super().__init__(args, data_frame, full_image_sample_transforms) self.cropped_sample_transforms = cropped_sample_transforms
[docs] def __getitem__(self, i: int) -> Dict[str, Any]: sample = CroppingDataset.create_possibly_padded_sample_for_cropping( sample=super().get_samples_at_index(index=i)[0], crop_size=self.args.crop_size, padding_mode=self.args.padding_mode ) sample = self.create_random_cropped_sample( sample=sample, crop_size=self.args.crop_size, center_size=self.args.center_size, class_weights=self.args.class_weights ) return Compose3D.apply(self.cropped_sample_transforms, sample).get_dict()
[docs] @staticmethod def create_possibly_padded_sample_for_cropping(sample: Sample, crop_size: TupleInt3, padding_mode: PaddingMode) -> Sample: """ Pad the original sample such the the provided images has the same (or slightly larger in case of uneven difference) shape to the output_size, using the provided padding mode. :param sample: Sample to pad. :param crop_size: Crop size to match. :param padding_mode: The padding scheme to apply. :return: padded sample """ image_spatial_shape = sample.image.shape[-3:] if any_pairwise_larger(crop_size, image_spatial_shape): if padding_mode == PaddingMode.NoPadding: raise ValueError( "The crop_size across each dimension should be greater than zero and less than or equal " f"to the current value (crop_size: {crop_size}, spatial shape: {image_spatial_shape}) " "or padding_mode must be set to enable padding") else: sample = sample.clone_with_overrides( image=pad_images(sample.image, crop_size, padding_mode), mask=pad_images(sample.mask, crop_size, padding_mode), labels=pad_images(sample.labels, crop_size, padding_mode) ) logging.debug(f"Padded sample for patient: {sample.patient_id}, from spatial dimensions: " f"{image_spatial_shape}, to: {sample.image.shape[-3:]}") return sample
[docs] @staticmethod def create_random_cropped_sample(sample: Sample, crop_size: TupleInt3, center_size: TupleInt3, class_weights: Optional[List[float]] = None) -> CroppedSample: """ Creates an instance of a cropped sample extracted from full 3D images. :param sample: the full size 3D sample to use for extracting a cropped sample. :param crop_size: the size of the crop to extract. :param center_size: the size of the center of the crop (this should be the same as the spatial dimensions of the posteriors that the model produces) :param class_weights: the distribution to use for the crop center class. :return: CroppedSample """ # crop the original raw sample sample, center_point = random_crop( sample=sample, crop_size=crop_size, class_weights=class_weights ) # crop the mask and label centers if required if center_size == crop_size: mask_center_crop = sample.mask labels_center_crop = sample.labels else: mask_center_crop = image_util.get_center_crop(image=sample.mask, crop_shape=center_size) labels_center_crop = np.zeros(shape=[len(sample.labels)] + list(center_size), # type: ignore dtype=ImageDataType.SEGMENTATION.value) for c in range(len(sample.labels)): # type: ignore labels_center_crop[c] = image_util.get_center_crop( image=sample.labels[c], crop_shape=center_size ) return CroppedSample( image=sample.image, mask=sample.mask, labels=sample.labels, mask_center_crop=mask_center_crop, labels_center_crop=labels_center_crop, center_indices=center_point, metadata=sample.metadata )