Source code for InnerEye.ML.utils.features_util
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, List
import torch
from InnerEye.Common.common_util import check_properties_are_not_none
from InnerEye.ML.dataset.scalar_dataset import ScalarDataSource
[docs]@dataclass(frozen=True)
class FeatureStatistics:
"""
Class to store statistics (mean and standard deviation) of a set of features in a given dataset.
Allows to perform feature standardization for this set of features.
"""
mean: torch.Tensor # This tensor will have the same shape as the non-image features in the dataset.
std: torch.Tensor # This tensor will have the same shape as the non-image features in the dataset.
def __post_init__(self) -> None:
check_properties_are_not_none(self)
[docs] @staticmethod
def from_data_sources(sources: List[ScalarDataSource]) -> FeatureStatistics:
"""
For the provided data sources, compute the mean and std across all non-image features across all entries.
:param sources: list of data sources
:return: a Feature Statistics object storing mean and standard deviation for each non-imaging feature of
the dataset.
"""
if len(sources) == 0:
raise ValueError("sources must have a length greater than 0")
data_sources: List[Any] = sources
numerical_non_image_features = [x.numerical_non_image_features for x in data_sources]
if len(numerical_non_image_features) == 0:
raise ValueError("This function must be called with a non-empty set of numerical_non_image_features.")
unique_shapes = {f.shape for f in numerical_non_image_features}
if len(unique_shapes) != 1:
raise ValueError(
f"All non-image features must have the same size, but got these sizes: {unique_shapes}")
# If the input features contain infinite values (e.g. from padding)
# we need to ignore them for the computation of the normalization statistics.
all_stacked = torch.stack(numerical_non_image_features, dim=0)
return FeatureStatistics.compute_masked_statistics(input=all_stacked,
mask=torch.isfinite(all_stacked))
[docs] @staticmethod
def compute_masked_statistics(input: torch.Tensor, mask: torch.Tensor,
apply_bias_correction: bool = True) -> FeatureStatistics:
"""
If the input features contains invalid values (e.g. from padding) they should be ignored in the
computation of the standardization statistics. This function allows to provide a boolean mask (of the same
shape as the input) to indicate which values should be taken into account for the computation of the
statistics. All values for which mask == True will be used for computation, the other will be ignored.
The statistics are computed for each feature i.e. column of the input (shape [batch_size, n_numerical_features])
:param input: input including all values, of dimension [batch_size, n_numerical_features]
:param mask: boolean tensor of the same shape as input
:param apply_bias_correction: if True applies Bessel's correction to the standard deviation estimate
:return: FeatureStatistics (mean and std) computed on the masked values.
"""
n_obs_per_feature = mask.sum(dim=0).float()
masked_values = torch.zeros_like(input)
masked_values[mask] = input[mask]
mean = masked_values.sum(dim=0) / n_obs_per_feature
second_moment = torch.pow(masked_values, 2).sum(dim=0) / n_obs_per_feature
variance = second_moment - torch.pow(mean, 2)
if apply_bias_correction:
# Applies Bessel's bias correction to the std estimate (as in PyTorch's default behavior)
variance *= torch.div(n_obs_per_feature, (n_obs_per_feature - 1))
# Need to make sure variance is positive (numerical instability can make it slightly <0)
std = torch.sqrt(torch.max(variance, torch.zeros_like(variance)))
return FeatureStatistics(mean=mean, std=std)
[docs] def standardize(self, sources: List[ScalarDataSource]) -> List[ScalarDataSource]:
"""
For the provided data sources, apply standardization to the non-image features in each source. This will
standardize them to mean 0, variance 1 across all sequences.
All features that have zero standard deviation (constant features) are left untouched.
:param sources: list of datasources.
:return: list of data sources where all non-imaging features are standardized.
"""
def apply_source(source: ScalarDataSource) -> ScalarDataSource:
new_features = (source.numerical_non_image_features - self.mean) / self.std
zero_or_nan = (self.std == 0.0) + torch.isnan(self.std)
new_features[zero_or_nan] = source.numerical_non_image_features[zero_or_nan]
return source.clone_with_overrides(numerical_non_image_features=new_features)
if len(sources) > 0:
return list(map(apply_source, sources)) # type: ignore
else:
return sources