Source code for InnerEye.ML.models.losses.mixture

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
from typing import Any, List, Tuple

import torch

from InnerEye.ML.utils.supervised_criterion import SupervisedLearningCriterion


[docs]class MixtureLoss(SupervisedLearningCriterion): def __init__(self, components: List[Tuple[float, SupervisedLearningCriterion]]): """ Loss function defined as a weighted mixture (interpolation) of other loss functions. :param components: a non-empty list of weights and loss function instances. """ super().__init__() if not components: raise ValueError("At least one (weight, loss_function) pair must be supplied.") self.components = components
[docs] def forward_minibatch(self, output: torch.Tensor, target: torch.Tensor, **kwargs: Any) -> torch.Tensor: """ Wrapper for mixture loss function implemented in PyTorch. Arguments should be suitable for the component loss functions, typically: :param output: Class logits (unnormalised), e.g. in 3D : BxCxWxHxD or in 1D BxC :param target: Target labels encoded in one-hot representation, e.g. in 3D BxCxWxHxD or in 1D BxC """ result = None for (weight, loss_function) in self.components: loss = weight * loss_function(output, target, **kwargs) if result is None: result = loss else: result = result + loss assert result is not None torch.cuda.empty_cache() return result