Source code for InnerEye.ML.surface_distance_heatmaps

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

import numpy as np

from InnerEye.Azure.azure_config import AzureConfig
from InnerEye.Azure.azure_runner import create_runner_parser, parse_args_and_add_yaml_variables
from InnerEye.Azure.azure_util import download_run_outputs_by_prefix
from InnerEye.Common.metrics_constants import MetricsFileColumns
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.plotting import segmentation_and_groundtruth_plot, surface_distance_ground_truth_plot
from InnerEye.ML.utils import surface_distance_utils as sd_util
from InnerEye.ML.utils.config_loader import ModelConfigLoader
from InnerEye.ML.utils.csv_util import get_worst_performing_outliers, load_csv
from InnerEye.ML.utils.image_util import multi_label_array_to_binary
from InnerEye.ML.utils.io_util import load_nifti_image
from InnerEye.ML.utils.surface_distance_utils import SurfaceDistanceConfig, SurfaceDistanceRunType


[docs]@dataclass(frozen=True) class Segmentation: """ Each individual structure segmentation (whether model prediction or human annotation) will have the properties structure_name (i.e. body part), subject_id and a unique path. Optionally, it may also have an associated annotator name and calculated Dice score, compared to ground truth. """ segmentation_path: Path structure_name: str subject_id: int annotator: Optional[str] = None dice_score: Optional[float] = None
[docs]def load_predictions(run_type: SurfaceDistanceRunType, azure_config: AzureConfig, model_config: SegmentationModelBase, execution_mode: ModelExecutionMode, extended_annotators: List[str], outlier_range: float ) -> List[Segmentation]: """ For each run type (IOV or outliers), instantiate a list of predicted Segmentations and return :param run_type: either "iov" or "outliers: :param azure_config: AzureConfig :param model_config: GenericConfig :param execution_mode: ModelExecutionMode: Either Test, Train or Val :param extended_annotators: List of annotators plus model_name to load segmentations for :param outlier_range: The standard deviation from the mean which the points have to be below to be considered an outlier. :return: list of [(subject_id, structure name and dice_scores)] """ predictions = [] if run_type == SurfaceDistanceRunType.OUTLIERS: first_child_run = sd_util.get_first_child_run(azure_config) output_dir = sd_util.get_run_output_dir(azure_config, model_config) metrics_path = sd_util.get_metrics_path(azure_config, model_config) # Load the downloaded metrics CSV as dataframe and determine worst performing outliers for the Test run df = load_csv(metrics_path, [MetricsFileColumns.Patient.value, MetricsFileColumns.Structure.value]) test_run_df = df[df['mode'] == execution_mode.value] worst_performers = get_worst_performing_outliers(test_run_df, outlier_range, MetricsFileColumns.Dice.value, max_n_outliers=-50) for (subject_id, structure_name, dice_score, _) in worst_performers: subject_prefix = sd_util.get_subject_prefix(model_config, execution_mode, subject_id) # if not already present, download data for subject download_run_outputs_by_prefix( blobs_prefix=subject_prefix, destination=output_dir, run=first_child_run ) # check it has been downloaded segmentation_path = output_dir / subject_prefix / f"{structure_name}.nii.gz" predictions.append(Segmentation(structure_name=structure_name, subject_id=subject_id, segmentation_path=segmentation_path, dice_score=float(dice_score))) elif run_type == SurfaceDistanceRunType.IOV: subject_id = 0 iov_dir = Path("outputs") / SurfaceDistanceRunType.IOV.value.lower() all_structs = model_config.class_and_index_with_background() structs_to_plot = [struct_name for struct_name in all_structs.keys() if struct_name not in ['background', 'external']] for annotator in extended_annotators: for struct_name in structs_to_plot: segmentation_path = iov_dir / f"{struct_name + annotator}.nii.gz" if not segmentation_path.is_file(): logging.warning(f"No such file {segmentation_path}") continue predictions.append(Segmentation(structure_name=struct_name, subject_id=subject_id, segmentation_path=segmentation_path, annotator=annotator)) return predictions
[docs]def main() -> None: parser = create_runner_parser(SegmentationModelBase) parser_result = parse_args_and_add_yaml_variables(parser, fail_on_unknown_args=True) surface_distance_config = SurfaceDistanceConfig.parse_args() azure_config = AzureConfig(**parser_result.args) config_model = azure_config.model if config_model is None: raise ValueError("The name of the model to train must be given in the --model argument.") model_config = ModelConfigLoader().create_model_config_from_name(config_model) model_config.apply_overrides(parser_result.overrides, should_validate=True) execution_mode = surface_distance_config.execution_mode run_mode = surface_distance_config.run_mode if run_mode == SurfaceDistanceRunType.IOV: ct_path = Path("outputs") / SurfaceDistanceRunType.IOV.value.lower() / "ct.nii.gz" ct = load_nifti_image(ct_path).image else: ct = None annotators = [annotator.strip() for annotator in surface_distance_config.annotators] extended_annotators = annotators + [surface_distance_config.model_name] outlier_range = surface_distance_config.outlier_range predictions = load_predictions(run_mode, azure_config, model_config, execution_mode, extended_annotators, outlier_range) segmentations = [load_nifti_image(Path(pred_seg.segmentation_path)) for pred_seg in predictions] img_shape = segmentations[0].image.shape # transpose spacing to match image which is transposed in io_util voxel_spacing = segmentations[0].header.spacing[::-1] overall_gold_standard = np.zeros(img_shape) sds_for_annotator = sd_util.initialise_surface_distance_dictionary(extended_annotators, img_shape) plane = surface_distance_config.plane output_img_dir = Path(surface_distance_config.output_img_dir) subject_id: Optional[int] = None for prediction, pred_seg_w_header in zip(predictions, segmentations): subject_id = prediction.subject_id structure_name = prediction.structure_name annotator = prediction.annotator pred_segmentation = pred_seg_w_header.image if run_mode == SurfaceDistanceRunType.OUTLIERS: try: ground_truth = sd_util.load_ground_truth_from_run(model_config, surface_distance_config, subject_id, structure_name) except FileNotFoundError as e: logging.warning(e) continue elif run_mode == SurfaceDistanceRunType.IOV: ground_truth = sd_util.get_annotations_and_majority_vote(model_config, annotators, structure_name) else: raise ValueError(f'Unrecognised run mode: {run_mode}. Expected either IOV or OUTLIERS') binary_prediction_mask = multi_label_array_to_binary(pred_segmentation, 2)[1] # For comparison, plot gold standard vs predicted segmentation segmentation_and_groundtruth_plot(binary_prediction_mask, ground_truth, subject_id, structure_name, plane, output_img_dir, annotator=annotator) if run_mode == SurfaceDistanceRunType.IOV: overall_gold_standard += ground_truth # Calculate and plot surface distance sds_full = sd_util.calculate_surface_distances(ground_truth, binary_prediction_mask, list(voxel_spacing)) surface_distance_ground_truth_plot(ct, ground_truth, sds_full, subject_id, structure_name, plane, output_img_dir, annotator=annotator) if annotator is not None: sds_for_annotator[annotator] += sds_full # Plot all structures SDs for each annotator if run_mode == SurfaceDistanceRunType.IOV and subject_id is not None: for annotator, sds in sds_for_annotator.items(): num_classes = int(np.amax(np.unique(overall_gold_standard))) binarised_gold_standard = multi_label_array_to_binary(overall_gold_standard, num_classes)[1:].sum(axis=0) surface_distance_ground_truth_plot(ct, binarised_gold_standard, sds, subject_id, 'All', plane, output_img_dir, annotator=annotator)
if __name__ == "__main__": main()