Source code for InnerEye.ML.utils.temperature_scaling

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------

from typing import Callable, List, Tuple

import torch
from torch.optim import LBFGS  # type: ignore

from InnerEye.Common.type_annotations import T
from InnerEye.ML.deep_learning_config import TemperatureScalingConfig
from InnerEye.ML.utils.device_aware_module import DeviceAwareModule, E


[docs]class ModelWithTemperature(DeviceAwareModule): """ Torch module to wrap a model with temperature scaling. """ def __init__(self, model: DeviceAwareModule, temperature_scaling_config: TemperatureScalingConfig): super().__init__() self.model = model self.conv_in_3d = model.conv_in_3d self.temperature_scaling_config = temperature_scaling_config # assign this parameter to the first model device otherwise use PyTorch default. _model_devices = model.get_devices() _device = _model_devices[0] if _model_devices else None self.temperature = torch.nn.Parameter(torch.ones(1, device=_device), requires_grad=False)
[docs] def forward(self, *x: torch.Tensor) -> torch.Tensor: # type: ignore logits = self.model(*x) return self.temperature_scale(logits)
[docs] def get_last_encoder_layer_names(self) -> List[str]: _model: DeviceAwareModule = self.model return _model.get_last_encoder_layer_names()
[docs] def get_input_tensors(self, item: T) -> List[E]: _model: DeviceAwareModule = self.model return _model.get_input_tensors(item)
[docs] def temperature_scale(self, logits: torch.Tensor) -> torch.Tensor: """ Perform temperature scaling on logits """ # Expand temperature to match the size of logits temperature = self.temperature.expand(logits.shape) return logits / temperature
[docs] def set_temperature(self, logits: torch.Tensor, labels: torch.Tensor, criterion_fn: Callable[[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]], use_gpu: bool) -> float: """ Tune the temperature of the model using the provided logits and labels. :param logits: Logits to use to learn the temperature parameter :param labels: Labels to use to learn the temperature parameter :param criterion_fn: A criterion function s.t: (logits, labels) => (loss, ECE) :param use_gpu: If True then GPU will be used otherwise CPU will be used. :return: Optimal temperature value """ if use_gpu: logits = logits.cuda() labels = labels.cuda() # Calculate loss values before scaling before_temperature_loss, before_temperature_ece = criterion_fn(logits, labels) print('Before temperature scaling - LOSS: {:.3f} ECE: {:.3f}' .format(before_temperature_loss.item(), before_temperature_ece.item())) # Next: optimize the temperature w.r.t. the provided criterion function self.temperature.requires_grad = True optimizer = LBFGS([self.temperature], lr=self.temperature_scaling_config.lr, max_iter=self.temperature_scaling_config.max_iter) def eval_criterion() -> torch.Tensor: # zero the gradients for the next optimization step optimizer.zero_grad() loss, ece = criterion_fn(self.temperature_scale(logits), labels) loss.backward() return loss optimizer.step(eval_criterion) # type: ignore after_temperature_loss, after_temperature_ece = criterion_fn(self.temperature_scale(logits), labels) print('Optimal temperature: {:.3f}'.format(self.temperature.item())) print('After temperature scaling - LOSS: {:.3f} ECE: {:.3f}' .format(after_temperature_loss.item(), after_temperature_ece.item())) self.temperature.requires_grad = False return self.temperature.item()