Source code for InnerEye.ML.utils.config_loader
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import importlib
import inspect
import logging
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Dict, List, Optional
import param
from importlib._bootstrap import ModuleSpec
from InnerEye.Common.common_util import path_to_namespace
from InnerEye.Common.generic_parsing import GenericConfig
from InnerEye.ML.deep_learning_config import DeepLearningConfig
[docs]class ModelConfigLoader(GenericConfig):
"""
Helper class to manage model config loading
"""
model_configs_namespace: Optional[str] = param.String(default=None,
doc="Non-default namespace to search for model configs")
def __init__(self, **params: Any):
super().__init__(**params)
default_module = self.get_default_search_module()
self.module_search_specs: List[ModuleSpec] = [importlib.util.find_spec(default_module)]
if self.model_configs_namespace and self.model_configs_namespace != default_module:
# The later member of this list will take priority if a model name occurs in both, because
# dict.update is used to combine the dictionaries of models.
custom_spec = importlib.util.find_spec(self.model_configs_namespace)
if custom_spec is None:
raise ValueError(f"Search namespace {self.model_configs_namespace} was not found.")
self.module_search_specs.append(custom_spec)
[docs] @staticmethod
def get_default_search_module() -> str:
from InnerEye.ML import configs
return configs.__name__
[docs] def create_model_config_from_name(self, model_name: str) -> DeepLearningConfig:
"""
Returns a model configuration for a model of the given name. This can be either a segmentation or
classification configuration for an InnerEye built-in model, or a LightningContainer.
To avoid having to import torch here, there are no references to LightningContainer.
Searching for a class member called <model_name> in the search modules provided recursively.
:param model_name: Name of the model for which to get the configs for.
"""
if not model_name:
raise ValueError("Unable to load a model configuration because the model name is missing.")
configs: Dict[str, DeepLearningConfig] = {}
def _get_model_config(module_spec: ModuleSpec) -> Optional[DeepLearningConfig]:
"""
Given a module specification check to see if it has a class property with
the <model_name> provided, and instantiate that config class with the
provided <config_overrides>. Otherwise, return None.
:param module_spec:
:return: Instantiated model config if it was found.
"""
# noinspection PyBroadException
try:
logging.debug(f"Importing {module_spec.name}")
target_module = importlib.import_module(module_spec.name)
# The "if" clause checks that obj is a class, of the desired name, that is
# defined in this module rather than being imported into it (and hence potentially
# being found twice).
_class = next(obj for name, obj in inspect.getmembers(target_module)
if inspect.isclass(obj)
and name == model_name
and inspect.getmodule(obj) == target_module)
logging.info(f"Found class {_class} in file {module_spec.origin}")
# ignore the exception which will occur if the provided module cannot be loaded
# or the loaded module does not have the required class as a member
except Exception as e:
exception_text = str(e)
if exception_text != "":
logging.warning(f"(from attempt to import module {module_spec.name}): {exception_text}")
return None
model_config: DeepLearningConfig = _class()
return model_config
def _search_recursively_and_store(module_search_spec: ModuleSpec) -> None:
"""
Given a root namespace eg: A.B.C searches recursively in all child namespaces
for class property with the <model_name> provided. If found, this is
instantiated with the provided overrides, and added to the configs dictionary.
"""
root_namespace = module_search_spec.name
namespaces_to_search: List[str] = []
if module_search_spec.submodule_search_locations:
# There is little documentation about ModuleSpec, and in particular how the submodule search locations
# are structured. From the examples I saw, the _path field usually has two entries that only differ by
# case and/or directory separator. For ambiguous paths, there may be more search locations.
logging.debug(f"Searching through {len(module_search_spec.submodule_search_locations)} folders that "
f"match namespace {module_search_spec.name}: "
f"{module_search_spec.submodule_search_locations}")
for root in module_search_spec.submodule_search_locations:
for n in Path(root).rglob("*"):
if n.is_file() and "__pycache__" not in str(n):
sub_namespace = path_to_namespace(n, root=root)
namespaces_to_search.append(root_namespace + "." + sub_namespace)
elif module_search_spec.origin:
# The module search spec already points to a python file: Search only that.
namespaces_to_search.append(module_search_spec.name)
else:
raise ValueError(f"Unable to process module spec: {module_search_spec}")
for n in namespaces_to_search: # type: ignore
_module_spec = None
# noinspection PyBroadException
try:
_module_spec = find_spec(n) # type: ignore
except Exception:
pass
if _module_spec:
config = _get_model_config(_module_spec)
if config:
configs[n] = config # type: ignore
for search_spec in self.module_search_specs:
_search_recursively_and_store(search_spec)
if len(configs) == 0:
raise ValueError(
f"Model name {model_name} was not found in search namespaces: "
f"{[s.name for s in self.module_search_specs]}.")
elif len(configs) > 1:
raise ValueError(
f"Multiple instances of model name {model_name} were found in namespaces: {configs.keys()}.")
else:
return list(configs.values())[0]