Source code for InnerEye.ML.utils.lr_scheduler

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
from __future__ import annotations

from typing import Dict, List

from torch.optim.lr_scheduler import CosineAnnealingLR, ExponentialLR, LambdaLR, MultiStepLR, StepLR, _LRScheduler
from torch.optim.optimizer import Optimizer

from InnerEye.ML.deep_learning_config import LRSchedulerType, LRWarmUpType, OptimizerParams


[docs]def get_current_learning_rates(optimizer: Optimizer) -> List[float]: """ Reads the current values of the learning rate(s) for all parameter groups from the optimizer. """ return [group['lr'] for group in optimizer.param_groups]
[docs]class LinearWarmUp(_LRScheduler): """ Implements linear warmup up to a given initial learning rate. """ def __init__(self, optimizer: Optimizer, warmup_epochs: int, final_lr: float, last_epoch: int = -1): if warmup_epochs < 0: raise ValueError("The number of warmup epochs must be >= 0.") self.warmup_epochs = warmup_epochs self.final_lr = final_lr self.last_epoch = last_epoch super().__init__(optimizer, last_epoch)
[docs] def warmup_multiplier(self) -> float: if self.warmup_epochs <= 0: return 1.0 if self.last_epoch >= self.warmup_epochs: return 1.0 return (self.last_epoch + 1) / (self.warmup_epochs + 1)
[docs] def get_lr(self) -> List[float]: # type: ignore return [self.final_lr * self.warmup_multiplier()]
[docs]class PolynomialLR: def __init__(self, gamma: float, l_rate: float, min_l_rate: float, epochs_after_warmup: int) -> None: self.gamma = gamma self.l_rate = l_rate self.min_l_rate = min_l_rate self.epochs_after_warmup = epochs_after_warmup
[docs] def get_lr(self, epoch: int) -> float: x = self.min_l_rate / self.l_rate return (1 - x) * ((1. - float(epoch) / self.epochs_after_warmup) ** self.gamma) + x
[docs]class SchedulerWithWarmUp(_LRScheduler): """ LR Scheduler which runs a warmup schedule (linear ramp-up) for a few iterations, and then switches to one of the normal schedulers. """ def __init__(self, args: OptimizerParams, optimizer: Optimizer, num_epochs: int, last_epoch: int = -1): self.optimizer = optimizer self.last_epoch = last_epoch self.num_epochs = num_epochs self.warmup_epochs = 0 if args.l_rate_warmup == LRWarmUpType.NoWarmUp else args.l_rate_warmup_epochs self._scheduler = self.get_scheduler(args) # This must be called after self.get_scheduler, because we want the optimizer to have the learning rate # guided by the warmup schedule self._warmup = LinearWarmUp(optimizer, warmup_epochs=self.warmup_epochs, final_lr=args.l_rate, last_epoch=last_epoch) self._last_lr = get_current_learning_rates(optimizer) self.min_l_rate = args.min_l_rate super().__init__(optimizer, last_epoch)
[docs] def get_scheduler(self, args: OptimizerParams) -> _LRScheduler: """ Create the LR scheduler that will be used after warmup, based on the config params. """ scheduler: _LRScheduler epochs_after_warmup = self.num_epochs - self.warmup_epochs if args.l_rate_scheduler == LRSchedulerType.Exponential: scheduler = ExponentialLR(optimizer=self.optimizer, gamma=args.l_rate_exponential_gamma, last_epoch=self.last_epoch) elif args.l_rate_scheduler == LRSchedulerType.Step: scheduler = StepLR(optimizer=self.optimizer, step_size=args.l_rate_step_step_size, gamma=args.l_rate_step_gamma, last_epoch=self.last_epoch) elif args.l_rate_scheduler == LRSchedulerType.MultiStep: assert args.l_rate_multi_step_milestones is not None scheduler = MultiStepLR(optimizer=self.optimizer, milestones=args.l_rate_multi_step_milestones, gamma=args.l_rate_multi_step_gamma, last_epoch=self.last_epoch) elif args.l_rate_scheduler == LRSchedulerType.Polynomial: polynomial_lr = PolynomialLR(gamma=args.l_rate_polynomial_gamma, l_rate=args.l_rate, min_l_rate=args.min_l_rate, epochs_after_warmup=epochs_after_warmup) scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=polynomial_lr.get_lr, last_epoch=self.last_epoch) elif args.l_rate_scheduler == LRSchedulerType.Cosine: scheduler = CosineAnnealingLR(optimizer=self.optimizer, T_max=epochs_after_warmup, eta_min=args.min_l_rate, last_epoch=self.last_epoch) else: raise ValueError("Unknown learning rate scheduler {}".format(args.l_rate_scheduler)) return scheduler
[docs] def state_dict(self) -> Dict: """ Returns a dictionary with all the values in this objects __dict__. It creates the dictionary entry for variables "_scheduler" and "_warmup_scheduler" separately, by calling state_dict for these variables. The state dict does not include the state of the optimizer. """ state_dict = {key: val for key, val in self.__dict__.items() if key != "_scheduler" and key != "optimizer" and key != "_warmup"} state_dict['_scheduler'] = self._scheduler.state_dict() state_dict['_warmup'] = self._warmup.state_dict() return state_dict
[docs] def load_state_dict(self, state_dict: Dict) -> None: """ Initializes the current object with values from state_dict. Initializes variables "_scheduler" and "_warmup_scheduler" separately, by calling load_state_dict for these variables. """ top_level = {key: val for key, val in state_dict.items() if key != "_scheduler" and key != "_warmup"} self.__dict__.update(top_level) self._scheduler.load_state_dict(state_dict["_scheduler"]) self._warmup.load_state_dict(state_dict["_warmup"])
[docs] def step(self, epoch: int = None) -> None: # self.step() is called in the _LRScheduler.__init__, as the very last operation, when self.last_epoch == -1 # Inside of the default implementation of self.step, it calls # self.last_epoch += 1 # values = self.get_lr() # The values are then set in the optimizer, and stored in self._last_lr if epoch is not None: raise ValueError("Calling scheduler.step with an epoch argument will be deprecated.") # self.step is called from within the base class constructor, _LRScheduler.__init__ # The scheduler itself has already been initialized, and scheduler.step has also been called already in # the respective constructor. Avoid calling it again here. if self.last_epoch != -1: if self.last_epoch < self._warmup.warmup_epochs: self._warmup.step() else: self._scheduler.step() self.last_epoch += 1 self._last_lr = get_current_learning_rates(self.optimizer)
[docs] def get_last_lr(self) -> List[float]: return self._last_lr