image_to_image.utils.model_io
Module to handle the input/output of PyTorch models. In this case that means loading and saving models.
Functions:
- save_checkpoint
- get_single_model
- get_model
- load_model_weights
- load_and_get_model
By Tobia Ippolito
1""" 2Module to handle the input/output of PyTorch models. 3In this case that means loading and saving models. 4 5 6Functions: 7- save_checkpoint 8- get_single_model 9- get_model 10- load_model_weights 11- load_and_get_model 12 13By Tobia Ippolito 14""" 15# --------------------------- 16# > Imports < 17# --------------------------- 18import os 19 20import torch 21 22import prime_printer as prime 23 24from ..models.pix2pix import Pix2Pix 25from ..models.resfcn import ResFCN 26from ..models.residual_design_model import ResidualDesignModel 27from ..models.transformer import PhysicFormer 28 29 30 31# --------------------------- 32# > Saving < 33# --------------------------- 34def save_checkpoint(args, model, optimizer, scheduler, epoch, path='ckpt.pth'): 35 """ 36 Saves a training checkpoint containing model, optimizer, and scheduler states. 37 38 Parameter: 39 - model (nn.Module): 40 Model to save. 41 - optimizer: 42 Optimizer or list/tuple of optimizers. 43 - scheduler: 44 Scheduler or list/tuple of schedulers. 45 - epoch (int): 46 Current epoch index. 47 - path (str): 48 File path to save checkpoint to ('.pth' extension added automatically). 49 50 Returns: 51 - None 52 """ 53 if not path.endswith(".pth"): 54 path += ".pth" 55 56 # set content to save 57 checkpoint_saving = {'epoch': epoch, 'model_state': model.state_dict()} 58 59 if isinstance(optimizer, (list, tuple)): 60 for idx, cur_optimizer in enumerate(optimizer): 61 checkpoint_saving[f'optim_state_{idx}'] = cur_optimizer.state_dict() 62 else: 63 checkpoint_saving[f'optim_state'] = optimizer.state_dict() 64 65 if isinstance(scheduler, (list, tuple)): 66 for idx, cur_scheduler in enumerate(scheduler): 67 checkpoint_saving[f'sched_state_{idx}'] = cur_scheduler.state_dict() 68 else: 69 checkpoint_saving[f'sched_state'] = scheduler.state_dict() 70 71 checkpoint_saving['args'] = args 72 73 # save checkpoint 74 torch.save(checkpoint_saving, path) 75 76 # save info txt 77 root_path, name = os.path.split(path) 78 info_name = ".".join(name.split(".")[:-1]) + ".txt" 79 with open(os.path.join(root_path, info_name), "w") as f: 80 f.write(f"Last Model saved in epoch: {epoch}, at: {prime.get_time(pattern='DAY.MONTH.YEAR HOUR:MINUTE O\'Clock', time_zone='Europe/Berlin')}") 81 82 83 84# --------------------------- 85# > Loading < 86# --------------------------- 87def get_single_model(model_name, args, criterion, device): 88 """ 89 Returns a single model instance based on provided arguments. 90 91 Supported models: 92 - ResFCN 93 - Pix2Pix 94 - ResidualDesignModel 95 - PhysicFormer 96 97 Parameter: 98 - model_name (str): 99 Name of the model to initialize. 100 - args: 101 Parsed command-line arguments. 102 - criterion: 103 Criterion for Pix2Pixs second loss. Required during model initialization. 104 - device: 105 Target device (GPU or CPU) on which to place the model. 106 107 Returns: 108 - model (nn.Module): 109 Instantiated PyTorch model on the given device. 110 """ 111 model_name = model_name.lower() 112 113 if model_name== "resfcn": 114 model = ResFCN(input_channels=args.resfcn_in_channels, 115 hidden_channels=args.resfcn_hidden_channels, 116 output_channels=args.resfcn_out_channels, 117 num_blocks=args.resfcn_num_blocks).to(device) 118 elif model_name == "resfcn_2": 119 model = ResFCN(input_channels=args.resfcn_2_in_channels, 120 hidden_channels=args.resfcn_2_hidden_channels, 121 output_channels=args.resfcn_2_out_channels, 122 num_blocks=args.resfcn_2_num_blocks).to(device) 123 elif model_name == "pix2pix": 124 model = Pix2Pix(input_channels=args.pix2pix_in_channels, 125 output_channels=args.pix2pix_out_channels, 126 hidden_channels=args.pix2pix_hidden_channels, 127 second_loss=criterion, 128 lambda_second=args.pix2pix_second_loss_lambda).to(device) 129 elif model_name == "pix2pix_2": 130 model = Pix2Pix(input_channels=args.pix2pix_2_in_channels, 131 output_channels=args.pix2pix_2_out_channels, 132 hidden_channels=args.pix2pix_2_hidden_channels, 133 second_loss=criterion, 134 lambda_second=args.pix2pix_2_second_loss_lambda).to(device) 135 elif model_name == "physicsformer": 136 model = PhysicFormer(input_channels=args.physicsformer_in_channels, 137 output_channels=args.physicsformer_out_channels, 138 img_size=args.physicsformer_img_size, 139 patch_size=args.physicsformer_patch_size, 140 embedded_dim=args.physicsformer_embedded_dim, 141 num_blocks=args.physicsformer_num_blocks, 142 heads=args.physicsformer_heads, 143 mlp_dim=args.physicsformer_mlp_dim, 144 dropout=args.physicsformer_dropout).to(device) 145 elif model_name == "physicsformer_2": 146 model = PhysicFormer(input_channels=args.physicsformer_in_channels_2, 147 output_channels=args.physicsformer_out_channels_2, 148 img_size=args.physicsformer_img_size_2, 149 patch_size=args.physicsformer_patch_size_2, 150 embedded_dim=args.physicsformer_embedded_dim_2, 151 num_blocks=args.physicsformer_num_blocks_2, 152 heads=args.physicsformer_heads_2, 153 mlp_dim=args.physicsformer_mlp_dim_2, 154 dropout=args.physicsformer_dropout_2).to(device) 155 else: 156 raise ValueError(f"'{model_name}' is not a supported model.") 157 158 return model 159 160 161 162def get_model(args, device, criterion=None): 163 """ 164 Returns a model object with the given args. Also complex models with sub-models can be loaded. 165 This is the main function to get a model object. 166 167 - args (argparse.ArgumentParser): 168 Arguments to get information about the model class/object. 169 - device (torch.device): 170 Device on which the model should get move to. 171 - criterion (torch.nn.modules.loss._Loss, default=None): 172 Loss function. Some models save the loss internally. Not needed for inference only. 173 174 Returns: 175 - torch.nn.Module: 176 Loaded Model object without weights. 177 """ 178 if args.model.lower() == "residual_design_model": 179 model = ResidualDesignModel(base_model=get_single_model(model_name=args.base_model, args=args, criterion=criterion[0], device=device).to(device), 180 complex_model=get_single_model(model_name=args.complex_model+"_2", args=args, criterion=criterion[1], device=device).to(device), 181 combine_mode=args).to(device) 182 else: 183 model = get_single_model(model_name=args.model, args=args, criterion=criterion, device=device).to(device) 184 185 return model 186 187 188 189def load_model_weights(model_params_path): 190 """ 191 Loads model weights (model states) from a pth file. 192 193 Parameter: 194 - model_params_path (str): 195 The path to the model parameters. 196 197 Returns: 198 - dict: 199 Loaded weights 200 """ 201 return torch.load(model_params_path, weights_only=False) 202 203 204 205# python -c "import torch;print(type(torch.device('cuda' if torch.cuda.is_available() else 'cpu')))" 206# <class 'torch.device'> 207# python -c "import torch.nn as nn;print(isinstance(nn.MSELoss(), nn.modules.loss._Loss))" 208def load_and_get_model(model_params_path, device, criterion=None): 209 """ 210 Loads model weights (model states) from a pth file. 211 And creates the right model object and return the model object with the weights loaded. 212 213 Parameter: 214 - model_params_path (str): 215 The path to the model parameters. 216 - device (torch.device): 217 Device on which the model should get move to. 218 - criterion (torch.nn.modules.loss._Loss, default=None): 219 Loss function. Some models save the loss internally. Not needed for inference only. 220 221 Returns: 222 - torch.nn.Module: 223 Loaded Model object with weights. 224 - argparse.ArgumentParser: 225 Loaded arguments. 226 """ 227 state_dict = load_model_weights(model_params_path) 228 args = state_dict["args"] 229 model = get_model(args=args, device=device, criterion=criterion) 230 model.load_state_dict(state_dict["model_state"]) 231 model.eval() 232 return model, args
def
save_checkpoint(args, model, optimizer, scheduler, epoch, path='ckpt.pth'):
35def save_checkpoint(args, model, optimizer, scheduler, epoch, path='ckpt.pth'): 36 """ 37 Saves a training checkpoint containing model, optimizer, and scheduler states. 38 39 Parameter: 40 - model (nn.Module): 41 Model to save. 42 - optimizer: 43 Optimizer or list/tuple of optimizers. 44 - scheduler: 45 Scheduler or list/tuple of schedulers. 46 - epoch (int): 47 Current epoch index. 48 - path (str): 49 File path to save checkpoint to ('.pth' extension added automatically). 50 51 Returns: 52 - None 53 """ 54 if not path.endswith(".pth"): 55 path += ".pth" 56 57 # set content to save 58 checkpoint_saving = {'epoch': epoch, 'model_state': model.state_dict()} 59 60 if isinstance(optimizer, (list, tuple)): 61 for idx, cur_optimizer in enumerate(optimizer): 62 checkpoint_saving[f'optim_state_{idx}'] = cur_optimizer.state_dict() 63 else: 64 checkpoint_saving[f'optim_state'] = optimizer.state_dict() 65 66 if isinstance(scheduler, (list, tuple)): 67 for idx, cur_scheduler in enumerate(scheduler): 68 checkpoint_saving[f'sched_state_{idx}'] = cur_scheduler.state_dict() 69 else: 70 checkpoint_saving[f'sched_state'] = scheduler.state_dict() 71 72 checkpoint_saving['args'] = args 73 74 # save checkpoint 75 torch.save(checkpoint_saving, path) 76 77 # save info txt 78 root_path, name = os.path.split(path) 79 info_name = ".".join(name.split(".")[:-1]) + ".txt" 80 with open(os.path.join(root_path, info_name), "w") as f: 81 f.write(f"Last Model saved in epoch: {epoch}, at: {prime.get_time(pattern='DAY.MONTH.YEAR HOUR:MINUTE O\'Clock', time_zone='Europe/Berlin')}")
Saves a training checkpoint containing model, optimizer, and scheduler states.
Parameter:
- model (nn.Module): Model to save.
- optimizer: Optimizer or list/tuple of optimizers.
- scheduler: Scheduler or list/tuple of schedulers.
- epoch (int): Current epoch index.
- path (str): File path to save checkpoint to ('.pth' extension added automatically).
Returns:
- None
def
get_single_model(model_name, args, criterion, device):
88def get_single_model(model_name, args, criterion, device): 89 """ 90 Returns a single model instance based on provided arguments. 91 92 Supported models: 93 - ResFCN 94 - Pix2Pix 95 - ResidualDesignModel 96 - PhysicFormer 97 98 Parameter: 99 - model_name (str): 100 Name of the model to initialize. 101 - args: 102 Parsed command-line arguments. 103 - criterion: 104 Criterion for Pix2Pixs second loss. Required during model initialization. 105 - device: 106 Target device (GPU or CPU) on which to place the model. 107 108 Returns: 109 - model (nn.Module): 110 Instantiated PyTorch model on the given device. 111 """ 112 model_name = model_name.lower() 113 114 if model_name== "resfcn": 115 model = ResFCN(input_channels=args.resfcn_in_channels, 116 hidden_channels=args.resfcn_hidden_channels, 117 output_channels=args.resfcn_out_channels, 118 num_blocks=args.resfcn_num_blocks).to(device) 119 elif model_name == "resfcn_2": 120 model = ResFCN(input_channels=args.resfcn_2_in_channels, 121 hidden_channels=args.resfcn_2_hidden_channels, 122 output_channels=args.resfcn_2_out_channels, 123 num_blocks=args.resfcn_2_num_blocks).to(device) 124 elif model_name == "pix2pix": 125 model = Pix2Pix(input_channels=args.pix2pix_in_channels, 126 output_channels=args.pix2pix_out_channels, 127 hidden_channels=args.pix2pix_hidden_channels, 128 second_loss=criterion, 129 lambda_second=args.pix2pix_second_loss_lambda).to(device) 130 elif model_name == "pix2pix_2": 131 model = Pix2Pix(input_channels=args.pix2pix_2_in_channels, 132 output_channels=args.pix2pix_2_out_channels, 133 hidden_channels=args.pix2pix_2_hidden_channels, 134 second_loss=criterion, 135 lambda_second=args.pix2pix_2_second_loss_lambda).to(device) 136 elif model_name == "physicsformer": 137 model = PhysicFormer(input_channels=args.physicsformer_in_channels, 138 output_channels=args.physicsformer_out_channels, 139 img_size=args.physicsformer_img_size, 140 patch_size=args.physicsformer_patch_size, 141 embedded_dim=args.physicsformer_embedded_dim, 142 num_blocks=args.physicsformer_num_blocks, 143 heads=args.physicsformer_heads, 144 mlp_dim=args.physicsformer_mlp_dim, 145 dropout=args.physicsformer_dropout).to(device) 146 elif model_name == "physicsformer_2": 147 model = PhysicFormer(input_channels=args.physicsformer_in_channels_2, 148 output_channels=args.physicsformer_out_channels_2, 149 img_size=args.physicsformer_img_size_2, 150 patch_size=args.physicsformer_patch_size_2, 151 embedded_dim=args.physicsformer_embedded_dim_2, 152 num_blocks=args.physicsformer_num_blocks_2, 153 heads=args.physicsformer_heads_2, 154 mlp_dim=args.physicsformer_mlp_dim_2, 155 dropout=args.physicsformer_dropout_2).to(device) 156 else: 157 raise ValueError(f"'{model_name}' is not a supported model.") 158 159 return model
Returns a single model instance based on provided arguments.
Supported models:
- ResFCN
- Pix2Pix
- ResidualDesignModel
- PhysicFormer
Parameter:
- model_name (str): Name of the model to initialize.
- args: Parsed command-line arguments.
- criterion: Criterion for Pix2Pixs second loss. Required during model initialization.
- device: Target device (GPU or CPU) on which to place the model.
Returns:
- model (nn.Module): Instantiated PyTorch model on the given device.
def
get_model(args, device, criterion=None):
163def get_model(args, device, criterion=None): 164 """ 165 Returns a model object with the given args. Also complex models with sub-models can be loaded. 166 This is the main function to get a model object. 167 168 - args (argparse.ArgumentParser): 169 Arguments to get information about the model class/object. 170 - device (torch.device): 171 Device on which the model should get move to. 172 - criterion (torch.nn.modules.loss._Loss, default=None): 173 Loss function. Some models save the loss internally. Not needed for inference only. 174 175 Returns: 176 - torch.nn.Module: 177 Loaded Model object without weights. 178 """ 179 if args.model.lower() == "residual_design_model": 180 model = ResidualDesignModel(base_model=get_single_model(model_name=args.base_model, args=args, criterion=criterion[0], device=device).to(device), 181 complex_model=get_single_model(model_name=args.complex_model+"_2", args=args, criterion=criterion[1], device=device).to(device), 182 combine_mode=args).to(device) 183 else: 184 model = get_single_model(model_name=args.model, args=args, criterion=criterion, device=device).to(device) 185 186 return model
Returns a model object with the given args. Also complex models with sub-models can be loaded. This is the main function to get a model object.
- args (argparse.ArgumentParser): Arguments to get information about the model class/object.
- device (torch.device): Device on which the model should get move to.
- criterion (torch.nn.modules.loss._Loss, default=None): Loss function. Some models save the loss internally. Not needed for inference only.
Returns:
- torch.nn.Module: Loaded Model object without weights.
def
load_model_weights(model_params_path):
190def load_model_weights(model_params_path): 191 """ 192 Loads model weights (model states) from a pth file. 193 194 Parameter: 195 - model_params_path (str): 196 The path to the model parameters. 197 198 Returns: 199 - dict: 200 Loaded weights 201 """ 202 return torch.load(model_params_path, weights_only=False)
Loads model weights (model states) from a pth file.
Parameter:
- model_params_path (str): The path to the model parameters.
Returns:
- dict: Loaded weights
def
load_and_get_model(model_params_path, device, criterion=None):
209def load_and_get_model(model_params_path, device, criterion=None): 210 """ 211 Loads model weights (model states) from a pth file. 212 And creates the right model object and return the model object with the weights loaded. 213 214 Parameter: 215 - model_params_path (str): 216 The path to the model parameters. 217 - device (torch.device): 218 Device on which the model should get move to. 219 - criterion (torch.nn.modules.loss._Loss, default=None): 220 Loss function. Some models save the loss internally. Not needed for inference only. 221 222 Returns: 223 - torch.nn.Module: 224 Loaded Model object with weights. 225 - argparse.ArgumentParser: 226 Loaded arguments. 227 """ 228 state_dict = load_model_weights(model_params_path) 229 args = state_dict["args"] 230 model = get_model(args=args, device=device, criterion=criterion) 231 model.load_state_dict(state_dict["model_state"]) 232 model.eval() 233 return model, args
Loads model weights (model states) from a pth file. And creates the right model object and return the model object with the weights loaded.
Parameter:
- model_params_path (str): The path to the model parameters.
- device (torch.device): Device on which the model should get move to.
- criterion (torch.nn.modules.loss._Loss, default=None): Loss function. Some models save the loss internally. Not needed for inference only.
Returns:
- torch.nn.Module: Loaded Model object with weights.
- argparse.ArgumentParser: Loaded arguments.