Source code for InnerEye.Common.generic_parsing

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
from __future__ import annotations

import argparse
import logging
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union

import param
from param.parameterized import Parameter

from InnerEye.Common.common_util import is_private_field_name
from InnerEye.Common.type_annotations import T

# Need this as otherwise a description of all the params in a class is added to the class docstring
# which makes generated documentation with sphinx messy.
param.parameterized.docstring_signature = False
param.parameterized.docstring_describe_params = False


[docs]class ListOrDictParam(param.Parameter): """ Wrapper class to allow either a List or Dict inside of a Parameterized object. """ def _validate(self, val: Any) -> None: if not (self.allow_None and val is None): if not (isinstance(val, List) or isinstance(val, Dict)): raise ValueError(f"{val} must be an instance of List or Dict, found {type(val)}") super()._validate(val)
[docs]class StringOrStringList(param.Parameter): """ Wrapper class to allow either a string or a list of strings. Internally represented always as a list. """ def _validate(self, val: Any) -> None: if isinstance(val, str): return if isinstance(val, List): if all([isinstance(v, str) for v in val]): return raise ValueError(f"{val} must be a string or a list of strings")
[docs] def set_hook(self, obj: Any, val: Any) -> Any: """ Modifies the value before calling the setter. Here, we are converting all strings to lists of strings. """ if isinstance(val, str): return [val] return val
[docs]class PathOrPathList(param.Parameter): """ Wrapper class to allow either a Path or a list of Paths. Internally represented always as a list. """ def _validate(self, val: Any) -> None: if isinstance(val, Path): return if isinstance(val, List): if all([isinstance(v, Path) for v in val]): return raise ValueError(f"{val} must be a Path object or a list of paths")
[docs] def set_hook(self, obj: Any, val: Any) -> Any: """ Modifies the value before calling the setter. Here, we are converting simple path to lists of path. """ if isinstance(val, Path): return [val] return val
[docs]class IntTuple(param.NumericTuple): """ Parameter class that must always have integer values """ def _validate(self, val: Any) -> None: super()._validate(val) if val is not None: for i, n in enumerate(val): if not isinstance(n, int): raise ValueError("{}: tuple element at index {} with value {} in {} is not an integer" .format(self.name, i, n, val))
[docs]class GenericConfig(param.Parameterized): """ Base class for all configuration classes provides helper functionality to create argparser. """ def __init__(self, should_validate: bool = True, throw_if_unknown_param: bool = False, **params: Any): """ Instantiates the config class, ignoring parameters that are not overridable. :param should_validate: If True, the validate() method is called directly after init. :param throw_if_unknown_param: If True, raise an error if the provided "params" contains any key that does not correspond to an attribute of the class. :param params: Parameters to set. """ # check if illegal arguments are passed in legal_params = self.get_overridable_parameters() illegal = [k for k, v in params.items() if (k in self.params().keys()) and (k not in legal_params)] if illegal: raise ValueError(f"The following parameters cannot be overriden as they are either " f"readonly, constant, or private members : {illegal}") if throw_if_unknown_param: # check if parameters not defined by the config class are passed in unknown = [k for k, v in params.items() if (k not in self.params().keys())] if unknown: raise ValueError(f"The following parameters do not exist: {unknown}") # set known arguments super().__init__(**{k: v for k, v in params.items() if k in legal_params.keys()}) if should_validate: self.validate()
[docs] def validate(self) -> None: """ Validation method called directly after init to be overridden by children if required """ pass
[docs] def add_and_validate(self, kwargs: Dict[str, Any], validate: bool = True) -> None: """ Add further parameters and, if validate is True, validate. We first try set_param, but that fails when the parameter has a setter. """ for key, value in kwargs.items(): try: self.set_param(key, value) except ValueError: setattr(self, key, value) if validate: self.validate()
[docs] @classmethod def create_argparser(cls: Type[GenericConfig]) -> argparse.ArgumentParser: """ Creates an ArgumentParser with all fields of the given argparser that are overridable. :return: ArgumentParser """ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) cls.add_args(parser) return parser
[docs] @classmethod def add_args(cls: Type[GenericConfig], parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """ Adds all overridable fields of the current class to the given argparser. Fields that are marked as readonly, constant or private are ignored. :param parser: Parser to add properties to. """ def parse_bool(x: str) -> bool: """ Parse a string as a bool. Supported values are case insensitive and one of: 'on', 't', 'true', 'y', 'yes', '1' for True 'off', 'f', 'false', 'n', 'no', '0' for False. :param x: string to test. :return: Bool value if string valid, otherwise a ValueError is raised. """ sx = str(x).lower() if sx in ('on', 't', 'true', 'y', 'yes', '1'): return True if sx in ('off', 'f', 'false', 'n', 'no', '0'): return False raise ValueError(f"Invalid value {x}, please supply one of True, true, false or False.") def _get_basic_type(_p: param.Parameter) -> Union[type, Callable]: """ Given a parameter, get its basic Python type, e.g.: param.Boolean -> bool. Throw exception if it is not supported. :param _p: parameter to get type and nargs for. :return: Type """ if isinstance(_p, param.Boolean): p_type: Callable = parse_bool elif isinstance(_p, param.Integer): p_type = lambda x: _p.default if x == "" else int(x) elif isinstance(_p, param.Number): p_type = lambda x: _p.default if x == "" else float(x) elif isinstance(_p, param.String): p_type = str elif isinstance(_p, param.List): p_type = lambda x: [_p.class_(item) for item in x.split(',')] elif isinstance(_p, param.NumericTuple): float_or_int = lambda y: int(y) if isinstance(_p, IntTuple) else float(y) p_type = lambda x: tuple([float_or_int(item) for item in x.split(',')]) elif isinstance(_p, param.ClassSelector): p_type = _p.class_ elif isinstance(_p, ListOrDictParam): def list_or_dict(x: str) -> Union[Dict, List]: import json if x.startswith("{") or x.startswith('['): res = json.loads(x) else: res = [str(item) for item in x.split(',')] if isinstance(res, Dict): return res elif isinstance(res, List): return res else: raise ValueError(f"Parameter of type {_p} should resolve to List or Dict") p_type = list_or_dict else: raise TypeError("Parameter of type: {} is not supported".format(_p)) return p_type def add_boolean_argument(parser: argparse.ArgumentParser, k: str, p: Parameter) -> None: """ Add a boolean argument. If the parameter default is False then allow --flag (to set it True) and --flag=Bool as usual. If the parameter default is True then allow --no-flag (to set it to False) and --flag=Bool as usual. :param parser: parser to add a boolean argument to. :param k: argument name. :param p: boolean parameter. """ if not p.default: # If the parameter default is False then use nargs="?" (argparse.OPTIONAL). # This means that the argument is optional. # If it is not supplied, i.e. in the --flag mode, use the "const" value, i.e. True. # Otherwise, i.e. in the --flag=value mode, try to parse the argument as a bool. parser.add_argument("--" + k, help=p.doc, type=parse_bool, default=False, nargs=argparse.OPTIONAL, const=True) else: # If the parameter default is True then create an exclusive group of arguments. # Either --flag=value as usual # Or --no-flag to store False in the parameter k. group = parser.add_mutually_exclusive_group(required=False) group.add_argument("--" + k, help=p.doc, type=parse_bool) group.add_argument('--no-' + k, dest=k, action='store_false') parser.set_defaults(**{k: p.default}) for k, p in cls.get_overridable_parameters().items(): # param.Booleans need to be handled separately, they are more complicated because they have # an optional argument. if isinstance(p, param.Boolean): add_boolean_argument(parser, k, p) else: parser.add_argument("--" + k, help=p.doc, type=_get_basic_type(p), default=p.default) return parser
[docs] @classmethod def parse_args(cls: Type[T], args: Optional[List[str]] = None) -> T: """ Creates an argparser based on the params class and parses stdin args (or the args provided) """ return cls(**vars(cls.create_argparser().parse_args(args))) # type: ignore
[docs] @classmethod def get_overridable_parameters(cls: Type[GenericConfig]) -> Dict[str, param.Parameter]: """ Get properties that are not constant, readonly or private (eg: prefixed with an underscore). :return: A dictionary of parameter names and their definitions. """ return dict((k, v) for k, v in cls.params().items() if cls.reason_not_overridable(v) is None)
[docs] @staticmethod def reason_not_overridable(value: param.Parameter) -> Optional[str]: """ :param value: a parameter value :return: None if the parameter is overridable; otherwise a one-word string explaining why not. """ if value.readonly: return "readonly" elif value.constant: return "constant" elif is_private_field_name(value.name): return "private" elif isinstance(value, param.Callable): return "callable" return None
[docs] def apply_overrides(self, values: Optional[Dict[str, Any]], should_validate: bool = True, keys_to_ignore: Optional[Set[str]] = None) -> Dict[str, Any]: """ Applies the provided `values` overrides to the config. Only properties that are marked as overridable are actually overwritten. :param values: A dictionary mapping from field name to value. :param should_validate: If true, run the .validate() method after applying overrides. :param keys_to_ignore: keys to ignore in reporting failed overrides. If None, do not report. :return: A dictionary with all the fields that were modified. """ def _apply(_overrides: Optional[Dict[str, Any]]) -> Dict[str, Any]: applied: Dict[str, Any] = {} if _overrides is not None: overridable_parameters = self.get_overridable_parameters().keys() for k, v in _overrides.items(): if k in overridable_parameters: applied[k] = v setattr(self, k, v) return applied actual_overrides = _apply(values) if keys_to_ignore is not None: self.report_on_overrides(values, keys_to_ignore) # type: ignore if should_validate: self.validate() return actual_overrides
[docs] def report_on_overrides(self, values: Dict[str, Any], keys_to_ignore: Set[str]) -> None: """ Logs a warning for every parameter whose value is not as given in "values", other than those in keys_to_ignore. :param values: override dictionary, parameter names to values :param keys_to_ignore: set of dictionary keys not to report on :return: None """ for key, desired in values.items(): # If this isn't an AzureConfig instance, we don't want to warn on keys intended for it. if key in keys_to_ignore: continue actual = getattr(self, key, None) if actual == desired: continue if key not in self.params(): reason = "parameter is undefined" else: val = self.params()[key] reason = self.reason_not_overridable(val) # type: ignore if reason is None: reason = "for UNKNOWN REASONS" else: reason = f"parameter is {reason}" # We could raise an error here instead - to be discussed. logging.warning(f"Override {key}={desired} failed: {reason} in class {self.__class__.name}")
[docs]def create_from_matching_params(from_object: param.Parameterized, cls_: Type[T]) -> T: """ Creates an object of the given target class, and then copies all attributes from the `from_object` to the newly created object, if there is a matching attribute. The target class must be a subclass of param.Parameterized. :param from_object: The object to read attributes from. :param cls_: The name of the class for the newly created object. :return: An instance of cls_ """ c = cls_() if not isinstance(c, param.Parameterized): raise ValueError(f"The created object must be a subclass of param.Parameterized, but got {type(c)}") for param_name, p in c.params().items(): if not p.constant and not p.readonly: setattr(c, param_name, getattr(from_object, param_name)) return c