Source code for InnerEye.ML.visualizers.model_hooks
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, List, Optional
import torch
from torch.nn import Module
from InnerEye.ML.utils.temperature_scaling import ModelWithTemperature
[docs]class HookBasedFeatureExtractor(Module):
def __init__(self, model: Module, layer_name: Optional[List[str]]):
"""
:param model: pytorch model
:param layer_name: name of direct submodule of modle, or nested module names structure as a list
layer_name = [_2D_encoder] , layer_name = [_2D_encoder, conv_layers, 0]
"""
super().__init__()
self.inputs: List[Any] = []
self.outputs: List[Any] = []
self.layer_name = layer_name
self.model = model
self.net: Module
if isinstance(model, torch.nn.DataParallel):
self.net = model.module # type: ignore
else:
self.net = model
if isinstance(self.net, ModelWithTemperature):
self.net = self.net.model
if layer_name is not None:
self._verify_layer_name(self.net, layer_name)
def _verify_layer_name(self, model: Module, layer_name: List[str]) -> None:
"""
Recursively traverses the model and verifies if the layer name is valid
:param model: the model
:param layer_name: hierarchical list of layer names to index within model
:return:
"""
submodules = model._modules.keys() # type: ignore
submodule = model
for el in layer_name:
if el not in submodules:
raise ValueError("invalid layer name: ", el)
submodule = submodule._modules[el] # type: ignore
submodules = submodule._modules.keys() # type: ignore
[docs] def forward_hook_fn(self, module: Module, input: Any, output: Any) -> None:
"""
Registers a forward hook inside module
:param module:
:param input:
:param output:
:return:
"""
if isinstance(input, tuple):
self.inputs.append([input[index].data.clone() for index in range(len(input))])
else:
self.inputs.append(input.data.clone())
if isinstance(output, tuple):
self.outputs.append([output[index].data.clone() for index in range(len(output))])
else:
self.outputs.append(output.data.clone())
# noinspection PyTypeChecker
[docs] def forward(self, input): # type: ignore
if self.layer_name is not None:
submodule = self.net
for el in self.layer_name:
submodule = submodule._modules[el]
target_layer = submodule
hook = target_layer.register_forward_hook(self.forward_hook_fn)
else:
hook = self.net.register_forward_hook(self.forward_hook_fn)
self.model(input)
hook.remove()