image_to_image.scheduler.warm_up
Module to define a WarmUp Scheduler which can have a adter scheduler.
Class:
- WarmUpScheduler
By Tobia Ippolito
1""" 2Module to define a WarmUp Scheduler which can have a adter scheduler. 3 4Class: 5- WarmUpScheduler 6 7By Tobia Ippolito 8""" 9# --------------------------- 10# > Imports < 11# --------------------------- 12import torch 13 14 15 16# --------------------------- 17# > Scheduler < 18# --------------------------- 19class WarmUpScheduler(object): 20 """ 21 Implements a learning rate scheduler with an initial warm-up phase. 22 23 After the warm-up phase, an optional 'after scheduler' can take over 24 to continue adjusting the learning rate according to another schedule. 25 """ 26 def __init__(self, start_lr, end_lr, optimizer, scheduler=None, step_duration=1000): 27 """ 28 Init WarmUp Scheduler. 29 30 Parameter: 31 - start_lr (float): 32 Initial learning rate at the start of warm-up. 33 - end_lr (float): 34 Final learning rate at the end of warm-up. 35 - optimizer (torch.optim.Optimizer): 36 Optimizer whose learning rate will be updated. 37 - scheduler (torch.optim.lr_scheduler._LRScheduler, optional): 38 Scheduler to apply after warm-up. 39 - step_duration (int): 40 Number of steps over which to increase the learning rate. 41 """ 42 self.start_lr = start_lr 43 self.end_lr = end_lr 44 self.current_lr = start_lr 45 self.step_duration = step_duration 46 self.optimizer = optimizer 47 self.scheduler = scheduler 48 self.current_step = 0 49 50 self.lrs = torch.linspace(start_lr, end_lr, step_duration) 51 52 # set initial LR 53 if self.optimizer: 54 for param_group in self.optimizer.param_groups: 55 param_group['lr'] = start_lr 56 else: 57 print("[WARNING] No optomizer was given to the WarmUp Scheduler!") 58 59 60 def step(self): 61 """ 62 Performs a single step in the scheduler. 63 During warm-up, linearly increases the learning rate. 64 After warm-up, delegates to the optional after-scheduler. 65 """ 66 if self.current_step < self.step_duration: 67 self.current_lr = float(self.lrs[self.current_step]) 68 if self.optimizer: 69 for param_group in self.optimizer.param_groups: 70 param_group['lr'] = self.current_lr 71 else: 72 if self.scheduler is not None: 73 self.scheduler.step() 74 75 self.current_step += 1 76 77 78 def get_last_lr(self): 79 """ 80 Returns the most recently applied learning rate as a list. 81 """ 82 if self.current_step < self.step_duration: 83 return [float(self.lrs[self.current_step-1])] 84 elif self.scheduler: 85 return self.scheduler.get_last_lr() 86 else: 87 return [self.end_lr] 88 89 90 def state_dict(self): 91 """ 92 Returns a dictionary with the current step and after-scheduler state. 93 """ 94 return { 95 "current_step": self.current_step, 96 "after_scheduler": self.scheduler.state_dict() if self.scheduler else None 97 } 98 99 100 def load_state_dict(self, state): 101 """ 102 Loads the scheduler state from a given dictionary. 103 104 Parameter: 105 - state (dict): 106 Dictionary containing 'current_step' and optional 'after_scheduler' state. 107 """ 108 self.current_step = state["current_step"] 109 if self.scheduler and state["after_scheduler"]: 110 self.scheduler.load_state_dict(state["after_scheduler"]) 111 112 113 114 115
class
WarmUpScheduler:
20class WarmUpScheduler(object): 21 """ 22 Implements a learning rate scheduler with an initial warm-up phase. 23 24 After the warm-up phase, an optional 'after scheduler' can take over 25 to continue adjusting the learning rate according to another schedule. 26 """ 27 def __init__(self, start_lr, end_lr, optimizer, scheduler=None, step_duration=1000): 28 """ 29 Init WarmUp Scheduler. 30 31 Parameter: 32 - start_lr (float): 33 Initial learning rate at the start of warm-up. 34 - end_lr (float): 35 Final learning rate at the end of warm-up. 36 - optimizer (torch.optim.Optimizer): 37 Optimizer whose learning rate will be updated. 38 - scheduler (torch.optim.lr_scheduler._LRScheduler, optional): 39 Scheduler to apply after warm-up. 40 - step_duration (int): 41 Number of steps over which to increase the learning rate. 42 """ 43 self.start_lr = start_lr 44 self.end_lr = end_lr 45 self.current_lr = start_lr 46 self.step_duration = step_duration 47 self.optimizer = optimizer 48 self.scheduler = scheduler 49 self.current_step = 0 50 51 self.lrs = torch.linspace(start_lr, end_lr, step_duration) 52 53 # set initial LR 54 if self.optimizer: 55 for param_group in self.optimizer.param_groups: 56 param_group['lr'] = start_lr 57 else: 58 print("[WARNING] No optomizer was given to the WarmUp Scheduler!") 59 60 61 def step(self): 62 """ 63 Performs a single step in the scheduler. 64 During warm-up, linearly increases the learning rate. 65 After warm-up, delegates to the optional after-scheduler. 66 """ 67 if self.current_step < self.step_duration: 68 self.current_lr = float(self.lrs[self.current_step]) 69 if self.optimizer: 70 for param_group in self.optimizer.param_groups: 71 param_group['lr'] = self.current_lr 72 else: 73 if self.scheduler is not None: 74 self.scheduler.step() 75 76 self.current_step += 1 77 78 79 def get_last_lr(self): 80 """ 81 Returns the most recently applied learning rate as a list. 82 """ 83 if self.current_step < self.step_duration: 84 return [float(self.lrs[self.current_step-1])] 85 elif self.scheduler: 86 return self.scheduler.get_last_lr() 87 else: 88 return [self.end_lr] 89 90 91 def state_dict(self): 92 """ 93 Returns a dictionary with the current step and after-scheduler state. 94 """ 95 return { 96 "current_step": self.current_step, 97 "after_scheduler": self.scheduler.state_dict() if self.scheduler else None 98 } 99 100 101 def load_state_dict(self, state): 102 """ 103 Loads the scheduler state from a given dictionary. 104 105 Parameter: 106 - state (dict): 107 Dictionary containing 'current_step' and optional 'after_scheduler' state. 108 """ 109 self.current_step = state["current_step"] 110 if self.scheduler and state["after_scheduler"]: 111 self.scheduler.load_state_dict(state["after_scheduler"])
Implements a learning rate scheduler with an initial warm-up phase.
After the warm-up phase, an optional 'after scheduler' can take over to continue adjusting the learning rate according to another schedule.
WarmUpScheduler(start_lr, end_lr, optimizer, scheduler=None, step_duration=1000)
27 def __init__(self, start_lr, end_lr, optimizer, scheduler=None, step_duration=1000): 28 """ 29 Init WarmUp Scheduler. 30 31 Parameter: 32 - start_lr (float): 33 Initial learning rate at the start of warm-up. 34 - end_lr (float): 35 Final learning rate at the end of warm-up. 36 - optimizer (torch.optim.Optimizer): 37 Optimizer whose learning rate will be updated. 38 - scheduler (torch.optim.lr_scheduler._LRScheduler, optional): 39 Scheduler to apply after warm-up. 40 - step_duration (int): 41 Number of steps over which to increase the learning rate. 42 """ 43 self.start_lr = start_lr 44 self.end_lr = end_lr 45 self.current_lr = start_lr 46 self.step_duration = step_duration 47 self.optimizer = optimizer 48 self.scheduler = scheduler 49 self.current_step = 0 50 51 self.lrs = torch.linspace(start_lr, end_lr, step_duration) 52 53 # set initial LR 54 if self.optimizer: 55 for param_group in self.optimizer.param_groups: 56 param_group['lr'] = start_lr 57 else: 58 print("[WARNING] No optomizer was given to the WarmUp Scheduler!")
Init WarmUp Scheduler.
Parameter:
- start_lr (float): Initial learning rate at the start of warm-up.
- end_lr (float): Final learning rate at the end of warm-up.
- optimizer (torch.optim.Optimizer): Optimizer whose learning rate will be updated.
- scheduler (torch.optim.lr_scheduler._LRScheduler, optional): Scheduler to apply after warm-up.
- step_duration (int): Number of steps over which to increase the learning rate.
def
step(self):
61 def step(self): 62 """ 63 Performs a single step in the scheduler. 64 During warm-up, linearly increases the learning rate. 65 After warm-up, delegates to the optional after-scheduler. 66 """ 67 if self.current_step < self.step_duration: 68 self.current_lr = float(self.lrs[self.current_step]) 69 if self.optimizer: 70 for param_group in self.optimizer.param_groups: 71 param_group['lr'] = self.current_lr 72 else: 73 if self.scheduler is not None: 74 self.scheduler.step() 75 76 self.current_step += 1
Performs a single step in the scheduler. During warm-up, linearly increases the learning rate. After warm-up, delegates to the optional after-scheduler.
def
get_last_lr(self):
79 def get_last_lr(self): 80 """ 81 Returns the most recently applied learning rate as a list. 82 """ 83 if self.current_step < self.step_duration: 84 return [float(self.lrs[self.current_step-1])] 85 elif self.scheduler: 86 return self.scheduler.get_last_lr() 87 else: 88 return [self.end_lr]
Returns the most recently applied learning rate as a list.
def
state_dict(self):
91 def state_dict(self): 92 """ 93 Returns a dictionary with the current step and after-scheduler state. 94 """ 95 return { 96 "current_step": self.current_step, 97 "after_scheduler": self.scheduler.state_dict() if self.scheduler else None 98 }
Returns a dictionary with the current step and after-scheduler state.
def
load_state_dict(self, state):
101 def load_state_dict(self, state): 102 """ 103 Loads the scheduler state from a given dictionary. 104 105 Parameter: 106 - state (dict): 107 Dictionary containing 'current_step' and optional 'after_scheduler' state. 108 """ 109 self.current_step = state["current_step"] 110 if self.scheduler and state["after_scheduler"]: 111 self.scheduler.load_state_dict(state["after_scheduler"])
Loads the scheduler state from a given dictionary.
Parameter:
- state (dict): Dictionary containing 'current_step' and optional 'after_scheduler' state.