image_to_image.model_interactions.train
Module to train, validate, and evaluate image-to-image models. Supports multiple models, losses, optimizers, schedulers, mixed-precision training, checkpointing, and experiment tracking with MLflow and TensorBoard.
The train function handles full experiment orchestration including:
- Argument parsing and device setup.
- Dataset and dataloader initialization.
- Model, optimizer, loss function, and scheduler setup.
- Mixed precision (AMP) and warm-up handling.
- MLflow and TensorBoard logging.
- Periodic validation and checkpoint saving.
Functions:
- get_loss
- get_optimizer
- get_scheduler
- backward_model
- train_one_epoch
- evaluate
- train
By Tobia Ippolito
1""" 2Module to train, validate, and evaluate image-to-image models. 3Supports multiple models, losses, optimizers, schedulers, mixed-precision training, 4checkpointing, and experiment tracking with MLflow and TensorBoard. 5 6The train function handles full experiment orchestration including: 7- Argument parsing and device setup. 8- Dataset and dataloader initialization. 9- Model, optimizer, loss function, and scheduler setup. 10- Mixed precision (AMP) and warm-up handling. 11- MLflow and TensorBoard logging. 12- Periodic validation and checkpoint saving. 13 14Functions: 15- get_loss 16- get_optimizer 17- get_scheduler 18- backward_model 19- train_one_epoch 20- evaluate 21- train 22 23By Tobia Ippolito 24""" 25# --------------------------- 26# > Imports < 27# --------------------------- 28import os 29import shutil 30import time 31import copy 32 33import matplotlib.pyplot as plt 34 35from tqdm import tqdm 36 37import torch 38from torch import nn, optim 39from torch.utils.data import DataLoader 40from torch.amp import autocast, GradScaler 41import torchvision 42 43# Experiment tracking 44import mlflow 45import mlflow.pytorch 46from torch.utils.tensorboard import SummaryWriter 47 48import prime_printer as prime 49 50 51from ..utils.argument_parsing import parse_args 52from ..utils.model_io import get_model, save_checkpoint 53 54from ..data.physgen import PhysGenDataset 55from ..data.residual_physgen import PhysGenResidualDataset, to_device 56 57from ..models.resfcn import ResFCN 58from ..models.pix2pix import Pix2Pix 59from ..models.residual_design_model import ResidualDesignModel 60from ..models.transformer import PhysicFormer 61 62from ..losses.weighted_combined_loss import WeightedCombinedLoss 63from ..scheduler.warm_up import WarmUpScheduler 64from..amp.dummy_scaler import DummyScaler 65 66 67 68# --------------------------- 69# > Train Helpers < 70# --------------------------- 71def get_data(args): 72 if args.model.lower() == "residual_design_model": 73 train_dataset = PhysGenResidualDataset(variation=args.data_variation, mode="train", 74 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 75 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 76 77 val_dataset = PhysGenDataset(variation=args.data_variation, mode="validation", input_type="osm", output_type="standard", 78 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 79 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 80 else: 81 train_dataset = PhysGenDataset(variation=args.data_variation, mode="train", input_type=args.input_type, output_type=args.output_type, 82 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 83 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 84 85 val_dataset = PhysGenDataset(variation=args.data_variation, mode="validation", input_type=args.input_type, output_type=args.output_type, 86 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 87 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 88 89 train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 90 val_loader = DataLoader(val_dataset, batch_size=args.batch_size) 91 92 return train_dataset, val_dataset, train_loader, val_loader 93 94 95 96def get_loss(loss_name, args): 97 """ 98 Returns a loss function instance based on the provided loss name. 99 100 Supported losses: 101 - L1 / L1_2 102 - CrossEntropy / CrossEntropy_2 103 - WeightedCombined / WeightedCombined_2 104 105 Parameter: 106 - loss_name (str): 107 Name of the loss function. 108 - args: 109 Parsed command-line arguments with configured loss weights. 110 111 Returns: 112 - criterion (nn.Module): Instantiated loss function. 113 """ 114 loss_name = loss_name.lower() 115 116 if loss_name == "l1": 117 criterion = nn.L1Loss() 118 elif loss_name == "l1_2": 119 criterion = nn.L1Loss() 120 elif loss_name == "crossentropy": 121 criterion = nn.CrossEntropyLoss() 122 elif loss_name == "crossentropy_2": 123 criterion = nn.CrossEntropyLoss() 124 elif loss_name == "weighted_combined": 125 criterion = WeightedCombinedLoss( 126 silog_lambda=args.wc_loss_silog_lambda, 127 weight_silog=args.wc_loss_weight_silog, 128 weight_grad=args.wc_loss_weight_grad, 129 weight_ssim=args.wc_loss_weight_ssim, 130 weight_edge_aware=args.wc_loss_weight_edge_aware, 131 weight_l1=args.wc_loss_weight_l1, 132 weight_var=args.wc_loss_weight_var, 133 weight_range=args.wc_loss_weight_range, 134 weight_blur=args.wc_loss_weight_blur 135 ) 136 elif loss_name == "weighted_combined_2": 137 criterion = WeightedCombinedLoss( 138 silog_lambda=args.wc_loss_silog_lambda_2, 139 weight_silog=args.wc_loss_weight_silog_2, 140 weight_grad=args.wc_loss_weight_grad_2, 141 weight_ssim=args.wc_loss_weight_ssim_2, 142 weight_edge_aware=args.wc_loss_weight_edge_aware_2, 143 weight_l1=args.wc_loss_weight_l1_2, 144 weight_var=args.wc_loss_weight_var_2, 145 weight_range=args.wc_loss_weight_range_2, 146 weight_blur=args.wc_loss_weight_blur_2 147 ) 148 else: 149 raise ValueError(f"'{loss_name}' is not a supported loss.") 150 151 return criterion 152 153 154 155def get_optimizer(optimizer_name, model, lr, args): 156 """ 157 Returns an optimizer for the given model. 158 159 Supported optimizers: 160 - Adam 161 - AdamW 162 163 Parameter: 164 - optimizer_name (str): 165 Name of the optimizer. 166 - model (nn.Module): 167 Model whose parameters should be optimized. 168 - lr (float): 169 Learning rate. 170 - args: 171 Parsed command-line arguments with optimizer configuration. 172 173 Returns: 174 - optimizer (torch.optim.Optimizer): Instantiated optimizer. 175 """ 176 optimizer_name = optimizer_name.lower() 177 178 weight_decay_rate = args.weight_decay_rate if args.weight_decay else 0 179 180 if args.optimizer.lower() == "adam": 181 optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=weight_decay_rate) 182 elif args.optimizer.lower() == "adamw": 183 optimizer = optim.AdamW(model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=weight_decay_rate) 184 else: 185 raise ValueError(f"'{optimizer_name}' is not a supported optimizer.") 186 187 return optimizer 188 189 190 191def get_scheduler(scheduler_name, optimizer, args): 192 """ 193 Returns a learning rate scheduler for the given optimizer. 194 195 Supported schedulers: 196 - StepLR 197 - CosineAnnealingLR 198 199 Parameter: 200 - scheduler_name (str): 201 Name of the scheduler. 202 - optimizer: 203 Optimizer whose learning rate will be managed. 204 - args: 205 Parsed command-line arguments containing scheduler configuration. 206 207 Returns: 208 - scheduler (torch.optim.lr_scheduler): Instantiated scheduler. 209 """ 210 scheduler_name = scheduler_name.lower() 211 212 if scheduler_name == "step": 213 scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9) 214 elif scheduler_name == "cosine": 215 scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) 216 else: 217 raise ValueError(f"'{scheduler_name}' is not a supported scheduler.") 218 219 return scheduler 220 221 222 223def get_params(args, device, n_model_params, current_save_name, checkpoint_save_dir): 224 return { 225 # General 226 "mode": args.mode, 227 "device": str(device), 228 229 # Training 230 "epochs": args.epochs, 231 "batch_size": args.batch_size, 232 "learning_rate": args.lr, 233 "loss_function": args.loss, 234 "optimizer": args.optimizer, 235 "weight_decay": args.weight_decay, 236 "weight_decay_rate": args.weight_decay_rate, 237 "gradient_clipping": args.gradient_clipping, 238 "gradient_clipping_threshold": args.gradient_clipping_threshold, 239 "scheduler": args.scheduler, 240 "use_warm_up": args.use_warm_up, 241 "warm_up_start_lr": args.warm_up_start_lr, 242 "warm_up_step_duration": args.warm_up_step_duration, 243 "use_amp": args.activate_amp, 244 "amp_scaler": args.amp_scaler, 245 "save_only_best_model": args.save_only_best_model, 246 247 # Loss 248 "wc_loss_silog_lambda": args.wc_loss_silog_lambda, 249 "wc_loss_weight_silog": args.wc_loss_weight_silog, 250 "wc_loss_weight_grad": args.wc_loss_weight_grad, 251 "wc_loss_weight_ssim": args.wc_loss_weight_ssim, 252 "wc_loss_weight_edge_aware": args.wc_loss_weight_edge_aware, 253 "wc_loss_weight_l1": args.wc_loss_weight_l1, 254 "wc_loss_weight_var": args.wc_loss_weight_var, 255 "wc_loss_weight_range": args.wc_loss_weight_range, 256 "wc_loss_weight_blur": args.wc_loss_weight_blur, 257 258 # Model 259 "model": args.model, 260 "n_model_params": n_model_params, 261 "resfcn_in_channels": args.resfcn_in_channels, 262 "resfcn_hidden_channels": args.resfcn_hidden_channels, 263 "resfcn_out_channels": args.resfcn_out_channels, 264 "resfcn_num_blocks": args.resfcn_num_blocks, 265 266 "pix2pix_in_channels": args.pix2pix_in_channels, 267 "pix2pix_hidden_channels": args.pix2pix_hidden_channels, 268 "pix2pix_out_channels": args.pix2pix_out_channels, 269 "pix2pix_second_loss_lambda": args.pix2pix_second_loss_lambda, 270 271 "physicsformer_in_channels": args.physicsformer_in_channels, 272 "physicsformer_out_channels": args.physicsformer_out_channels, 273 "physicsformer_img_size": args.physicsformer_img_size, 274 "physicsformer_patch_size": args.physicsformer_patch_size, 275 "physicsformer_embedded_dim": args.physicsformer_embedded_dim, 276 "physicsformer_num_blocks": args.physicsformer_num_blocks, 277 "physicsformer_heads": args.physicsformer_heads, 278 "physicsformer_mlp_dim": args.physicsformer_mlp_dim, 279 "physicsformer_dropout": args.physicsformer_dropout, 280 281 # Data 282 "data_variation": args.data_variation, 283 "input_type": args.input_type, 284 "output_type": args.output_type, 285 "fake_rgb_output": args.fake_rgb_output, 286 "make_14_dividable_size": args.make_14_dividable_size, 287 288 # Experiment tracking 289 "experiment_name": args.experiment_name, 290 "run_name": current_save_name, # CURRENT_SAVE_NAME, 291 "tensorboard_path": args.tensorboard_path, 292 "save_path": args.save_path, 293 "checkpoint_save_dir": checkpoint_save_dir, 294 "cmap": args.cmap, 295 296 # >> Residual Model << 297 "base_model": args.base_model, 298 "complex_model": args.complex_model, 299 "combine_mode": args.combine_mode, 300 301 # ---- Loss (2nd branch) 302 "loss_2": args.loss_2, 303 "wc_loss_silog_lambda_2": args.wc_loss_silog_lambda_2, 304 "wc_loss_weight_silog_2": args.wc_loss_weight_silog_2, 305 "wc_loss_weight_grad_2": args.wc_loss_weight_grad_2, 306 "wc_loss_weight_ssim_2": args.wc_loss_weight_ssim_2, 307 "wc_loss_weight_edge_aware_2": args.wc_loss_weight_edge_aware_2, 308 "wc_loss_weight_l1_2": args.wc_loss_weight_l1_2, 309 "wc_loss_weight_var_2": args.wc_loss_weight_var_2, 310 "wc_loss_weight_range_2": args.wc_loss_weight_range_2, 311 "wc_loss_weight_blur_2": args.wc_loss_weight_blur_2, 312 313 # ---- ResFCN Model 2 314 "resfcn_2_in_channels": args.resfcn_2_in_channels, 315 "resfcn_2_hidden_channels": args.resfcn_2_hidden_channels, 316 "resfcn_2_out_channels": args.resfcn_2_out_channels, 317 "resfcn_2_num_blocks": args.resfcn_2_num_blocks, 318 319 # ---- Pix2Pix Model 2 320 "pix2pix_2_in_channels": args.pix2pix_2_in_channels, 321 "pix2pix_2_hidden_channels": args.pix2pix_2_hidden_channels, 322 "pix2pix_2_out_channels": args.pix2pix_2_out_channels, 323 "pix2pix_2_second_loss_lambda": args.pix2pix_2_second_loss_lambda, 324 325 # ---- PhysicsFormer Model 2 326 "physicsformer_in_channels_2": args.physicsformer_in_channels_2, 327 "physicsformer_out_channels_2": args.physicsformer_out_channels_2, 328 "physicsformer_img_size_2": args.physicsformer_img_size_2, 329 "physicsformer_patch_size_2": args.physicsformer_patch_size_2, 330 "physicsformer_embedded_dim_2": args.physicsformer_embedded_dim_2, 331 "physicsformer_num_blocks_2": args.physicsformer_num_blocks_2, 332 "physicsformer_heads_2": args.physicsformer_heads_2, 333 "physicsformer_mlp_dim_2": args.physicsformer_mlp_dim_2, 334 "physicsformer_dropout_2": args.physicsformer_dropout_2, 335 } 336 337 338def backward_model(model, x, y, optimizer, criterion, device, epoch, amp_scaler, gradient_clipping_threshold=None): 339 """ 340 Performs a backward pass, optimizer step, and mixed-precision handling. 341 342 Parameter: 343 - model (nn.Module): 344 Model to train. 345 - x (torch.Tensor): 346 Input tensor. 347 - y (torch.Tensor): 348 Target tensor. 349 - optimizer: 350 Optimizer or tuple of optimizers (for Pix2Pix/ResidualDesignModel). 351 - criterion: 352 Loss function. 353 - device (torch.device): 354 Device to perform computation on. 355 - epoch (int): 356 Current training epoch. 357 - amp_scaler (GradScaler): 358 Gradient scaler for mixed-precision training. 359 - gradient_clipping_threshold (float, optional): 360 Maximum allowed gradient norm for clipping. 361 362 Returns: 363 - loss (torch.Tensor): Computed loss value for the batch. 364 """ 365 if isinstance(model, Pix2Pix): 366 if epoch is None or epoch % 2 == 0: 367 model.discriminator_step(x, y, optimizer[1], amp_scaler, device, gradient_clipping_threshold) 368 loss, _, _ = model.generator_step(x, y, optimizer[0], amp_scaler, device, gradient_clipping_threshold) 369 elif amp_scaler: 370 # reset gradients 371 optimizer.zero_grad() 372 373 with autocast(device_type=device.type): 374 y_predict = model(x) 375 loss = criterion(y_predict, y) 376 amp_scaler.scale(loss).backward() 377 if gradient_clipping_threshold: 378 # Unscale first! 379 amp_scaler.unscale_(optimizer) 380 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clipping_threshold) 381 amp_scaler.step(optimizer) 382 amp_scaler.update() 383 else: 384 # reset gradients 385 optimizer.zero_grad() 386 387 y_predict = model(x) 388 loss = criterion(y_predict, y) 389 loss.backward() 390 if gradient_clipping_threshold: 391 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clipping_threshold) 392 optimizer.step() 393 return loss 394 395 396 397def train_one_epoch(model, loader, optimizer, criterion, device, epoch=None, amp_scaler=None, gradient_clipping_threshold=None): 398 """ 399 Runs one full epoch of training and returns the average loss. 400 401 Parameter: 402 - model (nn.Module): 403 Model to train. 404 - loader (DataLoader): 405 Data loader containing training batches. 406 - optimizer: 407 Optimizer or tuple of optimizers. 408 - criterion: 409 Loss function or tuple of losses (for multi-stage models). 410 - device (torch.device): 411 Device to use for training. 412 - epoch (int, optional): 413 Current epoch index. 414 - amp_scaler (GradScaler, optional): 415 Mixed-precision scaler. 416 - gradient_clipping_threshold (float, optional): 417 Max allowed gradient norm. 418 419 Returns: 420 - avg_loss (float): Average training loss for the epoch. 421 """ 422 # change to train mode -> calc gradients 423 model.train() 424 425 total_loss = 0.0 426 # for x, y in tqdm(loader, desc=f"Epoch {epoch:03}", leave=True, ascii=True mininterval=5): 427 for x, y in loader: 428 429 if not isinstance(model, ResidualDesignModel): 430 x, y = x.to(device), y.to(device) 431 432 if isinstance(model, ResidualDesignModel): 433 base_input, complex_input = x 434 base_target, complex_target, target_= y 435 436 # Basline 437 base_input = base_input.to(device) 438 base_target = base_target.to(device) 439 base_loss = backward_model(model=model.base_model, x=base_input, y=base_target, optimizer=optimizer[0], criterion=criterion[0], device=device, epoch=epoch, amp_scaler=amp_scaler, gradient_clipping_threshold=gradient_clipping_threshold) 440 # del base_input 441 # torch.cuda.empty_cache() 442 443 # Complex 444 complex_input = complex_input.to(device) 445 complex_target = complex_target.to(device) 446 complex_loss = backward_model(model=model.complex_model, x=complex_input, y=complex_target, optimizer=optimizer[1], criterion=criterion[1], device=device, epoch=epoch, amp_scaler=amp_scaler, gradient_clipping_threshold=gradient_clipping_threshold) 447 # del complex_input 448 # torch.cuda.empty_cache() 449 450 # Fusion 451 target_ = target_.to(device) 452 combine_loss = model.combine_net.backward(base_target, complex_target, target_) 453 454 model.backward(base_target, complex_target, target_) 455 456 model.last_base_loss = base_loss 457 model.last_complex_loss = complex_loss 458 model.last_combined_loss = combine_loss 459 loss = base_loss + complex_loss + combine_loss 460 else: 461 loss = backward_model(model=model, x=x, y=y, optimizer=optimizer, criterion=criterion, device=device, epoch=epoch, amp_scaler=amp_scaler, gradient_clipping_threshold=gradient_clipping_threshold) 462 total_loss += loss.item() 463 return total_loss / len(loader) 464 465 466 467@torch.no_grad() 468def evaluate(model, loader, criterion, device, writer=None, epoch=None, save_path=None, cmap="gray", use_tqdm=True): 469 """ 470 Evaluates the model on the validation set and logs results to TensorBoard or MLflow. 471 472 Parameter: 473 - model (nn.Module): 474 Model to evaluate. 475 - loader (DataLoader): 476 Validation data loader. 477 - criterion: 478 Loss function used for evaluation. 479 - device (torch.device): 480 Device to perform inference on. 481 - writer (SummaryWriter, optional): 482 TensorBoard writer for visualization. 483 - epoch (int, optional): 484 Current epoch index for logging. 485 - save_path (str, optional): 486 Directory to save sample images. 487 - cmap (str): 488 Colormap for saved images. 489 - use_tqdm (bool): 490 Whether to use tqdm progress bar. 491 492 Returns: 493 - avg_loss (float): Average validation loss. 494 """ 495 model.eval() 496 497 total_loss = 0.0 498 is_first_round = True 499 500 if use_tqdm: 501 validation_iter = tqdm(loader, desc="Validation", ascii=True, mininterval=3) 502 else: 503 validation_iter = loader 504 505 for x, y in validation_iter: 506 x, y = x.to(device), y.to(device) 507 y_predict = model(x) 508 total_loss += criterion(y_predict, y).item() 509 510 if is_first_round: 511 if writer: 512 # Convert to grid 513 img_grid_input = torchvision.utils.make_grid(x[:4].cpu(), normalize=True, scale_each=True) 514 max_val = max(y_predict.max().item(), 1e-8) 515 if y_predict.ndim == 3: # e.g., [B, H, W] 516 img_grid_pred = torchvision.utils.make_grid(y_predict[:4].unsqueeze(1).float().cpu() / max_val) 517 else: # [B, 1, H, W] 518 img_grid_pred = torchvision.utils.make_grid(y_predict[:4].float().cpu() / max_val) 519 520 max_val = max(y.max().item(), 1e-8) 521 if y.ndim == 3: 522 img_grid_gt = torchvision.utils.make_grid(y[:4].unsqueeze(1).float().cpu() / max_val) 523 else: 524 img_grid_gt = torchvision.utils.make_grid(y[:4].float().cpu() / max_val) 525 526 # Log to TensorBoard 527 writer.add_image("Input", img_grid_input, epoch) 528 writer.add_image("Prediction", img_grid_pred, epoch) 529 writer.add_image("GroundTruth", img_grid_gt, epoch) 530 531 if save_path: 532 # os.makedirs(save_path, exist_ok=True) 533 prediction_path = os.path.join(save_path, f"{epoch}_prediction.png") 534 plt.imsave(prediction_path, 535 y_predict[0].detach().cpu().numpy().squeeze(), cmap=cmap) 536 mlflow.log_artifact(prediction_path, artifact_path="images") 537 538 input_path = os.path.join(save_path, f"{epoch}_input.png") 539 plt.imsave(input_path, 540 x[0][0].detach().cpu().numpy().squeeze(), cmap=cmap) 541 mlflow.log_artifact(input_path, artifact_path="images") 542 543 ground_truth_path = os.path.join(save_path, f"{epoch}_ground_truth.png") 544 plt.imsave(ground_truth_path, 545 y[0].detach().cpu().numpy().squeeze(), cmap=cmap) 546 mlflow.log_artifacts(ground_truth_path, artifact_path="images") 547 548 # alternative direct save: 549 # import numpy as np 550 # import io 551 # from PIL import Image 552 553 # img = (y[0].detach().cpu().squeeze().numpy() * 255).astype(np.uint8) 554 # img_pil = Image.fromarray(img) 555 556 # buf = io.BytesIO() 557 # img_pil.save(buf, format='PNG') 558 # buf.seek(0) 559 560 # mlflow.log_image(image=img_pil, artifact_file=f"images/{epoch}_ground_truth.png") 561 562 is_first_round = False 563 564 return total_loss / len(loader) 565 566 567 568# --------------------------- 569# > Train Main < 570# --------------------------- 571def train(args=None): 572 """ 573 Main training loop for image-to-image tasks. 574 575 Workflow: 576 1. Initializes the training and validation datasets based on model type. 577 2. Constructs the model and its loss functions. 578 3. Configures optimizers, learning rate schedulers, and optional warm-up phases. 579 4. Enables mixed precision (AMP) if selected. 580 5. Sets up MLflow experiment tracking and TensorBoard visualization. 581 6. Executes the epoch loop: 582 - Trains the model for one epoch (`train_one_epoch()`). 583 - Optionally evaluates on the validation set. 584 - Logs metrics and learning rates. 585 - Updates the scheduler. 586 - Saves checkpoints (best or periodic). 587 7. Logs the trained model and experiment results to MLflow upon completion. 588 589 Parameters: 590 - args : argparse.Namespace, optional 591 Parsed command-line arguments containing all training configurations. 592 If None, the function will automatically call `parse_args()` to obtain them. 593 594 Returns: 595 - None: The function performs training and logging in-place without returning values. 596 597 Notes: 598 - Automatically handles model-specific configurations (e.g., Pix2Pix discriminator, ResidualDesignModel branches). 599 - Uses `prime.get_time()` to generate time-stamped run names. 600 - Supports gradient clipping and various learning rate schedulers. 601 602 Logging: 603 - **MLflow**: Stores metrics, hyperparameters, checkpoints, and final model. 604 - **TensorBoard**: Logs training/validation losses, learning rates, and sub-loss components. 605 """ 606 print("\n---> Welcome to Image-to-Image Training <---") 607 608 print("\nChecking your Hardware:") 609 print(prime.get_hardware()) 610 611 # Parse arguments 612 if args is None: 613 args = parse_args() 614 615 CURRENT_SAVE_NAME = prime.get_time(pattern="YEAR-MONTH-DAY_HOUR_MINUTE_", time_zone="Europe/Berlin") + args.run_name 616 617 device = torch.device(args.device if torch.cuda.is_available() else "cpu") 618 619 # Dataset loading 620 train_dataset, val_dataset, train_loader, val_loader = get_data(args) 621 622 # Loss 623 criterion = get_loss(loss_name=args.loss, args=args) 624 625 if args.model.lower() == "residual_design_model": 626 criterion = [ 627 criterion, 628 get_loss(loss_name=args.loss_2+"_2", args=args) 629 ] 630 631 # Model Loading 632 model = get_model(args, device, criterion=criterion) 633 634 # get parameter amount 635 n_model_params = 0 636 for cur_model_param in model.parameters(): 637 n_model_params += cur_model_param.numel() 638 639 INPUT_CHANNELS = model.get_input_channels() 640 641 # Optimizer 642 if args.model.lower() == "pix2pix": 643 optimizer = [get_optimizer(optimizer_name=args.optimizer, model=model.generator, lr=args.lr, args=args), 644 get_optimizer(optimizer_name=args.optimizer_2, model=model.discriminator, lr=args.lr, args=args)] 645 elif args.model.lower() == "residual_design_model": 646 optimizer = [get_optimizer(optimizer_name=args.optimizer, model=model.base_model, lr=args.lr, args=args), 647 get_optimizer(optimizer_name=args.optimizer_2, model=model.complex_model, lr=args.lr, args=args)] 648 649 if args.base_model.lower() == "pix2pix": 650 optimizer[0] = [get_optimizer(optimizer_name=args.optimizer, model=model.base_model.generator, lr=args.lr, args=args), 651 get_optimizer(optimizer_name=args.optimizer, model=model.base_model.discriminator, lr=args.lr, args=args)] 652 if args.complex_model.lower() == "pix2pix": 653 optimizer[1] = [get_optimizer(optimizer_name=args.optimizer_2, model=model.complex_model.generator, lr=args.lr, args=args), 654 get_optimizer(optimizer_name=args.optimizer_2, model=model.complex_model.discriminator, lr=args.lr, args=args)] 655 else: 656 optimizer = get_optimizer(optimizer_name=args.optimizer, model=model, lr=args.lr, args=args) 657 658 # Scheduler 659 if args.model.lower() == "residual_design_model": 660 661 if args.base_model.lower() == "pix2pix": 662 scheduler_1 = [get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0][0], args=args), 663 get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0][1], args=args)] 664 else: 665 scheduler_1 = get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0], args=args) 666 667 if args.complex_model.lower() == "pix2pix": 668 scheduler_2 = [get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1][0], args=args), 669 get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1][1], args=args)] 670 else: 671 scheduler_2 = get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1], args=args) 672 673 scheduler = [scheduler_1, scheduler_2] 674 elif args.model.lower() == "pix2pix": 675 scheduler = [get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0], args=args), 676 get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1], args=args)] 677 else: 678 scheduler = get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer, args=args) 679 680 # Warm-Up Scheduler 681 if args.use_warm_up: 682 if isinstance(scheduler, (tuple, list)): 683 new_scheduler = [] 684 for cur_scheduler in scheduler: 685 new_scheduler += [WarmUpScheduler(start_lr=args.warm_up_start_lr, end_lr=args.lr, optimizer=cur_scheduler.optimizer, scheduler=cur_scheduler, step_duration=args.warm_up_step_duration)] 686 687 scheduler = new_scheduler 688 else: 689 if scheduler is not None: 690 cur_optimizer = scheduler.optimizer 691 else: 692 cur_optimizer = None 693 scheduler = WarmUpScheduler(start_lr=args.warm_up_start_lr, end_lr=args.lr, optimizer=cur_optimizer, scheduler=scheduler, step_duration=args.warm_up_step_duration) 694 695 # AMP Scaler 696 if args.activate_amp == False: 697 amp_scaler = None 698 elif args.amp_scaler.lower() == "none": 699 amp_scaler = DummyScaler() 700 elif args.amp_scaler.lower() == "grad": 701 amp_scaler = GradScaler() 702 else: 703 raise ValueError(f"'{args.amp_scaler}' is not an supported scaler.") 704 705 # setup checkpoint saving 706 checkpoint_save_dir = os.path.join(args.checkpoint_save_dir, args.experiment_name, CURRENT_SAVE_NAME) 707 os.makedirs(checkpoint_save_dir, exist_ok=True) 708 shutil.rmtree(checkpoint_save_dir) 709 os.makedirs(checkpoint_save_dir, exist_ok=True) 710 711 # setup intermediate image saving 712 save_path = os.path.join(args.save_path, args.experiment_name, CURRENT_SAVE_NAME) 713 os.makedirs(save_path, exist_ok=True) 714 shutil.rmtree(save_path) 715 os.makedirs(save_path, exist_ok=True) 716 717 # setup gradient clipping 718 if args.gradient_clipping: 719 gradient_clipping_threshold = args.gradient_clipping_threshold 720 else: 721 gradient_clipping_threshold = None 722 723 mlflow.set_experiment(args.experiment_name) 724 # same as: 725 # mlflow.create_experiment(args.experiment_name) 726 # mlflow.get_experiment_by_name(experiment_name) 727 728 # Start MLflow run 729 with mlflow.start_run(run_name=CURRENT_SAVE_NAME): 730 731 # TensorBoard writer 732 tensorboard_path = os.path.join(args.tensorboard_path, args.experiment_name, CURRENT_SAVE_NAME) 733 os.makedirs(tensorboard_path, exist_ok=True) 734 shutil.rmtree(tensorboard_path) 735 os.makedirs(tensorboard_path, exist_ok=True) 736 writer = SummaryWriter(log_dir=tensorboard_path) 737 738 # Log hyperparameters 739 params = get_params(args=args, 740 device=device, 741 n_model_params=n_model_params, 742 current_save_name=CURRENT_SAVE_NAME, 743 checkpoint_save_dir=checkpoint_save_dir) 744 745 # mlflow.log_params(vars(args)) 746 mlflow.log_params(params) 747 748 params_text = "\n".join([f"{k}: {v}" for k, v in params.items()]) 749 writer.add_text("hyperparameters", params_text, 0) 750 751 print(f"Train dataset size: {len(train_dataset)} | Validation dataset size: {len(val_dataset)}") 752 mlflow.log_metrics({ 753 "train_dataset_size": len(train_dataset), 754 "val_dataset_size": len(val_dataset) 755 }) 756 757 758 # log model architecture 759 # -> create dataset for that 760 if isinstance(INPUT_CHANNELS, (list, tuple)): 761 dummy_inference_data = [torch.ones(size=(1, channels, 256, 256)).to(device) for channels in INPUT_CHANNELS] 762 else: 763 dummy_inference_data = torch.ones(size=(1, INPUT_CHANNELS, 256, 256)).to(device) 764 writer.add_graph(model, dummy_inference_data) 765 766 # Run Training 767 last_best_loss = float("inf") 768 try: 769 epoch_iter = tqdm(range(1, args.epochs + 1), desc="Epochs", ascii=True) 770 for epoch in epoch_iter: 771 start_time = time.time() 772 train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch, amp_scaler, gradient_clipping_threshold) 773 duration = time.time() - start_time 774 if epoch % args.validation_interval == 0: 775 val_loss = evaluate(model, val_loader, criterion, device, writer=writer, epoch=epoch, save_path=save_path, cmap=args.cmap, use_tqdm=False) 776 else: 777 val_loss = float("inf") 778 779 val_str = f"{val_loss:.4f}" if epoch % args.validation_interval == 0 else "N/A" 780 epoch_iter.set_postfix(train_loss=f"{train_loss:.4f}", val_loss=val_str, time_needed=f"{duration:.2f}s") 781 # tqdm.write(f" -> Train Loss: {train_loss:.4f} | Val Loss: {val_str} | Time: {duration:.2f}") 782 # \n\n[Epoch {epoch:02}/{args.epochs}] 783 784 # Hint: Tensorboard and mlflow does not like spaces in tags! 785 786 # Log to TensorBoard 787 writer.add_scalar("Time/epoch_duration", duration, epoch) 788 writer.add_scalar("Loss/train", train_loss, epoch) 789 writer.add_scalar("Loss/val", val_loss, epoch) 790 if isinstance(scheduler, list): 791 if isinstance(scheduler[0], list): 792 writer.add_scalar("LR/generator", scheduler[0][0].get_last_lr()[0], epoch) 793 writer.add_scalar("LR/discriminator", scheduler[0][1].get_last_lr()[0], epoch) 794 else: 795 name = "generator" if args.model.lower() == "pix2pix" else "base_model" 796 writer.add_scalar(f"LR/{name}", scheduler[0].get_last_lr()[0], epoch) 797 798 if isinstance(scheduler[1], list): 799 writer.add_scalar("LR/generator", scheduler[1][0].get_last_lr()[0], epoch) 800 writer.add_scalar("LR/discriminator", scheduler[1][1].get_last_lr()[0], epoch) 801 else: 802 name = "discriminator" if args.model.lower() == "pix2pix" else "complex_model" 803 writer.add_scalar(f"LR/{name}", scheduler[1].get_last_lr()[0], epoch) 804 else: 805 writer.add_scalar("LR", scheduler.get_last_lr()[0], epoch) 806 807 # Log to MLflow 808 if type(scheduler) in [list, tuple]: 809 metrics = { 810 "train_loss": train_loss, 811 "val_loss": val_loss, 812 } 813 idx = 0 814 for idx, cur_scheduler in enumerate(scheduler): 815 metrics[f"lr_{idx}"] = cur_scheduler.get_last_lr()[0] 816 mlflow.log_metrics(metrics, step=epoch) 817 else: 818 mlflow.log_metrics({ 819 "train_loss": train_loss, 820 "val_loss": val_loss, 821 "lr": scheduler.get_last_lr()[0] 822 }, step=epoch) 823 824 # add sub losses / loss components 825 if args.model.lower() in ["pix2pix", "residual_design_model"]: 826 losses = model.get_dict() 827 for name, value in losses.items(): 828 writer.add_scalar(f"LossComponents/{name}", value, epoch) 829 mlflow.log_metrics(losses, step=epoch) 830 831 if type(criterion) in [list, tuple]: 832 if args.loss in ["weighted_combined"]: 833 losses = criterion[0].get_dict() 834 for name, value in losses.items(): 835 writer.add_scalar(f"LossComponents/{name}", value, epoch) 836 mlflow.log_metrics(losses, step=epoch) 837 838 if args.loss_2 in ["weighted_combined"] and args.model.lower() in ["residual_design_model"]: 839 losses = criterion[1].get_dict() 840 for name, value in losses.items(): 841 writer.add_scalar(f"LossComponents/{name}", value, epoch) 842 mlflow.log_metrics(losses, step=epoch) 843 else: 844 if args.loss in ["weighted_combined"]: 845 losses = criterion.get_dict() 846 for name, value in losses.items(): 847 writer.add_scalar(f"LossComponents/{name}", value, epoch) 848 mlflow.log_metrics(losses, step=epoch) 849 850 851 852 # Step scheduler 853 if args.model.lower() in ["pix2pix"]: 854 scheduler[0].step() 855 scheduler[1].step() 856 else: 857 scheduler.step() 858 859 # Save Checkpoint 860 if args.save_only_best_model: 861 if val_loss < last_best_loss or (last_best_loss == float("inf") and epoch == args.epochs): 862 last_best_loss = val_loss 863 checkpoint_path = os.path.join(checkpoint_save_dir, f"best_checkpoint.pth") 864 save_checkpoint(args, model, optimizer, scheduler, epoch, checkpoint_path) 865 866 # Log model checkpoint path 867 mlflow.log_artifact(checkpoint_path) 868 elif epoch % args.checkpoint_interval == 0 or epoch == args.epochs: 869 checkpoint_path = os.path.join(checkpoint_save_dir, f"epoch_{epoch}.pth") 870 save_checkpoint(args, model, optimizer, scheduler, epoch, checkpoint_path) 871 872 # Log model checkpoint path 873 mlflow.log_artifact(checkpoint_path) 874 875 # Log final model 876 try: 877 mlflow.pytorch.log_model(model.cpu(), name="model", input_example=dummy_inference_data.cpu().numpy()) 878 except Exception as e: 879 print(e) 880 mlflow.pytorch.log_model(model, name="model") 881 mlflow.end_run() 882 finally: 883 writer.close() 884 885 print("Training completed.") 886 887 888 889if __name__ == "__main__": 890 train()
72def get_data(args): 73 if args.model.lower() == "residual_design_model": 74 train_dataset = PhysGenResidualDataset(variation=args.data_variation, mode="train", 75 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 76 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 77 78 val_dataset = PhysGenDataset(variation=args.data_variation, mode="validation", input_type="osm", output_type="standard", 79 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 80 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 81 else: 82 train_dataset = PhysGenDataset(variation=args.data_variation, mode="train", input_type=args.input_type, output_type=args.output_type, 83 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 84 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 85 86 val_dataset = PhysGenDataset(variation=args.data_variation, mode="validation", input_type=args.input_type, output_type=args.output_type, 87 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 88 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 89 90 train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 91 val_loader = DataLoader(val_dataset, batch_size=args.batch_size) 92 93 return train_dataset, val_dataset, train_loader, val_loader
97def get_loss(loss_name, args): 98 """ 99 Returns a loss function instance based on the provided loss name. 100 101 Supported losses: 102 - L1 / L1_2 103 - CrossEntropy / CrossEntropy_2 104 - WeightedCombined / WeightedCombined_2 105 106 Parameter: 107 - loss_name (str): 108 Name of the loss function. 109 - args: 110 Parsed command-line arguments with configured loss weights. 111 112 Returns: 113 - criterion (nn.Module): Instantiated loss function. 114 """ 115 loss_name = loss_name.lower() 116 117 if loss_name == "l1": 118 criterion = nn.L1Loss() 119 elif loss_name == "l1_2": 120 criterion = nn.L1Loss() 121 elif loss_name == "crossentropy": 122 criterion = nn.CrossEntropyLoss() 123 elif loss_name == "crossentropy_2": 124 criterion = nn.CrossEntropyLoss() 125 elif loss_name == "weighted_combined": 126 criterion = WeightedCombinedLoss( 127 silog_lambda=args.wc_loss_silog_lambda, 128 weight_silog=args.wc_loss_weight_silog, 129 weight_grad=args.wc_loss_weight_grad, 130 weight_ssim=args.wc_loss_weight_ssim, 131 weight_edge_aware=args.wc_loss_weight_edge_aware, 132 weight_l1=args.wc_loss_weight_l1, 133 weight_var=args.wc_loss_weight_var, 134 weight_range=args.wc_loss_weight_range, 135 weight_blur=args.wc_loss_weight_blur 136 ) 137 elif loss_name == "weighted_combined_2": 138 criterion = WeightedCombinedLoss( 139 silog_lambda=args.wc_loss_silog_lambda_2, 140 weight_silog=args.wc_loss_weight_silog_2, 141 weight_grad=args.wc_loss_weight_grad_2, 142 weight_ssim=args.wc_loss_weight_ssim_2, 143 weight_edge_aware=args.wc_loss_weight_edge_aware_2, 144 weight_l1=args.wc_loss_weight_l1_2, 145 weight_var=args.wc_loss_weight_var_2, 146 weight_range=args.wc_loss_weight_range_2, 147 weight_blur=args.wc_loss_weight_blur_2 148 ) 149 else: 150 raise ValueError(f"'{loss_name}' is not a supported loss.") 151 152 return criterion
Returns a loss function instance based on the provided loss name.
Supported losses:
- L1 / L1_2
- CrossEntropy / CrossEntropy_2
- WeightedCombined / WeightedCombined_2
Parameter:
- loss_name (str): Name of the loss function.
- args: Parsed command-line arguments with configured loss weights.
Returns:
- criterion (nn.Module): Instantiated loss function.
156def get_optimizer(optimizer_name, model, lr, args): 157 """ 158 Returns an optimizer for the given model. 159 160 Supported optimizers: 161 - Adam 162 - AdamW 163 164 Parameter: 165 - optimizer_name (str): 166 Name of the optimizer. 167 - model (nn.Module): 168 Model whose parameters should be optimized. 169 - lr (float): 170 Learning rate. 171 - args: 172 Parsed command-line arguments with optimizer configuration. 173 174 Returns: 175 - optimizer (torch.optim.Optimizer): Instantiated optimizer. 176 """ 177 optimizer_name = optimizer_name.lower() 178 179 weight_decay_rate = args.weight_decay_rate if args.weight_decay else 0 180 181 if args.optimizer.lower() == "adam": 182 optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=weight_decay_rate) 183 elif args.optimizer.lower() == "adamw": 184 optimizer = optim.AdamW(model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=weight_decay_rate) 185 else: 186 raise ValueError(f"'{optimizer_name}' is not a supported optimizer.") 187 188 return optimizer
Returns an optimizer for the given model.
Supported optimizers:
- Adam
- AdamW
Parameter:
- optimizer_name (str): Name of the optimizer.
- model (nn.Module): Model whose parameters should be optimized.
- lr (float): Learning rate.
- args: Parsed command-line arguments with optimizer configuration.
Returns:
- optimizer (torch.optim.Optimizer): Instantiated optimizer.
192def get_scheduler(scheduler_name, optimizer, args): 193 """ 194 Returns a learning rate scheduler for the given optimizer. 195 196 Supported schedulers: 197 - StepLR 198 - CosineAnnealingLR 199 200 Parameter: 201 - scheduler_name (str): 202 Name of the scheduler. 203 - optimizer: 204 Optimizer whose learning rate will be managed. 205 - args: 206 Parsed command-line arguments containing scheduler configuration. 207 208 Returns: 209 - scheduler (torch.optim.lr_scheduler): Instantiated scheduler. 210 """ 211 scheduler_name = scheduler_name.lower() 212 213 if scheduler_name == "step": 214 scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9) 215 elif scheduler_name == "cosine": 216 scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) 217 else: 218 raise ValueError(f"'{scheduler_name}' is not a supported scheduler.") 219 220 return scheduler
Returns a learning rate scheduler for the given optimizer.
Supported schedulers:
- StepLR
- CosineAnnealingLR
Parameter:
- scheduler_name (str): Name of the scheduler.
- optimizer: Optimizer whose learning rate will be managed.
- args: Parsed command-line arguments containing scheduler configuration.
Returns:
- scheduler (torch.optim.lr_scheduler): Instantiated scheduler.
224def get_params(args, device, n_model_params, current_save_name, checkpoint_save_dir): 225 return { 226 # General 227 "mode": args.mode, 228 "device": str(device), 229 230 # Training 231 "epochs": args.epochs, 232 "batch_size": args.batch_size, 233 "learning_rate": args.lr, 234 "loss_function": args.loss, 235 "optimizer": args.optimizer, 236 "weight_decay": args.weight_decay, 237 "weight_decay_rate": args.weight_decay_rate, 238 "gradient_clipping": args.gradient_clipping, 239 "gradient_clipping_threshold": args.gradient_clipping_threshold, 240 "scheduler": args.scheduler, 241 "use_warm_up": args.use_warm_up, 242 "warm_up_start_lr": args.warm_up_start_lr, 243 "warm_up_step_duration": args.warm_up_step_duration, 244 "use_amp": args.activate_amp, 245 "amp_scaler": args.amp_scaler, 246 "save_only_best_model": args.save_only_best_model, 247 248 # Loss 249 "wc_loss_silog_lambda": args.wc_loss_silog_lambda, 250 "wc_loss_weight_silog": args.wc_loss_weight_silog, 251 "wc_loss_weight_grad": args.wc_loss_weight_grad, 252 "wc_loss_weight_ssim": args.wc_loss_weight_ssim, 253 "wc_loss_weight_edge_aware": args.wc_loss_weight_edge_aware, 254 "wc_loss_weight_l1": args.wc_loss_weight_l1, 255 "wc_loss_weight_var": args.wc_loss_weight_var, 256 "wc_loss_weight_range": args.wc_loss_weight_range, 257 "wc_loss_weight_blur": args.wc_loss_weight_blur, 258 259 # Model 260 "model": args.model, 261 "n_model_params": n_model_params, 262 "resfcn_in_channels": args.resfcn_in_channels, 263 "resfcn_hidden_channels": args.resfcn_hidden_channels, 264 "resfcn_out_channels": args.resfcn_out_channels, 265 "resfcn_num_blocks": args.resfcn_num_blocks, 266 267 "pix2pix_in_channels": args.pix2pix_in_channels, 268 "pix2pix_hidden_channels": args.pix2pix_hidden_channels, 269 "pix2pix_out_channels": args.pix2pix_out_channels, 270 "pix2pix_second_loss_lambda": args.pix2pix_second_loss_lambda, 271 272 "physicsformer_in_channels": args.physicsformer_in_channels, 273 "physicsformer_out_channels": args.physicsformer_out_channels, 274 "physicsformer_img_size": args.physicsformer_img_size, 275 "physicsformer_patch_size": args.physicsformer_patch_size, 276 "physicsformer_embedded_dim": args.physicsformer_embedded_dim, 277 "physicsformer_num_blocks": args.physicsformer_num_blocks, 278 "physicsformer_heads": args.physicsformer_heads, 279 "physicsformer_mlp_dim": args.physicsformer_mlp_dim, 280 "physicsformer_dropout": args.physicsformer_dropout, 281 282 # Data 283 "data_variation": args.data_variation, 284 "input_type": args.input_type, 285 "output_type": args.output_type, 286 "fake_rgb_output": args.fake_rgb_output, 287 "make_14_dividable_size": args.make_14_dividable_size, 288 289 # Experiment tracking 290 "experiment_name": args.experiment_name, 291 "run_name": current_save_name, # CURRENT_SAVE_NAME, 292 "tensorboard_path": args.tensorboard_path, 293 "save_path": args.save_path, 294 "checkpoint_save_dir": checkpoint_save_dir, 295 "cmap": args.cmap, 296 297 # >> Residual Model << 298 "base_model": args.base_model, 299 "complex_model": args.complex_model, 300 "combine_mode": args.combine_mode, 301 302 # ---- Loss (2nd branch) 303 "loss_2": args.loss_2, 304 "wc_loss_silog_lambda_2": args.wc_loss_silog_lambda_2, 305 "wc_loss_weight_silog_2": args.wc_loss_weight_silog_2, 306 "wc_loss_weight_grad_2": args.wc_loss_weight_grad_2, 307 "wc_loss_weight_ssim_2": args.wc_loss_weight_ssim_2, 308 "wc_loss_weight_edge_aware_2": args.wc_loss_weight_edge_aware_2, 309 "wc_loss_weight_l1_2": args.wc_loss_weight_l1_2, 310 "wc_loss_weight_var_2": args.wc_loss_weight_var_2, 311 "wc_loss_weight_range_2": args.wc_loss_weight_range_2, 312 "wc_loss_weight_blur_2": args.wc_loss_weight_blur_2, 313 314 # ---- ResFCN Model 2 315 "resfcn_2_in_channels": args.resfcn_2_in_channels, 316 "resfcn_2_hidden_channels": args.resfcn_2_hidden_channels, 317 "resfcn_2_out_channels": args.resfcn_2_out_channels, 318 "resfcn_2_num_blocks": args.resfcn_2_num_blocks, 319 320 # ---- Pix2Pix Model 2 321 "pix2pix_2_in_channels": args.pix2pix_2_in_channels, 322 "pix2pix_2_hidden_channels": args.pix2pix_2_hidden_channels, 323 "pix2pix_2_out_channels": args.pix2pix_2_out_channels, 324 "pix2pix_2_second_loss_lambda": args.pix2pix_2_second_loss_lambda, 325 326 # ---- PhysicsFormer Model 2 327 "physicsformer_in_channels_2": args.physicsformer_in_channels_2, 328 "physicsformer_out_channels_2": args.physicsformer_out_channels_2, 329 "physicsformer_img_size_2": args.physicsformer_img_size_2, 330 "physicsformer_patch_size_2": args.physicsformer_patch_size_2, 331 "physicsformer_embedded_dim_2": args.physicsformer_embedded_dim_2, 332 "physicsformer_num_blocks_2": args.physicsformer_num_blocks_2, 333 "physicsformer_heads_2": args.physicsformer_heads_2, 334 "physicsformer_mlp_dim_2": args.physicsformer_mlp_dim_2, 335 "physicsformer_dropout_2": args.physicsformer_dropout_2, 336 }
339def backward_model(model, x, y, optimizer, criterion, device, epoch, amp_scaler, gradient_clipping_threshold=None): 340 """ 341 Performs a backward pass, optimizer step, and mixed-precision handling. 342 343 Parameter: 344 - model (nn.Module): 345 Model to train. 346 - x (torch.Tensor): 347 Input tensor. 348 - y (torch.Tensor): 349 Target tensor. 350 - optimizer: 351 Optimizer or tuple of optimizers (for Pix2Pix/ResidualDesignModel). 352 - criterion: 353 Loss function. 354 - device (torch.device): 355 Device to perform computation on. 356 - epoch (int): 357 Current training epoch. 358 - amp_scaler (GradScaler): 359 Gradient scaler for mixed-precision training. 360 - gradient_clipping_threshold (float, optional): 361 Maximum allowed gradient norm for clipping. 362 363 Returns: 364 - loss (torch.Tensor): Computed loss value for the batch. 365 """ 366 if isinstance(model, Pix2Pix): 367 if epoch is None or epoch % 2 == 0: 368 model.discriminator_step(x, y, optimizer[1], amp_scaler, device, gradient_clipping_threshold) 369 loss, _, _ = model.generator_step(x, y, optimizer[0], amp_scaler, device, gradient_clipping_threshold) 370 elif amp_scaler: 371 # reset gradients 372 optimizer.zero_grad() 373 374 with autocast(device_type=device.type): 375 y_predict = model(x) 376 loss = criterion(y_predict, y) 377 amp_scaler.scale(loss).backward() 378 if gradient_clipping_threshold: 379 # Unscale first! 380 amp_scaler.unscale_(optimizer) 381 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clipping_threshold) 382 amp_scaler.step(optimizer) 383 amp_scaler.update() 384 else: 385 # reset gradients 386 optimizer.zero_grad() 387 388 y_predict = model(x) 389 loss = criterion(y_predict, y) 390 loss.backward() 391 if gradient_clipping_threshold: 392 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clipping_threshold) 393 optimizer.step() 394 return loss
Performs a backward pass, optimizer step, and mixed-precision handling.
Parameter:
- model (nn.Module): Model to train.
- x (torch.Tensor): Input tensor.
- y (torch.Tensor): Target tensor.
- optimizer: Optimizer or tuple of optimizers (for Pix2Pix/ResidualDesignModel).
- criterion: Loss function.
- device (torch.device): Device to perform computation on.
- epoch (int): Current training epoch.
- amp_scaler (GradScaler): Gradient scaler for mixed-precision training.
- gradient_clipping_threshold (float, optional): Maximum allowed gradient norm for clipping.
Returns:
- loss (torch.Tensor): Computed loss value for the batch.
398def train_one_epoch(model, loader, optimizer, criterion, device, epoch=None, amp_scaler=None, gradient_clipping_threshold=None): 399 """ 400 Runs one full epoch of training and returns the average loss. 401 402 Parameter: 403 - model (nn.Module): 404 Model to train. 405 - loader (DataLoader): 406 Data loader containing training batches. 407 - optimizer: 408 Optimizer or tuple of optimizers. 409 - criterion: 410 Loss function or tuple of losses (for multi-stage models). 411 - device (torch.device): 412 Device to use for training. 413 - epoch (int, optional): 414 Current epoch index. 415 - amp_scaler (GradScaler, optional): 416 Mixed-precision scaler. 417 - gradient_clipping_threshold (float, optional): 418 Max allowed gradient norm. 419 420 Returns: 421 - avg_loss (float): Average training loss for the epoch. 422 """ 423 # change to train mode -> calc gradients 424 model.train() 425 426 total_loss = 0.0 427 # for x, y in tqdm(loader, desc=f"Epoch {epoch:03}", leave=True, ascii=True mininterval=5): 428 for x, y in loader: 429 430 if not isinstance(model, ResidualDesignModel): 431 x, y = x.to(device), y.to(device) 432 433 if isinstance(model, ResidualDesignModel): 434 base_input, complex_input = x 435 base_target, complex_target, target_= y 436 437 # Basline 438 base_input = base_input.to(device) 439 base_target = base_target.to(device) 440 base_loss = backward_model(model=model.base_model, x=base_input, y=base_target, optimizer=optimizer[0], criterion=criterion[0], device=device, epoch=epoch, amp_scaler=amp_scaler, gradient_clipping_threshold=gradient_clipping_threshold) 441 # del base_input 442 # torch.cuda.empty_cache() 443 444 # Complex 445 complex_input = complex_input.to(device) 446 complex_target = complex_target.to(device) 447 complex_loss = backward_model(model=model.complex_model, x=complex_input, y=complex_target, optimizer=optimizer[1], criterion=criterion[1], device=device, epoch=epoch, amp_scaler=amp_scaler, gradient_clipping_threshold=gradient_clipping_threshold) 448 # del complex_input 449 # torch.cuda.empty_cache() 450 451 # Fusion 452 target_ = target_.to(device) 453 combine_loss = model.combine_net.backward(base_target, complex_target, target_) 454 455 model.backward(base_target, complex_target, target_) 456 457 model.last_base_loss = base_loss 458 model.last_complex_loss = complex_loss 459 model.last_combined_loss = combine_loss 460 loss = base_loss + complex_loss + combine_loss 461 else: 462 loss = backward_model(model=model, x=x, y=y, optimizer=optimizer, criterion=criterion, device=device, epoch=epoch, amp_scaler=amp_scaler, gradient_clipping_threshold=gradient_clipping_threshold) 463 total_loss += loss.item() 464 return total_loss / len(loader)
Runs one full epoch of training and returns the average loss.
Parameter:
- model (nn.Module): Model to train.
- loader (DataLoader): Data loader containing training batches.
- optimizer: Optimizer or tuple of optimizers.
- criterion: Loss function or tuple of losses (for multi-stage models).
- device (torch.device): Device to use for training.
- epoch (int, optional): Current epoch index.
- amp_scaler (GradScaler, optional): Mixed-precision scaler.
- gradient_clipping_threshold (float, optional): Max allowed gradient norm.
Returns:
- avg_loss (float): Average training loss for the epoch.
468@torch.no_grad() 469def evaluate(model, loader, criterion, device, writer=None, epoch=None, save_path=None, cmap="gray", use_tqdm=True): 470 """ 471 Evaluates the model on the validation set and logs results to TensorBoard or MLflow. 472 473 Parameter: 474 - model (nn.Module): 475 Model to evaluate. 476 - loader (DataLoader): 477 Validation data loader. 478 - criterion: 479 Loss function used for evaluation. 480 - device (torch.device): 481 Device to perform inference on. 482 - writer (SummaryWriter, optional): 483 TensorBoard writer for visualization. 484 - epoch (int, optional): 485 Current epoch index for logging. 486 - save_path (str, optional): 487 Directory to save sample images. 488 - cmap (str): 489 Colormap for saved images. 490 - use_tqdm (bool): 491 Whether to use tqdm progress bar. 492 493 Returns: 494 - avg_loss (float): Average validation loss. 495 """ 496 model.eval() 497 498 total_loss = 0.0 499 is_first_round = True 500 501 if use_tqdm: 502 validation_iter = tqdm(loader, desc="Validation", ascii=True, mininterval=3) 503 else: 504 validation_iter = loader 505 506 for x, y in validation_iter: 507 x, y = x.to(device), y.to(device) 508 y_predict = model(x) 509 total_loss += criterion(y_predict, y).item() 510 511 if is_first_round: 512 if writer: 513 # Convert to grid 514 img_grid_input = torchvision.utils.make_grid(x[:4].cpu(), normalize=True, scale_each=True) 515 max_val = max(y_predict.max().item(), 1e-8) 516 if y_predict.ndim == 3: # e.g., [B, H, W] 517 img_grid_pred = torchvision.utils.make_grid(y_predict[:4].unsqueeze(1).float().cpu() / max_val) 518 else: # [B, 1, H, W] 519 img_grid_pred = torchvision.utils.make_grid(y_predict[:4].float().cpu() / max_val) 520 521 max_val = max(y.max().item(), 1e-8) 522 if y.ndim == 3: 523 img_grid_gt = torchvision.utils.make_grid(y[:4].unsqueeze(1).float().cpu() / max_val) 524 else: 525 img_grid_gt = torchvision.utils.make_grid(y[:4].float().cpu() / max_val) 526 527 # Log to TensorBoard 528 writer.add_image("Input", img_grid_input, epoch) 529 writer.add_image("Prediction", img_grid_pred, epoch) 530 writer.add_image("GroundTruth", img_grid_gt, epoch) 531 532 if save_path: 533 # os.makedirs(save_path, exist_ok=True) 534 prediction_path = os.path.join(save_path, f"{epoch}_prediction.png") 535 plt.imsave(prediction_path, 536 y_predict[0].detach().cpu().numpy().squeeze(), cmap=cmap) 537 mlflow.log_artifact(prediction_path, artifact_path="images") 538 539 input_path = os.path.join(save_path, f"{epoch}_input.png") 540 plt.imsave(input_path, 541 x[0][0].detach().cpu().numpy().squeeze(), cmap=cmap) 542 mlflow.log_artifact(input_path, artifact_path="images") 543 544 ground_truth_path = os.path.join(save_path, f"{epoch}_ground_truth.png") 545 plt.imsave(ground_truth_path, 546 y[0].detach().cpu().numpy().squeeze(), cmap=cmap) 547 mlflow.log_artifacts(ground_truth_path, artifact_path="images") 548 549 # alternative direct save: 550 # import numpy as np 551 # import io 552 # from PIL import Image 553 554 # img = (y[0].detach().cpu().squeeze().numpy() * 255).astype(np.uint8) 555 # img_pil = Image.fromarray(img) 556 557 # buf = io.BytesIO() 558 # img_pil.save(buf, format='PNG') 559 # buf.seek(0) 560 561 # mlflow.log_image(image=img_pil, artifact_file=f"images/{epoch}_ground_truth.png") 562 563 is_first_round = False 564 565 return total_loss / len(loader)
Evaluates the model on the validation set and logs results to TensorBoard or MLflow.
Parameter:
- model (nn.Module): Model to evaluate.
- loader (DataLoader): Validation data loader.
- criterion: Loss function used for evaluation.
- device (torch.device): Device to perform inference on.
- writer (SummaryWriter, optional): TensorBoard writer for visualization.
- epoch (int, optional): Current epoch index for logging.
- save_path (str, optional): Directory to save sample images.
- cmap (str): Colormap for saved images.
- use_tqdm (bool): Whether to use tqdm progress bar.
Returns:
- avg_loss (float): Average validation loss.
572def train(args=None): 573 """ 574 Main training loop for image-to-image tasks. 575 576 Workflow: 577 1. Initializes the training and validation datasets based on model type. 578 2. Constructs the model and its loss functions. 579 3. Configures optimizers, learning rate schedulers, and optional warm-up phases. 580 4. Enables mixed precision (AMP) if selected. 581 5. Sets up MLflow experiment tracking and TensorBoard visualization. 582 6. Executes the epoch loop: 583 - Trains the model for one epoch (`train_one_epoch()`). 584 - Optionally evaluates on the validation set. 585 - Logs metrics and learning rates. 586 - Updates the scheduler. 587 - Saves checkpoints (best or periodic). 588 7. Logs the trained model and experiment results to MLflow upon completion. 589 590 Parameters: 591 - args : argparse.Namespace, optional 592 Parsed command-line arguments containing all training configurations. 593 If None, the function will automatically call `parse_args()` to obtain them. 594 595 Returns: 596 - None: The function performs training and logging in-place without returning values. 597 598 Notes: 599 - Automatically handles model-specific configurations (e.g., Pix2Pix discriminator, ResidualDesignModel branches). 600 - Uses `prime.get_time()` to generate time-stamped run names. 601 - Supports gradient clipping and various learning rate schedulers. 602 603 Logging: 604 - **MLflow**: Stores metrics, hyperparameters, checkpoints, and final model. 605 - **TensorBoard**: Logs training/validation losses, learning rates, and sub-loss components. 606 """ 607 print("\n---> Welcome to Image-to-Image Training <---") 608 609 print("\nChecking your Hardware:") 610 print(prime.get_hardware()) 611 612 # Parse arguments 613 if args is None: 614 args = parse_args() 615 616 CURRENT_SAVE_NAME = prime.get_time(pattern="YEAR-MONTH-DAY_HOUR_MINUTE_", time_zone="Europe/Berlin") + args.run_name 617 618 device = torch.device(args.device if torch.cuda.is_available() else "cpu") 619 620 # Dataset loading 621 train_dataset, val_dataset, train_loader, val_loader = get_data(args) 622 623 # Loss 624 criterion = get_loss(loss_name=args.loss, args=args) 625 626 if args.model.lower() == "residual_design_model": 627 criterion = [ 628 criterion, 629 get_loss(loss_name=args.loss_2+"_2", args=args) 630 ] 631 632 # Model Loading 633 model = get_model(args, device, criterion=criterion) 634 635 # get parameter amount 636 n_model_params = 0 637 for cur_model_param in model.parameters(): 638 n_model_params += cur_model_param.numel() 639 640 INPUT_CHANNELS = model.get_input_channels() 641 642 # Optimizer 643 if args.model.lower() == "pix2pix": 644 optimizer = [get_optimizer(optimizer_name=args.optimizer, model=model.generator, lr=args.lr, args=args), 645 get_optimizer(optimizer_name=args.optimizer_2, model=model.discriminator, lr=args.lr, args=args)] 646 elif args.model.lower() == "residual_design_model": 647 optimizer = [get_optimizer(optimizer_name=args.optimizer, model=model.base_model, lr=args.lr, args=args), 648 get_optimizer(optimizer_name=args.optimizer_2, model=model.complex_model, lr=args.lr, args=args)] 649 650 if args.base_model.lower() == "pix2pix": 651 optimizer[0] = [get_optimizer(optimizer_name=args.optimizer, model=model.base_model.generator, lr=args.lr, args=args), 652 get_optimizer(optimizer_name=args.optimizer, model=model.base_model.discriminator, lr=args.lr, args=args)] 653 if args.complex_model.lower() == "pix2pix": 654 optimizer[1] = [get_optimizer(optimizer_name=args.optimizer_2, model=model.complex_model.generator, lr=args.lr, args=args), 655 get_optimizer(optimizer_name=args.optimizer_2, model=model.complex_model.discriminator, lr=args.lr, args=args)] 656 else: 657 optimizer = get_optimizer(optimizer_name=args.optimizer, model=model, lr=args.lr, args=args) 658 659 # Scheduler 660 if args.model.lower() == "residual_design_model": 661 662 if args.base_model.lower() == "pix2pix": 663 scheduler_1 = [get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0][0], args=args), 664 get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0][1], args=args)] 665 else: 666 scheduler_1 = get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0], args=args) 667 668 if args.complex_model.lower() == "pix2pix": 669 scheduler_2 = [get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1][0], args=args), 670 get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1][1], args=args)] 671 else: 672 scheduler_2 = get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1], args=args) 673 674 scheduler = [scheduler_1, scheduler_2] 675 elif args.model.lower() == "pix2pix": 676 scheduler = [get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0], args=args), 677 get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1], args=args)] 678 else: 679 scheduler = get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer, args=args) 680 681 # Warm-Up Scheduler 682 if args.use_warm_up: 683 if isinstance(scheduler, (tuple, list)): 684 new_scheduler = [] 685 for cur_scheduler in scheduler: 686 new_scheduler += [WarmUpScheduler(start_lr=args.warm_up_start_lr, end_lr=args.lr, optimizer=cur_scheduler.optimizer, scheduler=cur_scheduler, step_duration=args.warm_up_step_duration)] 687 688 scheduler = new_scheduler 689 else: 690 if scheduler is not None: 691 cur_optimizer = scheduler.optimizer 692 else: 693 cur_optimizer = None 694 scheduler = WarmUpScheduler(start_lr=args.warm_up_start_lr, end_lr=args.lr, optimizer=cur_optimizer, scheduler=scheduler, step_duration=args.warm_up_step_duration) 695 696 # AMP Scaler 697 if args.activate_amp == False: 698 amp_scaler = None 699 elif args.amp_scaler.lower() == "none": 700 amp_scaler = DummyScaler() 701 elif args.amp_scaler.lower() == "grad": 702 amp_scaler = GradScaler() 703 else: 704 raise ValueError(f"'{args.amp_scaler}' is not an supported scaler.") 705 706 # setup checkpoint saving 707 checkpoint_save_dir = os.path.join(args.checkpoint_save_dir, args.experiment_name, CURRENT_SAVE_NAME) 708 os.makedirs(checkpoint_save_dir, exist_ok=True) 709 shutil.rmtree(checkpoint_save_dir) 710 os.makedirs(checkpoint_save_dir, exist_ok=True) 711 712 # setup intermediate image saving 713 save_path = os.path.join(args.save_path, args.experiment_name, CURRENT_SAVE_NAME) 714 os.makedirs(save_path, exist_ok=True) 715 shutil.rmtree(save_path) 716 os.makedirs(save_path, exist_ok=True) 717 718 # setup gradient clipping 719 if args.gradient_clipping: 720 gradient_clipping_threshold = args.gradient_clipping_threshold 721 else: 722 gradient_clipping_threshold = None 723 724 mlflow.set_experiment(args.experiment_name) 725 # same as: 726 # mlflow.create_experiment(args.experiment_name) 727 # mlflow.get_experiment_by_name(experiment_name) 728 729 # Start MLflow run 730 with mlflow.start_run(run_name=CURRENT_SAVE_NAME): 731 732 # TensorBoard writer 733 tensorboard_path = os.path.join(args.tensorboard_path, args.experiment_name, CURRENT_SAVE_NAME) 734 os.makedirs(tensorboard_path, exist_ok=True) 735 shutil.rmtree(tensorboard_path) 736 os.makedirs(tensorboard_path, exist_ok=True) 737 writer = SummaryWriter(log_dir=tensorboard_path) 738 739 # Log hyperparameters 740 params = get_params(args=args, 741 device=device, 742 n_model_params=n_model_params, 743 current_save_name=CURRENT_SAVE_NAME, 744 checkpoint_save_dir=checkpoint_save_dir) 745 746 # mlflow.log_params(vars(args)) 747 mlflow.log_params(params) 748 749 params_text = "\n".join([f"{k}: {v}" for k, v in params.items()]) 750 writer.add_text("hyperparameters", params_text, 0) 751 752 print(f"Train dataset size: {len(train_dataset)} | Validation dataset size: {len(val_dataset)}") 753 mlflow.log_metrics({ 754 "train_dataset_size": len(train_dataset), 755 "val_dataset_size": len(val_dataset) 756 }) 757 758 759 # log model architecture 760 # -> create dataset for that 761 if isinstance(INPUT_CHANNELS, (list, tuple)): 762 dummy_inference_data = [torch.ones(size=(1, channels, 256, 256)).to(device) for channels in INPUT_CHANNELS] 763 else: 764 dummy_inference_data = torch.ones(size=(1, INPUT_CHANNELS, 256, 256)).to(device) 765 writer.add_graph(model, dummy_inference_data) 766 767 # Run Training 768 last_best_loss = float("inf") 769 try: 770 epoch_iter = tqdm(range(1, args.epochs + 1), desc="Epochs", ascii=True) 771 for epoch in epoch_iter: 772 start_time = time.time() 773 train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch, amp_scaler, gradient_clipping_threshold) 774 duration = time.time() - start_time 775 if epoch % args.validation_interval == 0: 776 val_loss = evaluate(model, val_loader, criterion, device, writer=writer, epoch=epoch, save_path=save_path, cmap=args.cmap, use_tqdm=False) 777 else: 778 val_loss = float("inf") 779 780 val_str = f"{val_loss:.4f}" if epoch % args.validation_interval == 0 else "N/A" 781 epoch_iter.set_postfix(train_loss=f"{train_loss:.4f}", val_loss=val_str, time_needed=f"{duration:.2f}s") 782 # tqdm.write(f" -> Train Loss: {train_loss:.4f} | Val Loss: {val_str} | Time: {duration:.2f}") 783 # \n\n[Epoch {epoch:02}/{args.epochs}] 784 785 # Hint: Tensorboard and mlflow does not like spaces in tags! 786 787 # Log to TensorBoard 788 writer.add_scalar("Time/epoch_duration", duration, epoch) 789 writer.add_scalar("Loss/train", train_loss, epoch) 790 writer.add_scalar("Loss/val", val_loss, epoch) 791 if isinstance(scheduler, list): 792 if isinstance(scheduler[0], list): 793 writer.add_scalar("LR/generator", scheduler[0][0].get_last_lr()[0], epoch) 794 writer.add_scalar("LR/discriminator", scheduler[0][1].get_last_lr()[0], epoch) 795 else: 796 name = "generator" if args.model.lower() == "pix2pix" else "base_model" 797 writer.add_scalar(f"LR/{name}", scheduler[0].get_last_lr()[0], epoch) 798 799 if isinstance(scheduler[1], list): 800 writer.add_scalar("LR/generator", scheduler[1][0].get_last_lr()[0], epoch) 801 writer.add_scalar("LR/discriminator", scheduler[1][1].get_last_lr()[0], epoch) 802 else: 803 name = "discriminator" if args.model.lower() == "pix2pix" else "complex_model" 804 writer.add_scalar(f"LR/{name}", scheduler[1].get_last_lr()[0], epoch) 805 else: 806 writer.add_scalar("LR", scheduler.get_last_lr()[0], epoch) 807 808 # Log to MLflow 809 if type(scheduler) in [list, tuple]: 810 metrics = { 811 "train_loss": train_loss, 812 "val_loss": val_loss, 813 } 814 idx = 0 815 for idx, cur_scheduler in enumerate(scheduler): 816 metrics[f"lr_{idx}"] = cur_scheduler.get_last_lr()[0] 817 mlflow.log_metrics(metrics, step=epoch) 818 else: 819 mlflow.log_metrics({ 820 "train_loss": train_loss, 821 "val_loss": val_loss, 822 "lr": scheduler.get_last_lr()[0] 823 }, step=epoch) 824 825 # add sub losses / loss components 826 if args.model.lower() in ["pix2pix", "residual_design_model"]: 827 losses = model.get_dict() 828 for name, value in losses.items(): 829 writer.add_scalar(f"LossComponents/{name}", value, epoch) 830 mlflow.log_metrics(losses, step=epoch) 831 832 if type(criterion) in [list, tuple]: 833 if args.loss in ["weighted_combined"]: 834 losses = criterion[0].get_dict() 835 for name, value in losses.items(): 836 writer.add_scalar(f"LossComponents/{name}", value, epoch) 837 mlflow.log_metrics(losses, step=epoch) 838 839 if args.loss_2 in ["weighted_combined"] and args.model.lower() in ["residual_design_model"]: 840 losses = criterion[1].get_dict() 841 for name, value in losses.items(): 842 writer.add_scalar(f"LossComponents/{name}", value, epoch) 843 mlflow.log_metrics(losses, step=epoch) 844 else: 845 if args.loss in ["weighted_combined"]: 846 losses = criterion.get_dict() 847 for name, value in losses.items(): 848 writer.add_scalar(f"LossComponents/{name}", value, epoch) 849 mlflow.log_metrics(losses, step=epoch) 850 851 852 853 # Step scheduler 854 if args.model.lower() in ["pix2pix"]: 855 scheduler[0].step() 856 scheduler[1].step() 857 else: 858 scheduler.step() 859 860 # Save Checkpoint 861 if args.save_only_best_model: 862 if val_loss < last_best_loss or (last_best_loss == float("inf") and epoch == args.epochs): 863 last_best_loss = val_loss 864 checkpoint_path = os.path.join(checkpoint_save_dir, f"best_checkpoint.pth") 865 save_checkpoint(args, model, optimizer, scheduler, epoch, checkpoint_path) 866 867 # Log model checkpoint path 868 mlflow.log_artifact(checkpoint_path) 869 elif epoch % args.checkpoint_interval == 0 or epoch == args.epochs: 870 checkpoint_path = os.path.join(checkpoint_save_dir, f"epoch_{epoch}.pth") 871 save_checkpoint(args, model, optimizer, scheduler, epoch, checkpoint_path) 872 873 # Log model checkpoint path 874 mlflow.log_artifact(checkpoint_path) 875 876 # Log final model 877 try: 878 mlflow.pytorch.log_model(model.cpu(), name="model", input_example=dummy_inference_data.cpu().numpy()) 879 except Exception as e: 880 print(e) 881 mlflow.pytorch.log_model(model, name="model") 882 mlflow.end_run() 883 finally: 884 writer.close() 885 886 print("Training completed.")
Main training loop for image-to-image tasks.
Workflow:
- Initializes the training and validation datasets based on model type.
- Constructs the model and its loss functions.
- Configures optimizers, learning rate schedulers, and optional warm-up phases.
- Enables mixed precision (AMP) if selected.
- Sets up MLflow experiment tracking and TensorBoard visualization.
- Executes the epoch loop:
- Trains the model for one epoch (
train_one_epoch()). - Optionally evaluates on the validation set.
- Logs metrics and learning rates.
- Updates the scheduler.
- Saves checkpoints (best or periodic).
- Trains the model for one epoch (
- Logs the trained model and experiment results to MLflow upon completion.
Parameters:
- args : argparse.Namespace, optional
Parsed command-line arguments containing all training configurations.
If None, the function will automatically call
parse_args()to obtain them.
Returns:
- None: The function performs training and logging in-place without returning values.
Notes:
- Automatically handles model-specific configurations (e.g., Pix2Pix discriminator, ResidualDesignModel branches).
- Uses
prime.get_time()to generate time-stamped run names. - Supports gradient clipping and various learning rate schedulers.
Logging:
- MLflow: Stores metrics, hyperparameters, checkpoints, and final model.
- TensorBoard: Logs training/validation losses, learning rates, and sub-loss components.