Source code for InnerEye.ML.utils.device_aware_module

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
from typing import Generic, List, TypeVar

import torch

from InnerEye.Common.type_annotations import T
from InnerEye.ML.utils.ml_util import is_gpu_available

E = TypeVar("E")


[docs]class DeviceAwareModule(torch.nn.Module, Generic[T, E]): """ Wrapper around base pytorch module class that can provide information about its devices """ def __init__(self) -> None: super().__init__() self.conv_in_3d = False
[docs] def get_devices(self) -> List[torch.device]: """ :return: a list of device ids on which this module is deployed. """ return list({x.device for x in self.parameters()})
[docs] def get_number_trainable_parameters(self) -> int: """ :return: the number of trainable parameters in the module. """ return sum(p.numel() for p in self.parameters() if p.requires_grad)
[docs] def is_model_on_gpu(self) -> bool: """ Checks if the model is cuda activated or not :return: True if the model is running on the GPU. """ try: cuda_activated = next(self.parameters()).is_cuda except StopIteration: # The model has no parameters cuda_activated = False return True if (cuda_activated and is_gpu_available()) else False
[docs] def get_input_tensors(self, item: T) -> List[E]: """ Extract the input tensors from a data sample as required by the forward pass of the module. :param item: a data sample :return: the correct input tensors for the forward pass """ raise NotImplementedError("get_input_tensor has to be" "implemented by sub classes.")
[docs] def get_last_encoder_layer_names(self) -> List[str]: """ Return the name of the last encoder layers for GradCam. Default is an empty list. """ return []