image_to_image.utils.model_io

Module to handle the input/output of PyTorch models. In this case that means loading and saving models.

Functions:

  • save_checkpoint
  • get_single_model
  • get_model
  • load_model_weights
  • load_and_get_model

By Tobia Ippolito

  1"""
  2Module to handle the input/output of PyTorch models.
  3In this case that means loading and saving models.
  4
  5
  6Functions:
  7- save_checkpoint
  8- get_single_model
  9- get_model
 10- load_model_weights
 11- load_and_get_model
 12
 13By Tobia Ippolito
 14"""
 15# ---------------------------
 16#        > Imports <
 17# ---------------------------
 18import os
 19
 20import torch
 21
 22import prime_printer as prime
 23
 24from ..models.pix2pix import Pix2Pix
 25from ..models.resfcn import ResFCN
 26from ..models.residual_design_model import ResidualDesignModel
 27from ..models.transformer import PhysicFormer
 28
 29
 30
 31# ---------------------------
 32#         > Saving <
 33# ---------------------------
 34def save_checkpoint(args, model, optimizer, scheduler, epoch, path='ckpt.pth'):
 35    """
 36    Saves a training checkpoint containing model, optimizer, and scheduler states.
 37
 38    Parameter:
 39    - model (nn.Module):
 40        Model to save.
 41    - optimizer:
 42        Optimizer or list/tuple of optimizers.
 43    - scheduler:
 44        Scheduler or list/tuple of schedulers.
 45    - epoch (int):
 46        Current epoch index.
 47    - path (str):
 48        File path to save checkpoint to ('.pth' extension added automatically).
 49
 50    Returns:
 51    - None
 52    """
 53    if not path.endswith(".pth"):
 54        path += ".pth"
 55
 56    # set content to save
 57    checkpoint_saving = {'epoch': epoch, 'model_state': model.state_dict()}
 58
 59    if isinstance(optimizer, (list, tuple)):
 60        for idx, cur_optimizer in enumerate(optimizer):
 61            checkpoint_saving[f'optim_state_{idx}'] = cur_optimizer.state_dict()
 62    else:
 63        checkpoint_saving[f'optim_state'] = optimizer.state_dict()
 64
 65    if isinstance(scheduler, (list, tuple)):
 66        for idx, cur_scheduler in enumerate(scheduler):
 67            checkpoint_saving[f'sched_state_{idx}'] = cur_scheduler.state_dict()
 68    else:
 69        checkpoint_saving[f'sched_state'] = scheduler.state_dict()
 70
 71    checkpoint_saving['args'] = args
 72
 73    # save checkpoint
 74    torch.save(checkpoint_saving, path)
 75
 76    # save info txt
 77    root_path, name = os.path.split(path)
 78    info_name = ".".join(name.split(".")[:-1]) + ".txt"
 79    with open(os.path.join(root_path, info_name), "w") as f:
 80        f.write(f"Last Model saved in epoch: {epoch}, at: {prime.get_time(pattern='DAY.MONTH.YEAR HOUR:MINUTE O\'Clock', time_zone='Europe/Berlin')}")
 81
 82
 83
 84# ---------------------------
 85#        > Loading <
 86# ---------------------------
 87def get_single_model(model_name, args, criterion, device):
 88    """
 89    Returns a single model instance based on provided arguments.
 90
 91    Supported models:
 92    - ResFCN
 93    - Pix2Pix
 94    - ResidualDesignModel
 95    - PhysicFormer
 96
 97    Parameter:
 98    - model_name (str):
 99        Name of the model to initialize.
100    - args:
101        Parsed command-line arguments.
102    - criterion:
103        Criterion for Pix2Pixs second loss. Required during model initialization.
104    - device:
105        Target device (GPU or CPU) on which to place the model.
106
107    Returns:
108    - model (nn.Module): 
109        Instantiated PyTorch model on the given device.
110    """
111    model_name = model_name.lower()
112
113    if model_name== "resfcn":
114        model = ResFCN(input_channels=args.resfcn_in_channels, 
115                       hidden_channels=args.resfcn_hidden_channels, 
116                       output_channels=args.resfcn_out_channels,
117                         num_blocks=args.resfcn_num_blocks).to(device)
118    elif model_name == "resfcn_2":
119        model = ResFCN(input_channels=args.resfcn_2_in_channels, 
120                       hidden_channels=args.resfcn_2_hidden_channels, 
121                       output_channels=args.resfcn_2_out_channels, 
122                       num_blocks=args.resfcn_2_num_blocks).to(device)
123    elif model_name == "pix2pix":
124        model = Pix2Pix(input_channels=args.pix2pix_in_channels, 
125                        output_channels=args.pix2pix_out_channels, 
126                        hidden_channels=args.pix2pix_hidden_channels, 
127                        second_loss=criterion, 
128                        lambda_second=args.pix2pix_second_loss_lambda).to(device)
129    elif model_name == "pix2pix_2":
130        model = Pix2Pix(input_channels=args.pix2pix_2_in_channels, 
131                        output_channels=args.pix2pix_2_out_channels, 
132                        hidden_channels=args.pix2pix_2_hidden_channels, 
133                        second_loss=criterion, 
134                        lambda_second=args.pix2pix_2_second_loss_lambda).to(device)
135    elif model_name == "physicsformer":
136        model = PhysicFormer(input_channels=args.physicsformer_in_channels, 
137                             output_channels=args.physicsformer_out_channels, 
138                             img_size=args.physicsformer_img_size, 
139                             patch_size=args.physicsformer_patch_size, 
140                             embedded_dim=args.physicsformer_embedded_dim, 
141                             num_blocks=args.physicsformer_num_blocks,
142                             heads=args.physicsformer_heads, 
143                             mlp_dim=args.physicsformer_mlp_dim, 
144                             dropout=args.physicsformer_dropout).to(device)
145    elif model_name == "physicsformer_2":
146        model = PhysicFormer(input_channels=args.physicsformer_in_channels_2, 
147                             output_channels=args.physicsformer_out_channels_2, 
148                             img_size=args.physicsformer_img_size_2, 
149                             patch_size=args.physicsformer_patch_size_2, 
150                             embedded_dim=args.physicsformer_embedded_dim_2, 
151                             num_blocks=args.physicsformer_num_blocks_2,
152                             heads=args.physicsformer_heads_2, 
153                             mlp_dim=args.physicsformer_mlp_dim_2, 
154                             dropout=args.physicsformer_dropout_2).to(device)
155    else:
156        raise ValueError(f"'{model_name}' is not a supported model.")
157
158    return model
159
160
161
162def get_model(args, device, criterion=None):
163    """
164    Returns a model object with the given args. Also complex models with sub-models can be loaded.
165    This is the main function to get a model object.
166
167    - args (argparse.ArgumentParser):
168        Arguments to get information about the model class/object.
169    - device (torch.device): 
170        Device on which the model should get move to.
171    - criterion (torch.nn.modules.loss._Loss, default=None): 
172        Loss function. Some models save the loss internally. Not needed for inference only.
173
174    Returns:
175    - torch.nn.Module: 
176        Loaded Model object without weights.
177    """
178    if args.model.lower() == "residual_design_model":
179        model = ResidualDesignModel(base_model=get_single_model(model_name=args.base_model, args=args, criterion=criterion[0], device=device).to(device),
180                                    complex_model=get_single_model(model_name=args.complex_model+"_2", args=args, criterion=criterion[1], device=device).to(device),
181                                    combine_mode=args).to(device)
182    else:
183        model = get_single_model(model_name=args.model, args=args, criterion=criterion, device=device).to(device)
184
185    return model
186
187
188
189def load_model_weights(model_params_path):
190    """
191    Loads model weights (model states) from a pth file.
192
193    Parameter:
194    - model_params_path (str): 
195        The path to the model parameters.
196
197    Returns:
198    - dict: 
199        Loaded weights
200    """
201    return torch.load(model_params_path, weights_only=False)
202
203
204
205# python -c "import torch;print(type(torch.device('cuda' if torch.cuda.is_available() else 'cpu')))"
206# <class 'torch.device'>
207# python -c "import torch.nn as nn;print(isinstance(nn.MSELoss(), nn.modules.loss._Loss))"
208def load_and_get_model(model_params_path, device, criterion=None):
209    """
210    Loads model weights (model states) from a pth file.
211    And creates the right model object and return the model object with the weights loaded.
212
213    Parameter:
214    - model_params_path (str): 
215        The path to the model parameters.
216    - device (torch.device): 
217        Device on which the model should get move to.
218    - criterion (torch.nn.modules.loss._Loss, default=None): 
219        Loss function. Some models save the loss internally. Not needed for inference only.
220
221    Returns:
222    - torch.nn.Module: 
223        Loaded Model object with weights.
224    - argparse.ArgumentParser: 
225        Loaded arguments.
226    """
227    state_dict = load_model_weights(model_params_path)
228    args = state_dict["args"]
229    model = get_model(args=args, device=device, criterion=criterion)
230    model.load_state_dict(state_dict["model_state"])
231    model.eval()
232    return model, args
def save_checkpoint(args, model, optimizer, scheduler, epoch, path='ckpt.pth'):
35def save_checkpoint(args, model, optimizer, scheduler, epoch, path='ckpt.pth'):
36    """
37    Saves a training checkpoint containing model, optimizer, and scheduler states.
38
39    Parameter:
40    - model (nn.Module):
41        Model to save.
42    - optimizer:
43        Optimizer or list/tuple of optimizers.
44    - scheduler:
45        Scheduler or list/tuple of schedulers.
46    - epoch (int):
47        Current epoch index.
48    - path (str):
49        File path to save checkpoint to ('.pth' extension added automatically).
50
51    Returns:
52    - None
53    """
54    if not path.endswith(".pth"):
55        path += ".pth"
56
57    # set content to save
58    checkpoint_saving = {'epoch': epoch, 'model_state': model.state_dict()}
59
60    if isinstance(optimizer, (list, tuple)):
61        for idx, cur_optimizer in enumerate(optimizer):
62            checkpoint_saving[f'optim_state_{idx}'] = cur_optimizer.state_dict()
63    else:
64        checkpoint_saving[f'optim_state'] = optimizer.state_dict()
65
66    if isinstance(scheduler, (list, tuple)):
67        for idx, cur_scheduler in enumerate(scheduler):
68            checkpoint_saving[f'sched_state_{idx}'] = cur_scheduler.state_dict()
69    else:
70        checkpoint_saving[f'sched_state'] = scheduler.state_dict()
71
72    checkpoint_saving['args'] = args
73
74    # save checkpoint
75    torch.save(checkpoint_saving, path)
76
77    # save info txt
78    root_path, name = os.path.split(path)
79    info_name = ".".join(name.split(".")[:-1]) + ".txt"
80    with open(os.path.join(root_path, info_name), "w") as f:
81        f.write(f"Last Model saved in epoch: {epoch}, at: {prime.get_time(pattern='DAY.MONTH.YEAR HOUR:MINUTE O\'Clock', time_zone='Europe/Berlin')}")

Saves a training checkpoint containing model, optimizer, and scheduler states.

Parameter:

  • model (nn.Module): Model to save.
  • optimizer: Optimizer or list/tuple of optimizers.
  • scheduler: Scheduler or list/tuple of schedulers.
  • epoch (int): Current epoch index.
  • path (str): File path to save checkpoint to ('.pth' extension added automatically).

Returns:

  • None
def get_single_model(model_name, args, criterion, device):
 88def get_single_model(model_name, args, criterion, device):
 89    """
 90    Returns a single model instance based on provided arguments.
 91
 92    Supported models:
 93    - ResFCN
 94    - Pix2Pix
 95    - ResidualDesignModel
 96    - PhysicFormer
 97
 98    Parameter:
 99    - model_name (str):
100        Name of the model to initialize.
101    - args:
102        Parsed command-line arguments.
103    - criterion:
104        Criterion for Pix2Pixs second loss. Required during model initialization.
105    - device:
106        Target device (GPU or CPU) on which to place the model.
107
108    Returns:
109    - model (nn.Module): 
110        Instantiated PyTorch model on the given device.
111    """
112    model_name = model_name.lower()
113
114    if model_name== "resfcn":
115        model = ResFCN(input_channels=args.resfcn_in_channels, 
116                       hidden_channels=args.resfcn_hidden_channels, 
117                       output_channels=args.resfcn_out_channels,
118                         num_blocks=args.resfcn_num_blocks).to(device)
119    elif model_name == "resfcn_2":
120        model = ResFCN(input_channels=args.resfcn_2_in_channels, 
121                       hidden_channels=args.resfcn_2_hidden_channels, 
122                       output_channels=args.resfcn_2_out_channels, 
123                       num_blocks=args.resfcn_2_num_blocks).to(device)
124    elif model_name == "pix2pix":
125        model = Pix2Pix(input_channels=args.pix2pix_in_channels, 
126                        output_channels=args.pix2pix_out_channels, 
127                        hidden_channels=args.pix2pix_hidden_channels, 
128                        second_loss=criterion, 
129                        lambda_second=args.pix2pix_second_loss_lambda).to(device)
130    elif model_name == "pix2pix_2":
131        model = Pix2Pix(input_channels=args.pix2pix_2_in_channels, 
132                        output_channels=args.pix2pix_2_out_channels, 
133                        hidden_channels=args.pix2pix_2_hidden_channels, 
134                        second_loss=criterion, 
135                        lambda_second=args.pix2pix_2_second_loss_lambda).to(device)
136    elif model_name == "physicsformer":
137        model = PhysicFormer(input_channels=args.physicsformer_in_channels, 
138                             output_channels=args.physicsformer_out_channels, 
139                             img_size=args.physicsformer_img_size, 
140                             patch_size=args.physicsformer_patch_size, 
141                             embedded_dim=args.physicsformer_embedded_dim, 
142                             num_blocks=args.physicsformer_num_blocks,
143                             heads=args.physicsformer_heads, 
144                             mlp_dim=args.physicsformer_mlp_dim, 
145                             dropout=args.physicsformer_dropout).to(device)
146    elif model_name == "physicsformer_2":
147        model = PhysicFormer(input_channels=args.physicsformer_in_channels_2, 
148                             output_channels=args.physicsformer_out_channels_2, 
149                             img_size=args.physicsformer_img_size_2, 
150                             patch_size=args.physicsformer_patch_size_2, 
151                             embedded_dim=args.physicsformer_embedded_dim_2, 
152                             num_blocks=args.physicsformer_num_blocks_2,
153                             heads=args.physicsformer_heads_2, 
154                             mlp_dim=args.physicsformer_mlp_dim_2, 
155                             dropout=args.physicsformer_dropout_2).to(device)
156    else:
157        raise ValueError(f"'{model_name}' is not a supported model.")
158
159    return model

Returns a single model instance based on provided arguments.

Supported models:

  • ResFCN
  • Pix2Pix
  • ResidualDesignModel
  • PhysicFormer

Parameter:

  • model_name (str): Name of the model to initialize.
  • args: Parsed command-line arguments.
  • criterion: Criterion for Pix2Pixs second loss. Required during model initialization.
  • device: Target device (GPU or CPU) on which to place the model.

Returns:

  • model (nn.Module): Instantiated PyTorch model on the given device.
def get_model(args, device, criterion=None):
163def get_model(args, device, criterion=None):
164    """
165    Returns a model object with the given args. Also complex models with sub-models can be loaded.
166    This is the main function to get a model object.
167
168    - args (argparse.ArgumentParser):
169        Arguments to get information about the model class/object.
170    - device (torch.device): 
171        Device on which the model should get move to.
172    - criterion (torch.nn.modules.loss._Loss, default=None): 
173        Loss function. Some models save the loss internally. Not needed for inference only.
174
175    Returns:
176    - torch.nn.Module: 
177        Loaded Model object without weights.
178    """
179    if args.model.lower() == "residual_design_model":
180        model = ResidualDesignModel(base_model=get_single_model(model_name=args.base_model, args=args, criterion=criterion[0], device=device).to(device),
181                                    complex_model=get_single_model(model_name=args.complex_model+"_2", args=args, criterion=criterion[1], device=device).to(device),
182                                    combine_mode=args).to(device)
183    else:
184        model = get_single_model(model_name=args.model, args=args, criterion=criterion, device=device).to(device)
185
186    return model

Returns a model object with the given args. Also complex models with sub-models can be loaded. This is the main function to get a model object.

  • args (argparse.ArgumentParser): Arguments to get information about the model class/object.
  • device (torch.device): Device on which the model should get move to.
  • criterion (torch.nn.modules.loss._Loss, default=None): Loss function. Some models save the loss internally. Not needed for inference only.

Returns:

  • torch.nn.Module: Loaded Model object without weights.
def load_model_weights(model_params_path):
190def load_model_weights(model_params_path):
191    """
192    Loads model weights (model states) from a pth file.
193
194    Parameter:
195    - model_params_path (str): 
196        The path to the model parameters.
197
198    Returns:
199    - dict: 
200        Loaded weights
201    """
202    return torch.load(model_params_path, weights_only=False)

Loads model weights (model states) from a pth file.

Parameter:

  • model_params_path (str): The path to the model parameters.

Returns:

  • dict: Loaded weights
def load_and_get_model(model_params_path, device, criterion=None):
209def load_and_get_model(model_params_path, device, criterion=None):
210    """
211    Loads model weights (model states) from a pth file.
212    And creates the right model object and return the model object with the weights loaded.
213
214    Parameter:
215    - model_params_path (str): 
216        The path to the model parameters.
217    - device (torch.device): 
218        Device on which the model should get move to.
219    - criterion (torch.nn.modules.loss._Loss, default=None): 
220        Loss function. Some models save the loss internally. Not needed for inference only.
221
222    Returns:
223    - torch.nn.Module: 
224        Loaded Model object with weights.
225    - argparse.ArgumentParser: 
226        Loaded arguments.
227    """
228    state_dict = load_model_weights(model_params_path)
229    args = state_dict["args"]
230    model = get_model(args=args, device=device, criterion=criterion)
231    model.load_state_dict(state_dict["model_state"])
232    model.eval()
233    return model, args

Loads model weights (model states) from a pth file. And creates the right model object and return the model object with the weights loaded.

Parameter:

  • model_params_path (str): The path to the model parameters.
  • device (torch.device): Device on which the model should get move to.
  • criterion (torch.nn.modules.loss._Loss, default=None): Loss function. Some models save the loss internally. Not needed for inference only.

Returns:

  • torch.nn.Module: Loaded Model object with weights.
  • argparse.ArgumentParser: Loaded arguments.