Source code for InnerEye.ML.configs.segmentation.HeadAndNeckPaper

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import logging
import random
from typing import Any, Optional

from InnerEye.ML.config import SliceExclusionRule, SummedProbabilityRule, equally_weighted_classes
from InnerEye.ML.configs.segmentation.HeadAndNeckBase import HeadAndNeckBase
from InnerEye.ML.utils.model_metadata_util import generate_random_colours_list

# List of structures to segment. The order is important, because different values of num_structures
# in the constructor will select different prefixes of the list.

STRUCTURE_LIST = ["external", "parotid_l", "parotid_r", "smg_l", "smg_r", "spinal_cord", "brainstem",
                  "globe_l", "globe_r", "mandible", "spc_muscle", "mpc_muscle", "cochlea_l", "cochlea_r",
                  "lens_l", "lens_r", "optic_chiasm", "optic_nerve_l", "optic_nerve_r", "pituitary_gland",
                  "lacrimal_gland_l", "lacrimal_gland_r"]
RANDOM_COLOUR_GENERATOR = random.Random(0)
COLOURS = generate_random_colours_list(RANDOM_COLOUR_GENERATOR, len(STRUCTURE_LIST))


[docs]class HeadAndNeckPaper(HeadAndNeckBase): """ Head and Neck model, as used in the paper. """ def __init__(self, num_structures: Optional[int] = None, **kwargs: Any) -> None: """ Creates a new instance of the class. :param num_structures: number of structures from STRUCTURE_LIST to predict (default: all structures) :param kwargs: Additional arguments that will be passed through to the SegmentationModelBase constructor. """ # Number of structures to predict; if positive but less than the length of STRUCTURE_LIST, the relevant prefix # of STRUCTURE_LIST will be predicted. if (num_structures is not None) and \ (num_structures <= 0 or num_structures > len(STRUCTURE_LIST)): raise ValueError(f"num structures must be between 0 and {len(STRUCTURE_LIST)}") if num_structures is None: logging.info(f'Setting num_structures to: {len(STRUCTURE_LIST)}') num_structures = len(STRUCTURE_LIST) ground_truth_ids = STRUCTURE_LIST[:num_structures] if "ground_truth_ids_display_names" in kwargs: ground_truth_ids_display_names = kwargs.pop("ground_truth_ids_display_names") else: logging.info('Using default ground_truth_ids_display_names') ground_truth_ids_display_names = [f"zz_{x}" for x in ground_truth_ids] if "colours" in kwargs: colours = kwargs.pop("colours") else: logging.info('Using default colours') colours = COLOURS[:num_structures] if "fill_holes" in kwargs: fill_holes = kwargs.pop("fill_holes") else: logging.info('Using default fill_holes') fill_holes = [True] * num_structures # The amount of GPU memory required increases with both the number of structures and the # number of feature channels. The following is a sensible default to avoid out-of-memory, # but you can override is by passing in another (singleton list) value for feature_channels # from a subclass. if "num_feature_channels" in kwargs: num_feature_channels = kwargs.pop("num_feature_channels") else: logging.info('Using default num_feature_channels') num_feature_channels = 32 if num_structures <= 20 else 26 bg_weight = 0.02 if len(ground_truth_ids) > 1 else 0.25 if "class_weights" in kwargs: class_weights = kwargs.pop("class_weights") else: logging.info('Using default class_weights') class_weights = equally_weighted_classes(ground_truth_ids, background_weight=bg_weight) # In case of vertical overlap between brainstem and spinal_cord, we separate them # by converting brainstem voxels to cord, as the latter is clinically more sensitive. # We do the same to separate SPC and MPC; in this case, the direction of change is unimportant, # so we choose SPC-to-MPC arbitrarily. if "slice_exclusion_rules" in kwargs: slice_exclusion_rules = kwargs.pop("slice_exclusion_rules") else: logging.info('Using default slice_exclusion_rules') slice_exclusion_rules = [] if "brainstem" in ground_truth_ids and "spinal_cord" in ground_truth_ids: slice_exclusion_rules.append(SliceExclusionRule("brainstem", "spinal_cord", False)) if "spc_muscle" in ground_truth_ids and "mpc_muscle" in ground_truth_ids: slice_exclusion_rules.append(SliceExclusionRule("spc_muscle", "mpc_muscle", False)) if "optic_chiasm" in ground_truth_ids and "pituitary_gland" in ground_truth_ids: slice_exclusion_rules.append(SliceExclusionRule("optic_chiasm", "pituitary_gland", True)) if "summed_probability_rules" in kwargs: summed_probability_rules = kwargs.pop("summed_probability_rules") else: logging.info('Using default summed_probability_rules') summed_probability_rules = [] if "brainstem" in ground_truth_ids and "spinal_cord" in ground_truth_ids and \ "external" in ground_truth_ids: summed_probability_rules.append(SummedProbabilityRule("spinal_cord", "brainstem", "external")) if "spc_muscle" in ground_truth_ids and "mpc_muscle" in ground_truth_ids and \ "external" in ground_truth_ids: summed_probability_rules.append(SummedProbabilityRule("mpc_muscle", "spc_muscle", "external")) if "optic_chiasm" in ground_truth_ids and "pituitary_gland" in ground_truth_ids and \ "external" in ground_truth_ids: summed_probability_rules.append(SummedProbabilityRule("optic_chiasm", "pituitary_gland", "external")) super().__init__( ground_truth_ids=ground_truth_ids, ground_truth_ids_display_names=ground_truth_ids_display_names, colours=colours, fill_holes=fill_holes, class_weights=class_weights, slice_exclusion_rules=slice_exclusion_rules, summed_probability_rules=summed_probability_rules, num_feature_channels=num_feature_channels, **kwargs)