Source code for InnerEye.ML.SSL.datamodules_and_datasets.cifar_datasets

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------

from torchvision.datasets import CIFAR10, CIFAR100

from InnerEye.ML.SSL.datamodules_and_datasets.dataset_cls_utils import InnerEyeDataClassBaseWithReturnIndex


[docs]class InnerEyeCIFAR10(InnerEyeDataClassBaseWithReturnIndex, CIFAR10): """ Wrapper class around torchvision CIFAR10 class to optionally return the index on top of the image and the label in __getitem__ as well as defining num_classes property. """ @property def num_classes(self) -> int: return 10
[docs]class InnerEyeCIFAR100(InnerEyeDataClassBaseWithReturnIndex, CIFAR100): """ Wrapper class around torchvision CIFAR100 class class to optionally return the index on top of the image and the label in __getitem__ as well as defining num_classes property. """ @property def num_classes(self) -> int: return 100