# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import random
from typing import List, Tuple
import numpy as np
from InnerEye.Common.common_util import any_pairwise_larger
from InnerEye.Common.type_annotations import TupleInt3
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.dataset.sample import Sample
[docs]def random_select_patch_center(sample: Sample, class_weights: List[float] = None) -> np.ndarray:
"""
Samples a point to use as the coordinates of the patch center. First samples one
class among the available classes then samples a center point among the pixels of the sampled
class.
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop voxel belongs
to (must sum to 1), uniform distribution assumed if none provided.
:return: numpy int array (3x1) containing patch center spatial coordinates
"""
num_classes = sample.labels.shape[0]
if class_weights is not None:
if len(class_weights) != num_classes:
raise Exception("A weight must be provided for each class, found weights:{}, expected:{}"
.format(len(class_weights), num_classes))
SegmentationModelBase.validate_class_weights(class_weights)
# If class weights are not initialised, selection is made with equal probability for all classes
available_classes = list(range(num_classes))
original_class_weights = class_weights
while len(available_classes) > 0:
selected_label_class = random.choices(population=available_classes, weights=class_weights, k=1)[0]
# Check pixels where mask and label maps are both foreground
indices = np.argwhere(np.logical_and(sample.labels[selected_label_class] == 1.0, sample.mask == 1))
if not np.any(indices):
available_classes.remove(selected_label_class)
if class_weights is not None:
assert original_class_weights is not None # for mypy
class_weights = [original_class_weights[i] for i in available_classes]
if sum(class_weights) <= 0.0:
raise ValueError("Cannot sample a class: no class present in the sample has a positive weight")
else:
break
# Raise an exception if non of the foreground classes are overlapping with the mask
if len(available_classes) == 0:
raise Exception("No non-mask voxels found, please check your mask and labels map")
# noinspection PyUnboundLocalVariable
choice = random.randint(0, len(indices) - 1)
return indices[choice].astype(int) # Numpy usually stores as floats
[docs]def slicers_for_random_crop(sample: Sample,
crop_size: TupleInt3,
class_weights: List[float] = None) -> Tuple[List[slice], np.ndarray]:
"""
Computes array slicers that produce random crops of the given crop_size.
The selection of the center is dependant on background probability.
By default it does not center on background.
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
:param crop_size: The size of the crop expressed as a list of 3 ints, one per spatial dimension.
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
:return: Tuple element 1: The slicers that convert the input image to the chosen crop. Tuple element 2: The
indices of the center point of the crop.
:raises ValueError: If there are shape mismatches among the arguments or if the crop size is larger than the image.
"""
shape = sample.image.shape[1:]
if any_pairwise_larger(crop_size, shape):
raise ValueError("The crop_size across each dimension should be greater than zero and less than or equal "
"to the current value (crop_size: {}, spatial shape: {})"
.format(crop_size, shape))
# Sample a center pixel location for patch extraction.
center = random_select_patch_center(sample, class_weights)
# Verify and fix overflow for each dimension
left = []
for i in range(3):
margin_left = int(crop_size[i] / 2)
margin_right = crop_size[i] - margin_left
left_index = center[i] - margin_left
right_index = center[i] + margin_right
if right_index > shape[i]:
left_index = left_index - (right_index - shape[i])
if left_index < 0:
left_index = 0
left.append(left_index)
return [slice(left[x], left[x] + crop_size[x]) for x in range(0, 3)], center
[docs]def random_crop(sample: Sample,
crop_size: TupleInt3,
class_weights: List[float] = None) -> Tuple[Sample, np.ndarray]:
"""
Randomly crops images, mask, and labels arrays according to the crop_size argument.
The selection of the center is dependant on background probability.
By default it does not center on background.
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
:param crop_size: The size of the crop expressed as a list of 3 ints, one per spatial dimension.
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
:return: Tuple item 1: The cropped images, labels, and mask. Tuple item 2: The center that was chosen for the crop,
before shifting to be inside of the image. Tuple item 3: The slicers that convert the input image to the chosen
crop.
:raises ValueError: If there are shape mismatches among the arguments or if the crop size is larger than the image.
"""
slicers, center = slicers_for_random_crop(sample, crop_size, class_weights)
sample = Sample(
image=sample.image[:, slicers[0], slicers[1], slicers[2]],
labels=sample.labels[:, slicers[0], slicers[1], slicers[2]],
mask=sample.mask[slicers[0], slicers[1], slicers[2]],
metadata=sample.metadata
)
return sample, center