Source code for InnerEye.ML.configs.segmentation.BasicModel2EpochsMoreData
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import pandas as pd
from InnerEye.ML.configs.segmentation.BasicModel2Epochs import BasicModel2Epochs
from InnerEye.ML.utils.split_dataset import DatasetSplits
[docs]class BasicModel2EpochsMoreData(BasicModel2Epochs):
"""
A clone of the basic PR build model, that has more training data, to avoid PyTorch throwing failures
because each rank does not have enough data to train on.
"""
def __init__(self) -> None:
super().__init__()
[docs] def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
return DatasetSplits.from_subject_ids(
df=dataset_df,
train_ids=['0', '1', '2', '3'],
test_ids=['4', '5', '6', '7'],
val_ids=['8', '9']
)