Source code for InnerEye.ML.visualizers.activation_maps

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import os
from typing import List, Optional

import matplotlib.pyplot as plt
import numpy as np
import torch

from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.dataset.sample import CroppedSample
from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.utils.model_util import create_model_with_temperature_scaling
from InnerEye.ML.visualizers import model_hooks


[docs]def vis_activation_map(activation_map: np.ndarray) -> np.ndarray: """ Normalizes the activation map and maps it to RGB range for visualization :param activation_map: :return: """ min_val = np.min(activation_map) activation_map += abs(min_val) # scale to RGB activation_map = (activation_map / np.max(activation_map)) * 255.0 return activation_map
[docs]def visualize_2d_activation_map(activation_map: np.ndarray, args: ModelConfigBase, slice_index: int = 0) -> None: """ Saves all feature channels of a 2D activation map as png files :param activation_map: :param args: :param slice_index: :return: """ destination_directory = str(args.outputs_folder / "activation_maps") if not os.path.exists(destination_directory): os.mkdir(destination_directory) for feat in range(activation_map.shape[0]): plt.imshow(vis_activation_map(activation_map[feat])) plt.savefig(os.path.join(destination_directory, "slice_" + str(slice_index) + "_feature_" + (str(feat) + "_Activation_Map.png")))
[docs]def visualize_3d_activation_map(activation_map: np.ndarray, args: ModelConfigBase, slices_to_visualize: Optional[List[int]] = None) -> None: """ Saves all feature channels of a 3D activation map as png files :param activation_map: :param args: :param slices_to_visualize: :return: """ # Only visualize some slices, random choice if not set if slices_to_visualize is None: slices_to_visualize = np.random.randint(0, activation_map.shape[1], 2).tolist() for _slice in slices_to_visualize: visualize_2d_activation_map(activation_map[:, _slice, :, :], args, slice_index=_slice)
[docs]def extract_activation_maps(args: ModelConfigBase) -> None: """ Extracts and saves activation maps of a specific layer of a trained network :param args: :return: """ model = create_model_with_temperature_scaling(args) if args.use_gpu: model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) model = model.cuda() checkpoint_path = args.get_path_to_checkpoint() if checkpoint_path.is_file(): checkpoint = torch.load(checkpoint_path) # type: ignore model.load_state_dict(checkpoint['state_dict']) else: raise FileNotFoundError("Could not find checkpoint") model.eval() val_loader = args.create_data_loaders()[ModelExecutionMode.VAL] feature_extractor = model_hooks.HookBasedFeatureExtractor(model, layer_name=args.activation_map_layers) for batch, sample in enumerate(val_loader): sample = CroppedSample.from_dict(sample=sample) input_image = sample.image.cuda().float() feature_extractor(input_image) # access first image of batch of feature maps activation_map = feature_extractor.outputs[0][0].cpu().numpy() if len(activation_map.shape) == 4: visualize_3d_activation_map(activation_map, args) elif len(activation_map.shape) == 3: visualize_2d_activation_map(activation_map, args) else: raise NotImplementedError('cannot visualize activation map of shape', activation_map.shape) # Only visualize the first validation example break