Source code for InnerEye.ML.common
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from __future__ import annotations
import abc
from datetime import datetime
from enum import Enum, unique
from pathlib import Path
from typing import Any, Dict, List
DATASET_CSV_FILE_NAME = "dataset.csv"
CHECKPOINT_SUFFIX = ".ckpt"
# The file names for the legacy "recovery" checkpoints behaviour, which stored the most recent N checkpoints
LEGACY_RECOVERY_CHECKPOINT_FILE_NAME = "recovery"
# The file names for the new recovery checkpoint behaviour: A single fixed checkpoint that is written every N epochs.
# Lightning does not overwrite files in place, and will hence create files "autosave.ckpt", "autosave-v1.ckpt"
# alternatingly
AUTOSAVE_CHECKPOINT_FILE_NAME = "autosave"
AUTOSAVE_CHECKPOINT_CANDIDATES = [AUTOSAVE_CHECKPOINT_FILE_NAME + CHECKPOINT_SUFFIX,
AUTOSAVE_CHECKPOINT_FILE_NAME + "-v1" + CHECKPOINT_SUFFIX]
# This is a constant that must match a filename defined in pytorch_lightning.ModelCheckpoint, but we don't want
# to import that here.
LAST_CHECKPOINT_FILE_NAME = "last"
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX = LAST_CHECKPOINT_FILE_NAME + CHECKPOINT_SUFFIX
FINAL_MODEL_FOLDER = "final_model"
FINAL_ENSEMBLE_MODEL_FOLDER = "final_ensemble_model"
CHECKPOINT_FOLDER = "checkpoints"
VISUALIZATION_FOLDER = "visualizations"
EXTRA_RUN_SUBFOLDER = "extra_run_id"
ARGS_TXT = "args.txt"
[docs]@unique
class ModelExecutionMode(Enum):
"""
Model execution mode
"""
TRAIN = "Train"
TEST = "Test"
VAL = "Val"
STORED_CSV_FILE_NAMES = \
{
ModelExecutionMode.TRAIN: "train_dataset.csv",
ModelExecutionMode.TEST: "test_dataset.csv",
ModelExecutionMode.VAL: "val_dataset.csv"
}
[docs]class OneHotEncoderBase(abc.ABC):
"""Abstract class for a one hot encoder object"""
[docs] @abc.abstractmethod
def encode(self, x: Dict[str, List[str]]) -> Any:
"""Encode dict mapping features to values and returns encoded vector."""
raise NotImplementedError("encode must be implemented by sub classes")
[docs] @abc.abstractmethod
def get_supported_dataset_column_names(self) -> List[str]:
"""Gets the names of the columns that this encoder supports"""
raise NotImplementedError("get_columns must be implemented by sub classes")
[docs] @abc.abstractmethod
def get_feature_length(self, column: str) -> int:
"""Gets the expected feature lengths for one hot encoded features using this encoder"""
raise NotImplementedError("get_feature_length must be implemented by sub classes")
[docs]def create_unique_timestamp_id() -> str:
"""
Creates a unique string using the current time in UTC, up to seconds precision, with characters that
are suitable for use in filenames. For example, on 31 Dec 2019 at 11:59:59pm UTC, the result would be
2019-12-31T235959Z.
"""
unique_id = datetime.utcnow().strftime("%Y-%m-%dT%H%M%SZ")
return unique_id
[docs]def get_best_checkpoint_path(path: Path) -> Path:
"""
Given a path and checkpoint, formats a path based on the checkpoint file name format.
:param path to checkpoint folder
"""
return path / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX