Source code for InnerEye.ML.SSL.datamodules_and_datasets.dataset_cls_utils
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, Tuple, Union
import torch
OptionalIndexInputAndLabel = Union[Tuple[torch.Tensor, int], Tuple[int, torch.Tensor, int]]
[docs]class InnerEyeDataClassBaseWithReturnIndex:
"""
Class to be use with double inheritance with a VisionDataset.
Overloads the __getitem__ function so that we can optionally also return
the index within the dataset.
"""
def __init__(self, root: str, return_index: bool, **kwargs: Any) -> None:
self.return_index = return_index
super().__init__(root=root, **kwargs) # type: ignore
def __getitem__(self, index: int) -> Any:
item = super().__getitem__(index) # type: ignore
if self.return_index:
return (index, *item)
else:
return item
@property
def num_classes(self) -> int:
raise NotImplementedError