Source code for InnerEye.ML.visualizers.patch_sampling

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import param

from InnerEye.ML.augmentations.augmentation_for_segmentation_utils import slicers_for_random_crop
from InnerEye.Common.generic_parsing import GenericConfig
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.dataset.cropping_dataset import CroppingDataset
from InnerEye.ML.dataset.full_image_dataset import FullImageDataset
from InnerEye.ML.dataset.sample import Sample
from InnerEye.ML.plotting import resize_and_save, scan_with_transparent_overlay
from InnerEye.ML.utils import io_util
# The name of the folder inside the default outputs folder that will holds plots that show the effect of
# sampling random patches
from InnerEye.ML.utils.image_util import get_unit_image_header

PATCH_SAMPLING_FOLDER = "patch_sampling"


[docs]class CheckPatchSamplingConfig(GenericConfig): """ Config class to store settings for patch sampling visualization script """ model_name: str = param.String("Lung", doc="InnerEye model name e.g. Lung") local_dataset: str = param.String(None, doc="Path to the local dataset (e.g. dataset folder name)") output_folder: Path = param.ClassSelector(class_=Path, default=Path("patch_sampling_visualisations"), doc="Output folder where heatmaps and sampled images are saved") number_samples: int = param.Number(10, bounds=(1, None), doc="Number of images sampled")
[docs]def visualize_random_crops(sample: Sample, config: SegmentationModelBase, output_folder: Path) -> np.ndarray: """ Simulate the effect of sampling random crops (as is done for trainig segmentation models), and store the results as a Nifti heatmap and as 3 axial/sagittal/coronal slices. The heatmap and the slices are stored in the given output folder, with filenames that contain the patient ID as the prefix. :param sample: The patient information from the dataset, with scans and ground truth labels. :param config: The model configuration. :param output_folder: The folder into which the heatmap and thumbnails should be written. :return: A numpy array that has the same size as the image, containing how often each voxel was contained in """ output_folder.mkdir(exist_ok=True, parents=True) sample = CroppingDataset.create_possibly_padded_sample_for_cropping( sample=sample, crop_size=config.crop_size, padding_mode=config.padding_mode) logging.info(f"Processing sample: {sample.patient_id}") # Exhaustively sample with random crop function image_channel0 = sample.image[0] heatmap = np.zeros(image_channel0.shape, dtype=np.uint16) # Number of repeats should fit into the range of UInt16, because we will later save the heatmap as an integer # Nifti file of that datatype. repeats = 200 for _ in range(repeats): slicers, _ = slicers_for_random_crop(sample=sample, crop_size=config.crop_size, class_weights=config.class_weights) heatmap[slicers[0], slicers[1], slicers[2]] += 1 is_3dim = heatmap.shape[0] > 1 header = sample.metadata.image_header if not header: logging.warning(f"No image header found for patient {sample.patient_id}. Using default header.") header = get_unit_image_header() if is_3dim: ct_output_name = str(output_folder / f"{sample.patient_id}_ct.nii.gz") heatmap_output_name = str(output_folder / f"{sample.patient_id}_sampled_patches.nii.gz") io_util.store_as_nifti(image=heatmap, header=header, file_name=heatmap_output_name, image_type=heatmap.dtype, scale=False) io_util.store_as_nifti(image=image_channel0, header=header, file_name=ct_output_name, image_type=sample.image.dtype, scale=False) heatmap_scaled = heatmap.astype(dtype=np.float) / heatmap.max() # If the incoming image is effectively a 2D image with degenerate Z dimension, then only plot a single # axial thumbnail. Otherwise, plot thumbnails for all 3 dimensions. dimensions = list(range(3)) if is_3dim else [0] # Center the 3 thumbnails at one of the points where the heatmap attains a maximum. This should ensure that # the thumbnails are in an area where many of the organs of interest are located. max_heatmap_index = np.unravel_index(heatmap.argmax(), heatmap.shape) if is_3dim else (0, 0, 0) for dimension in dimensions: plt.clf() scan_with_transparent_overlay(scan=image_channel0, overlay=heatmap_scaled, dimension=dimension, position=max_heatmap_index[dimension] if is_3dim else 0, spacing=header.spacing) # Construct a filename that has a dimension suffix if we are generating 3 of them. For 2dim images, skip # the suffix. thumbnail = f"{sample.patient_id}_sampled_patches" if is_3dim: thumbnail += f"_dim{dimension}" thumbnail += ".png" resize_and_save(width_inch=5, height_inch=5, filename=output_folder / thumbnail) return heatmap
[docs]def visualize_random_crops_for_dataset(config: SegmentationModelBase, output_folder: Optional[Path] = None) -> None: """ For segmentation models only: This function generates visualizations of the effect of sampling random patches for training. Visualizations are stored in both Nifti format, and as 3 PNG thumbnail files, in the output folder. :param config: The model configuration. :param output_folder: The folder in which the visualizations should be written. If not provided, use a subfolder "patch_sampling" in the model's default output folder """ dataset_splits = config.get_dataset_splits() # Load a sample using the full image data loader full_image_dataset = FullImageDataset(config, dataset_splits.train) output_folder = output_folder or config.outputs_folder / PATCH_SAMPLING_FOLDER count = min(config.show_patch_sampling, len(full_image_dataset)) for sample_index in range(count): sample = full_image_dataset.get_samples_at_index(index=sample_index)[0] visualize_random_crops(sample, config, output_folder=output_folder)