image_to_image.scheduler.warm_up

Module to define a WarmUp Scheduler which can have a adter scheduler.

Class:

  • WarmUpScheduler

By Tobia Ippolito

  1"""
  2Module to define a WarmUp Scheduler which can have a adter scheduler.
  3
  4Class:
  5- WarmUpScheduler
  6
  7By Tobia Ippolito
  8"""
  9# ---------------------------
 10#        > Imports <
 11# ---------------------------
 12import torch
 13
 14
 15
 16# ---------------------------
 17#       > Scheduler <
 18# ---------------------------
 19class WarmUpScheduler(object):
 20    """
 21    Implements a learning rate scheduler with an initial warm-up phase.
 22
 23    After the warm-up phase, an optional 'after scheduler' can take over
 24    to continue adjusting the learning rate according to another schedule.
 25    """
 26    def __init__(self, start_lr, end_lr, optimizer, scheduler=None, step_duration=1000):
 27        """
 28        Init WarmUp Scheduler.
 29
 30        Parameter:
 31        - start_lr (float): 
 32            Initial learning rate at the start of warm-up.
 33        - end_lr (float): 
 34            Final learning rate at the end of warm-up.
 35        - optimizer (torch.optim.Optimizer): 
 36            Optimizer whose learning rate will be updated.
 37        - scheduler (torch.optim.lr_scheduler._LRScheduler, optional): 
 38            Scheduler to apply after warm-up.
 39        - step_duration (int): 
 40            Number of steps over which to increase the learning rate.
 41        """
 42        self.start_lr = start_lr
 43        self.end_lr = end_lr
 44        self.current_lr = start_lr
 45        self.step_duration = step_duration
 46        self.optimizer = optimizer
 47        self.scheduler = scheduler
 48        self.current_step = 0
 49
 50        self.lrs = torch.linspace(start_lr, end_lr, step_duration)
 51
 52        # set initial LR
 53        if self.optimizer:
 54            for param_group in self.optimizer.param_groups:
 55                param_group['lr'] = start_lr
 56        else:
 57            print("[WARNING] No optomizer was given to the WarmUp Scheduler!")
 58
 59
 60    def step(self):
 61        """
 62        Performs a single step in the scheduler. 
 63        During warm-up, linearly increases the learning rate. 
 64        After warm-up, delegates to the optional after-scheduler.
 65        """
 66        if self.current_step < self.step_duration:
 67            self.current_lr = float(self.lrs[self.current_step])
 68            if self.optimizer:
 69                for param_group in self.optimizer.param_groups:
 70                    param_group['lr'] = self.current_lr
 71        else:
 72            if self.scheduler is not None:
 73                self.scheduler.step()
 74
 75        self.current_step += 1
 76
 77
 78    def get_last_lr(self):
 79        """
 80        Returns the most recently applied learning rate as a list.
 81        """
 82        if self.current_step < self.step_duration:
 83            return [float(self.lrs[self.current_step-1])]
 84        elif self.scheduler:
 85            return self.scheduler.get_last_lr()
 86        else:
 87            return [self.end_lr]
 88
 89
 90    def state_dict(self):
 91        """
 92        Returns a dictionary with the current step and after-scheduler state.
 93        """
 94        return {
 95            "current_step": self.current_step,
 96            "after_scheduler": self.scheduler.state_dict() if self.scheduler else None
 97        }
 98
 99
100    def load_state_dict(self, state):
101        """
102        Loads the scheduler state from a given dictionary.
103
104        Parameter:
105        - state (dict): 
106            Dictionary containing 'current_step' and optional 'after_scheduler' state.
107        """
108        self.current_step = state["current_step"]
109        if self.scheduler and state["after_scheduler"]:
110            self.scheduler.load_state_dict(state["after_scheduler"])
111    
112
113 
114
115    
class WarmUpScheduler:
 20class WarmUpScheduler(object):
 21    """
 22    Implements a learning rate scheduler with an initial warm-up phase.
 23
 24    After the warm-up phase, an optional 'after scheduler' can take over
 25    to continue adjusting the learning rate according to another schedule.
 26    """
 27    def __init__(self, start_lr, end_lr, optimizer, scheduler=None, step_duration=1000):
 28        """
 29        Init WarmUp Scheduler.
 30
 31        Parameter:
 32        - start_lr (float): 
 33            Initial learning rate at the start of warm-up.
 34        - end_lr (float): 
 35            Final learning rate at the end of warm-up.
 36        - optimizer (torch.optim.Optimizer): 
 37            Optimizer whose learning rate will be updated.
 38        - scheduler (torch.optim.lr_scheduler._LRScheduler, optional): 
 39            Scheduler to apply after warm-up.
 40        - step_duration (int): 
 41            Number of steps over which to increase the learning rate.
 42        """
 43        self.start_lr = start_lr
 44        self.end_lr = end_lr
 45        self.current_lr = start_lr
 46        self.step_duration = step_duration
 47        self.optimizer = optimizer
 48        self.scheduler = scheduler
 49        self.current_step = 0
 50
 51        self.lrs = torch.linspace(start_lr, end_lr, step_duration)
 52
 53        # set initial LR
 54        if self.optimizer:
 55            for param_group in self.optimizer.param_groups:
 56                param_group['lr'] = start_lr
 57        else:
 58            print("[WARNING] No optomizer was given to the WarmUp Scheduler!")
 59
 60
 61    def step(self):
 62        """
 63        Performs a single step in the scheduler. 
 64        During warm-up, linearly increases the learning rate. 
 65        After warm-up, delegates to the optional after-scheduler.
 66        """
 67        if self.current_step < self.step_duration:
 68            self.current_lr = float(self.lrs[self.current_step])
 69            if self.optimizer:
 70                for param_group in self.optimizer.param_groups:
 71                    param_group['lr'] = self.current_lr
 72        else:
 73            if self.scheduler is not None:
 74                self.scheduler.step()
 75
 76        self.current_step += 1
 77
 78
 79    def get_last_lr(self):
 80        """
 81        Returns the most recently applied learning rate as a list.
 82        """
 83        if self.current_step < self.step_duration:
 84            return [float(self.lrs[self.current_step-1])]
 85        elif self.scheduler:
 86            return self.scheduler.get_last_lr()
 87        else:
 88            return [self.end_lr]
 89
 90
 91    def state_dict(self):
 92        """
 93        Returns a dictionary with the current step and after-scheduler state.
 94        """
 95        return {
 96            "current_step": self.current_step,
 97            "after_scheduler": self.scheduler.state_dict() if self.scheduler else None
 98        }
 99
100
101    def load_state_dict(self, state):
102        """
103        Loads the scheduler state from a given dictionary.
104
105        Parameter:
106        - state (dict): 
107            Dictionary containing 'current_step' and optional 'after_scheduler' state.
108        """
109        self.current_step = state["current_step"]
110        if self.scheduler and state["after_scheduler"]:
111            self.scheduler.load_state_dict(state["after_scheduler"])

Implements a learning rate scheduler with an initial warm-up phase.

After the warm-up phase, an optional 'after scheduler' can take over to continue adjusting the learning rate according to another schedule.

WarmUpScheduler(start_lr, end_lr, optimizer, scheduler=None, step_duration=1000)
27    def __init__(self, start_lr, end_lr, optimizer, scheduler=None, step_duration=1000):
28        """
29        Init WarmUp Scheduler.
30
31        Parameter:
32        - start_lr (float): 
33            Initial learning rate at the start of warm-up.
34        - end_lr (float): 
35            Final learning rate at the end of warm-up.
36        - optimizer (torch.optim.Optimizer): 
37            Optimizer whose learning rate will be updated.
38        - scheduler (torch.optim.lr_scheduler._LRScheduler, optional): 
39            Scheduler to apply after warm-up.
40        - step_duration (int): 
41            Number of steps over which to increase the learning rate.
42        """
43        self.start_lr = start_lr
44        self.end_lr = end_lr
45        self.current_lr = start_lr
46        self.step_duration = step_duration
47        self.optimizer = optimizer
48        self.scheduler = scheduler
49        self.current_step = 0
50
51        self.lrs = torch.linspace(start_lr, end_lr, step_duration)
52
53        # set initial LR
54        if self.optimizer:
55            for param_group in self.optimizer.param_groups:
56                param_group['lr'] = start_lr
57        else:
58            print("[WARNING] No optomizer was given to the WarmUp Scheduler!")

Init WarmUp Scheduler.

Parameter:

  • start_lr (float): Initial learning rate at the start of warm-up.
  • end_lr (float): Final learning rate at the end of warm-up.
  • optimizer (torch.optim.Optimizer): Optimizer whose learning rate will be updated.
  • scheduler (torch.optim.lr_scheduler._LRScheduler, optional): Scheduler to apply after warm-up.
  • step_duration (int): Number of steps over which to increase the learning rate.
start_lr
end_lr
current_lr
step_duration
optimizer
scheduler
current_step
lrs
def step(self):
61    def step(self):
62        """
63        Performs a single step in the scheduler. 
64        During warm-up, linearly increases the learning rate. 
65        After warm-up, delegates to the optional after-scheduler.
66        """
67        if self.current_step < self.step_duration:
68            self.current_lr = float(self.lrs[self.current_step])
69            if self.optimizer:
70                for param_group in self.optimizer.param_groups:
71                    param_group['lr'] = self.current_lr
72        else:
73            if self.scheduler is not None:
74                self.scheduler.step()
75
76        self.current_step += 1

Performs a single step in the scheduler. During warm-up, linearly increases the learning rate. After warm-up, delegates to the optional after-scheduler.

def get_last_lr(self):
79    def get_last_lr(self):
80        """
81        Returns the most recently applied learning rate as a list.
82        """
83        if self.current_step < self.step_duration:
84            return [float(self.lrs[self.current_step-1])]
85        elif self.scheduler:
86            return self.scheduler.get_last_lr()
87        else:
88            return [self.end_lr]

Returns the most recently applied learning rate as a list.

def state_dict(self):
91    def state_dict(self):
92        """
93        Returns a dictionary with the current step and after-scheduler state.
94        """
95        return {
96            "current_step": self.current_step,
97            "after_scheduler": self.scheduler.state_dict() if self.scheduler else None
98        }

Returns a dictionary with the current step and after-scheduler state.

def load_state_dict(self, state):
101    def load_state_dict(self, state):
102        """
103        Loads the scheduler state from a given dictionary.
104
105        Parameter:
106        - state (dict): 
107            Dictionary containing 'current_step' and optional 'after_scheduler' state.
108        """
109        self.current_step = state["current_step"]
110        if self.scheduler and state["after_scheduler"]:
111            self.scheduler.load_state_dict(state["after_scheduler"])

Loads the scheduler state from a given dictionary.

Parameter:

  • state (dict): Dictionary containing 'current_step' and optional 'after_scheduler' state.