Source code for InnerEye.ML.models.losses.ece

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
from typing import Callable

import torch
import torch.nn.functional as F


[docs]class ECELoss(torch.nn.Module): """ Calculates the Expected Calibration Error of a model. Confidence outputs are divided into equally-sized interval bins. In each bin, we compute the confidence gap as: bin_gap = l1_norm(avg_confidence_in_bin - accuracy_in_bin) A weighted average of the gaps is then returned based on the number of samples in each bin. """ def __init__(self, n_bins: int = 15, activation: Callable = lambda x: F.softmax(x, dim=1)): """ :param n_bins: number of confidence interval bins. :param activation: callable function for logit normalisation. """ super(ECELoss, self).__init__() bin_boundaries = torch.linspace(0, 1, n_bins + 1) self.bin_lowers = bin_boundaries[:-1] self.bin_uppers = bin_boundaries[1:] self.activation = activation
[docs] def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: # type: ignore normalised_logits = self.activation(logits) confidences, predictions = torch.max(normalised_logits, 1) accuracies = predictions.eq(labels) ece = torch.zeros(1, device=logits.device) for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers): # type: ignore # Calculated 'confidence - accuracy' in each bin in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item()) prop_in_bin = in_bin.float().mean() if prop_in_bin.item() > 0: accuracy_in_bin = accuracies[in_bin].float().mean() avg_confidence_in_bin = confidences[in_bin].mean() ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin return ece