Source code for InnerEye.ML.configs.segmentation.ProstatePaper

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import logging
from typing import Any

from InnerEye.ML.config import equally_weighted_classes
from InnerEye.ML.configs.segmentation.ProstateBase import ProstateBase

fg_classes = ["external", "femur_r", "femur_l", "rectum", "prostate", "bladder", "seminalvesicles"]
fg_display_names = ["External", "Femur_R", "Femur_L", "Rectum", "Prostate", "Bladder", "SeminalVesicles"]


[docs]class ProstatePaper(ProstateBase): """ Prostate radiotherapy image segmentation model, as in the paper. """ def __init__(self, **kwargs: Any) -> None: """ Creates a new instance of the class. :param kwargs: Additional arguments that will be passed through to the SegmentationModelBase constructor. """ ground_truth_ids = fg_classes if "ground_truth_ids_display_names" in kwargs: ground_truth_ids_display_names = kwargs.pop("ground_truth_ids_display_names") else: logging.info('Using default ground_truth_ids_display_names') ground_truth_ids_display_names = [f"zz_{name}" for name in fg_display_names] if "colours" in kwargs: colours = kwargs.pop("colours") else: logging.info('Using default colours') colours = [(255, 0, 0)] * len(ground_truth_ids) if "fill_holes" in kwargs: fill_holes = kwargs.pop("fill_holes") else: logging.info('Using default fill_holes') fill_holes = [True, True, True, True, True, False, True] if "class_weights" in kwargs: class_weights = kwargs.pop("class_weights") else: logging.info('Using default class_weights') class_weights = equally_weighted_classes(ground_truth_ids, background_weight=0.02) if "largest_connected_component_foreground_classes" in kwargs: largest_connected_component_foreground_classes = kwargs.pop( "largest_connected_component_foreground_classes") else: logging.info('Using default largest_connected_component_foreground_classes') largest_connected_component_foreground_classes = [name for name in ground_truth_ids if name != "seminalvesicles"] super().__init__( ground_truth_ids=ground_truth_ids, ground_truth_ids_display_names=ground_truth_ids_display_names, colours=colours, fill_holes=fill_holes, class_weights=class_weights, largest_connected_component_foreground_classes=largest_connected_component_foreground_classes, **kwargs )