Source code for InnerEye.ML.utils.sequence_utils

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Sequence

import numpy as np
import torch
from torch.nn.utils.rnn import PackedSequence, pack_sequence, pad_sequence

from InnerEye.Common.common_util import check_properties_are_not_none
from InnerEye.ML.utils.image_util import NumpyOrTorch


[docs]@dataclass(frozen=True) class MaskedModelOutputAndLabelSequences: """ Dataclass to encapsulate masked model outputs, labels and associated subject ids """ model_outputs: PackedSequence labels: PackedSequence subject_ids: Optional[Sequence[str]] def __post_init__(self) -> None: check_properties_are_not_none(self, ignore=["subject_ids"]) if len(self.model_outputs.data) != len(self.labels.data): raise ValueError("model_outputs and labels must have the same length, " f"found {len(self.model_outputs.data)} and {len(self.labels.data)}") if not torch.equal(self.model_outputs.batch_sizes, self.labels.batch_sizes): raise ValueError("batch_sizes for model_outputs and labels must be equal, " f"found {self.model_outputs.batch_sizes} and {self.labels.batch_sizes}") if not torch.equal(self.model_outputs.sorted_indices, self.labels.sorted_indices): raise ValueError("sorted_indices for model_outputs and labels must be equal, " f"found {self.model_outputs.sorted_indices} and {self.labels.sorted_indices}") if not torch.equal(self.model_outputs.unsorted_indices, self.labels.unsorted_indices): raise ValueError("unsorted_indices for model_outputs and labels must be equal, " f"found {self.model_outputs.unsorted_indices} and {self.labels.unsorted_indices}") _expected_subjects = self.labels.batch_sizes.max().item() if self.subject_ids is not None and len(self.subject_ids) != _expected_subjects: raise ValueError(f"expected {_expected_subjects} subject_ids but found {len(self.subject_ids)}")
[docs]def sequences_to_padded_tensor(sequences: List[torch.Tensor], padding_value: float = 0.0) -> torch.Tensor: """ Method to convert possibly unequal length sequences to a padded tensor. :param sequences: List of Tensors to pad :param padding_value: Padding value to use, default is 0.0 :return: Output tensor with shape B x * where * is the max dimensions from the list of provided tensors. And B is the number of tensors in the list of sequences provided. """ return pad_sequence(sequences, batch_first=True, padding_value=padding_value)
[docs]def map_packed_sequence_data(x: PackedSequence, f: Callable[[torch.Tensor], torch.Tensor]) -> PackedSequence: """ Helper function to apply a map transform to a packed sequence """ _x_data = f(x.data) # make sure the function is a map function and maintains the original shape of the data tensor if x.data.shape != _x_data.shape: raise ValueError("The provided function must be a map function, but it changed the original tensor's shape" f" from {x.data.shape} to {_x_data.shape}") return PackedSequence(data=_x_data, batch_sizes=x.batch_sizes, sorted_indices=x.sorted_indices, unsorted_indices=x.unsorted_indices)
[docs]def get_masked_model_outputs_and_labels(model_output: torch.Tensor, labels: NumpyOrTorch, subject_ids: Optional[Sequence[str]] = None) \ -> Optional[MaskedModelOutputAndLabelSequences]: """ Helper function to get masked model outputs, labels and their associated subject ids. Masking is performed by excluding the NaN model outputs and labels based on a bool mask created using the occurrences of NaN in the labels provided. :param model_output: The model output tensor to mask. :param labels: The label tensor to use for mask, and use for masking. :param subject_ids: The associated subject ids. :return: None if all labels are required to be masked, otherwise MaskedModelOutputAndLabelSequences """ non_nan_idxs = ~torch.isnan(labels) _subject_ids: Optional[List[Any]] = [] if subject_ids is not None else None _model_output_tensors, _label_tensors = [], [] # iterate over each of the sequences to create masked tensors for i in range(non_nan_idxs.shape[0]): x = non_nan_idxs[i] masked_model_output, masked_labels = model_output[i, x], labels[i, x] # if all the elements of the sequence are masked, then drop the subject if masked_labels.numel() > 0: _model_output_tensors.append(masked_model_output) _label_tensors.append(masked_labels) if _subject_ids is not None: assert subject_ids is not None _subject_ids.append(subject_ids[i]) # since it is not possible to create a packed sequence with empty tensors, # make sure we have valid tensors to pack, otherwise return None. if len(_label_tensors) > 0: labels_packed = pack_sequence(_label_tensors, enforce_sorted=False) # make sure the subject ids are in the same order as the packed sequences if _subject_ids is not None: _subject_ids = np.array(_subject_ids)[labels_packed.sorted_indices.cpu()].tolist() # If there is only one subject, tolist() returns a string instead of a list. if isinstance(_subject_ids, str): _subject_ids = [_subject_ids] return MaskedModelOutputAndLabelSequences( model_outputs=pack_sequence(_model_output_tensors, enforce_sorted=False), labels=labels_packed, subject_ids=_subject_ids ) else: return None
[docs]def apply_sequence_model_loss(loss_fn: torch.nn.Module, model_output: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ Applies a loss function to a model output and labels, when the labels come from sequences with unequal length. Missing sequence elements are masked out. :param loss_fn: The loss function to apply to the sequence elements that are present. :param model_output: The model outputs :param labels: The ground truth labels. :return: The value of the loss function. """ # create masked sequences based on the labels masked_model_outputs_and_labels = get_masked_model_outputs_and_labels(model_output, labels) if masked_model_outputs_and_labels is None: raise ValueError("Invalid model_output and labels found") return loss_fn(masked_model_outputs_and_labels.model_outputs, masked_model_outputs_and_labels.labels)