Source code for InnerEye.ML.models.losses.cross_entropy

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
from typing import Any, Optional

import torch
import torch.nn.functional as F

from InnerEye.ML.utils.image_util import get_class_weights
from InnerEye.ML.utils.supervised_criterion import SupervisedLearningCriterion


[docs]class CrossEntropyLoss(SupervisedLearningCriterion): def __init__(self, class_weight_power: Optional[float] = None, smoothing_eps: float = 0.0, focal_loss_gamma: Optional[float] = None, ignore_index: int = -100): super().__init__(smoothing_eps) """ Multi-class cross entropy loss. :param class_weight_power: if 1.0, weights the cross-entropy term for each class equally. Class weights are inversely proportional to the number of pixels belonging to each class, raised to class_weight_power :param focal_loss_gamma: Gamma term used in focal loss to weight negative log-likelihood term: https://arxiv.org/pdf/1708.02002.pdf equation(4-5). When gamma equals to zero, it is equivalent to standard CE with no class balancing. (Gamma >= 0.0) :param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient """ if class_weight_power is not None and class_weight_power < 0.0: raise ValueError("Class-weight power should be equal to or larger than zero") if isinstance(focal_loss_gamma, float): if not (focal_loss_gamma >= 0.0): raise ValueError("Gamma should be equal to or larger than zero") self.class_weight_power = class_weight_power self.focal_loss_gamma = focal_loss_gamma self.ignore_index = ignore_index self.eps = 1e-6 @staticmethod def _get_class_weights(target_labels: torch.Tensor, num_classes: int) -> torch.Tensor: """Returns class weights inversely proportional to the number of pixels in each class""" class_weight = torch.zeros(num_classes, dtype=torch.float32) if target_labels.is_cuda: class_weight = class_weight.cuda() # Check which classes are available class_ids, class_counts = torch.unique(target_labels, sorted=False, return_counts=True) # type: ignore class_weight[class_ids] = 1.0 / class_counts.float() return class_weight
[docs] @torch.no_grad() def get_focal_loss_pixel_weights(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Computes weights for each pixel/sample inversely proportional to the posterior likelihood. :param logits: Logits tensor. :param target: Target label tensor in one-hot encoding. """ if not (torch.sum(target == 1.0) == (target.nelement() / target.shape[1])): raise ValueError("Focal loss is supported only for one-hot encoded targets") posteriors = torch.nn.functional.softmax(logits, dim=1) # noinspection PyUnresolvedReferences,PyTypeChecker pixel_weights: torch.Tensor = (1 - posteriors + self.eps).pow(self.focal_loss_gamma) # type: ignore pixel_weights = torch.sum(pixel_weights * target, dim=1) # Normalise the pixel weights scaling = pixel_weights.nelement() / (torch.sum(pixel_weights) + self.eps) return pixel_weights * scaling
@staticmethod def _verify_inputs(logits: torch.Tensor, target: torch.Tensor) -> None: """Function that verifies input data types and dimensions""" # Check input data types if not isinstance(logits, torch.Tensor) or not isinstance(target, torch.Tensor): raise TypeError("Logits and target must be torch.Tensors") if not logits.is_floating_point(): raise TypeError("Logits must be a float tensor") # Check input tensor dimensions if len(logits.shape) < 2: raise ValueError("The shape of logits must be at least 2, Batch x Class x ... " "(logits.shape: {})".format(logits.shape)) if logits.shape[0] != target.shape[0]: raise ValueError("The logits and target must have the same batch size (logits: {}, target: {})". format(logits.shape, target.shape)) if len(logits.shape) > 2: if logits.shape[1:] != target.shape[1:]: raise ValueError("The logits and target must have same shape (logits: {}, target: {})". format(logits.shape, target.shape))
[docs] def forward_minibatch(self, output: torch.Tensor, target: torch.Tensor, **kwargs: Any) -> torch.Tensor: """ Wrapper for multi-class cross entropy function implemented in PyTorch. The implementation supports tensors with arbitrary spatial dimension. Input logits are normalised internally in `F.cross_entropy` function. :param output: Class logits (unnormalised), e.g. in 3D : BxCxWxHxD or in 1D BxC :param target: Target labels encoded in one-hot representation, e.g. in 3D BxCxWxHxD or in 1D BxC """ # Convert one hot-encoded target to class indices class_weight = None # Check input tensors self._verify_inputs(output, target) # Determine class weights for unbalanced datasets if self.class_weight_power is not None and self.class_weight_power != 0.0: class_weight = get_class_weights(target, class_weight_power=self.class_weight_power) # Compute negative log-likelihood log_prob = F.log_softmax(output, dim=1) if self.smoothing_eps > 0.0: loss = -1.0 * log_prob * target if class_weight is not None: loss = torch.einsum('bc...,c->b...', loss, class_weight) else: loss = F.nll_loss(log_prob, torch.argmax(target, dim=1), weight=class_weight, reduction='none') # If focal loss is specified, apply pixel weighting if self.focal_loss_gamma is not None: pixel_weights = self.get_focal_loss_pixel_weights(output, target) loss = loss * pixel_weights return torch.mean(loss)