image_to_image.model_interactions.train

Module to train, validate, and evaluate image-to-image models. Supports multiple models, losses, optimizers, schedulers, mixed-precision training, checkpointing, and experiment tracking with MLflow and TensorBoard.

The train function handles full experiment orchestration including:

  • Argument parsing and device setup.
  • Dataset and dataloader initialization.
  • Model, optimizer, loss function, and scheduler setup.
  • Mixed precision (AMP) and warm-up handling.
  • MLflow and TensorBoard logging.
  • Periodic validation and checkpoint saving.

Functions:

  • get_loss
  • get_optimizer
  • get_scheduler
  • backward_model
  • train_one_epoch
  • evaluate
  • train

By Tobia Ippolito

  1"""
  2Module to train, validate, and evaluate image-to-image models. 
  3Supports multiple models, losses, optimizers, schedulers, mixed-precision training,
  4checkpointing, and experiment tracking with MLflow and TensorBoard.
  5
  6The train function handles full experiment orchestration including:
  7- Argument parsing and device setup.
  8- Dataset and dataloader initialization.
  9- Model, optimizer, loss function, and scheduler setup.
 10- Mixed precision (AMP) and warm-up handling.
 11- MLflow and TensorBoard logging.
 12- Periodic validation and checkpoint saving.
 13
 14Functions:
 15- get_loss
 16- get_optimizer
 17- get_scheduler
 18- backward_model
 19- train_one_epoch
 20- evaluate
 21- train
 22
 23By Tobia Ippolito
 24"""
 25# ---------------------------
 26#        > Imports <
 27# ---------------------------
 28import os
 29import shutil
 30import time
 31import copy
 32
 33import matplotlib.pyplot as plt
 34
 35from tqdm import tqdm
 36
 37import torch
 38from torch import nn, optim
 39from torch.utils.data import DataLoader
 40from torch.amp import autocast, GradScaler
 41import torchvision
 42
 43# Experiment tracking
 44import mlflow
 45import mlflow.pytorch
 46from torch.utils.tensorboard import SummaryWriter
 47
 48import prime_printer as prime
 49
 50
 51from ..utils.argument_parsing import parse_args
 52from ..utils.model_io import get_model, save_checkpoint
 53
 54from ..data.physgen import PhysGenDataset
 55from ..data.residual_physgen import PhysGenResidualDataset, to_device
 56
 57from ..models.resfcn import ResFCN
 58from ..models.pix2pix import Pix2Pix
 59from ..models.residual_design_model import ResidualDesignModel
 60from ..models.transformer import PhysicFormer
 61
 62from ..losses.weighted_combined_loss import WeightedCombinedLoss
 63from ..scheduler.warm_up import WarmUpScheduler
 64from..amp.dummy_scaler import DummyScaler
 65
 66
 67
 68# ---------------------------
 69#      > Train Helpers <
 70# ---------------------------
 71def get_data(args):
 72    if args.model.lower() == "residual_design_model":
 73        train_dataset = PhysGenResidualDataset(variation=args.data_variation, mode="train", 
 74                                               fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
 75                                               reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
 76        
 77        val_dataset = PhysGenDataset(variation=args.data_variation, mode="validation", input_type="osm", output_type="standard", 
 78                                    fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
 79                                    reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
 80    else:
 81        train_dataset = PhysGenDataset(variation=args.data_variation, mode="train", input_type=args.input_type, output_type=args.output_type, 
 82                                    fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
 83                                    reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
 84        
 85        val_dataset = PhysGenDataset(variation=args.data_variation, mode="validation", input_type=args.input_type, output_type=args.output_type, 
 86                                    fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
 87                                    reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
 88
 89    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
 90    val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
 91
 92    return train_dataset, val_dataset, train_loader, val_loader
 93
 94
 95
 96def get_loss(loss_name, args):
 97    """
 98    Returns a loss function instance based on the provided loss name.
 99
100    Supported losses:
101    - L1 / L1_2
102    - CrossEntropy / CrossEntropy_2
103    - WeightedCombined / WeightedCombined_2
104
105    Parameter:
106    - loss_name (str):
107        Name of the loss function.
108    - args:
109        Parsed command-line arguments with configured loss weights.
110
111    Returns:
112    - criterion (nn.Module): Instantiated loss function.
113    """
114    loss_name = loss_name.lower()
115
116    if loss_name == "l1":
117        criterion = nn.L1Loss()
118    elif loss_name == "l1_2":
119        criterion = nn.L1Loss()
120    elif loss_name == "crossentropy":
121        criterion = nn.CrossEntropyLoss()
122    elif loss_name == "crossentropy_2":
123        criterion = nn.CrossEntropyLoss()
124    elif loss_name == "weighted_combined":
125        criterion = WeightedCombinedLoss( 
126                        silog_lambda=args.wc_loss_silog_lambda, 
127                        weight_silog=args.wc_loss_weight_silog, 
128                        weight_grad=args.wc_loss_weight_grad,
129                        weight_ssim=args.wc_loss_weight_ssim,
130                        weight_edge_aware=args.wc_loss_weight_edge_aware,
131                        weight_l1=args.wc_loss_weight_l1,
132                        weight_var=args.wc_loss_weight_var,
133                        weight_range=args.wc_loss_weight_range,
134                        weight_blur=args.wc_loss_weight_blur
135                    )    
136    elif loss_name == "weighted_combined_2":
137        criterion = WeightedCombinedLoss( 
138                        silog_lambda=args.wc_loss_silog_lambda_2, 
139                        weight_silog=args.wc_loss_weight_silog_2, 
140                        weight_grad=args.wc_loss_weight_grad_2,
141                        weight_ssim=args.wc_loss_weight_ssim_2,
142                        weight_edge_aware=args.wc_loss_weight_edge_aware_2,
143                        weight_l1=args.wc_loss_weight_l1_2,
144                        weight_var=args.wc_loss_weight_var_2,
145                        weight_range=args.wc_loss_weight_range_2,
146                        weight_blur=args.wc_loss_weight_blur_2
147                    )    
148    else:
149        raise ValueError(f"'{loss_name}' is not a supported loss.")
150    
151    return criterion
152
153
154
155def get_optimizer(optimizer_name, model, lr, args):
156    """
157    Returns an optimizer for the given model.
158
159    Supported optimizers:
160    - Adam
161    - AdamW
162
163    Parameter:
164    - optimizer_name (str):
165        Name of the optimizer.
166    - model (nn.Module):
167        Model whose parameters should be optimized.
168    - lr (float):
169        Learning rate.
170    - args:
171        Parsed command-line arguments with optimizer configuration.
172
173    Returns:
174    - optimizer (torch.optim.Optimizer): Instantiated optimizer.
175    """
176    optimizer_name = optimizer_name.lower()
177
178    weight_decay_rate = args.weight_decay_rate if args.weight_decay else 0
179
180    if args.optimizer.lower() == "adam":
181        optimizer = optim.Adam(model.parameters(), lr=lr,  betas=(0.5, 0.999), weight_decay=weight_decay_rate)
182    elif args.optimizer.lower() == "adamw":
183        optimizer = optim.AdamW(model.parameters(), lr=lr,  betas=(0.5, 0.999), weight_decay=weight_decay_rate)
184    else:
185        raise ValueError(f"'{optimizer_name}' is not a supported optimizer.")
186    
187    return optimizer
188
189
190
191def get_scheduler(scheduler_name, optimizer, args):
192    """
193    Returns a learning rate scheduler for the given optimizer.
194
195    Supported schedulers:
196    - StepLR
197    - CosineAnnealingLR
198
199    Parameter:
200    - scheduler_name (str):
201        Name of the scheduler.
202    - optimizer:
203        Optimizer whose learning rate will be managed.
204    - args:
205        Parsed command-line arguments containing scheduler configuration.
206
207    Returns:
208    - scheduler (torch.optim.lr_scheduler): Instantiated scheduler.
209    """
210    scheduler_name = scheduler_name.lower()
211
212    if scheduler_name == "step":
213        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
214    elif scheduler_name == "cosine":
215        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
216    else:
217        raise ValueError(f"'{scheduler_name}' is not a supported scheduler.")
218    
219    return scheduler
220
221
222
223def get_params(args, device, n_model_params, current_save_name, checkpoint_save_dir):
224    return {
225            # General
226            "mode": args.mode,
227            "device": str(device),
228
229            # Training
230            "epochs": args.epochs,
231            "batch_size": args.batch_size,
232            "learning_rate": args.lr,
233            "loss_function": args.loss,
234            "optimizer": args.optimizer,
235            "weight_decay": args.weight_decay,
236            "weight_decay_rate": args.weight_decay_rate,
237            "gradient_clipping": args.gradient_clipping,
238            "gradient_clipping_threshold": args.gradient_clipping_threshold,
239            "scheduler": args.scheduler,
240            "use_warm_up": args.use_warm_up,
241            "warm_up_start_lr": args.warm_up_start_lr,
242            "warm_up_step_duration": args.warm_up_step_duration,
243            "use_amp": args.activate_amp,
244            "amp_scaler": args.amp_scaler,
245            "save_only_best_model": args.save_only_best_model,
246
247            # Loss
248            "wc_loss_silog_lambda": args.wc_loss_silog_lambda,
249            "wc_loss_weight_silog": args.wc_loss_weight_silog,
250            "wc_loss_weight_grad": args.wc_loss_weight_grad,
251            "wc_loss_weight_ssim": args.wc_loss_weight_ssim,
252            "wc_loss_weight_edge_aware": args.wc_loss_weight_edge_aware,
253            "wc_loss_weight_l1": args.wc_loss_weight_l1,
254            "wc_loss_weight_var": args.wc_loss_weight_var,
255            "wc_loss_weight_range": args.wc_loss_weight_range,
256            "wc_loss_weight_blur": args.wc_loss_weight_blur,
257
258            # Model
259            "model": args.model,
260            "n_model_params": n_model_params,
261            "resfcn_in_channels": args.resfcn_in_channels,
262            "resfcn_hidden_channels": args.resfcn_hidden_channels,
263            "resfcn_out_channels": args.resfcn_out_channels,
264            "resfcn_num_blocks": args.resfcn_num_blocks,
265
266            "pix2pix_in_channels": args.pix2pix_in_channels,
267            "pix2pix_hidden_channels": args.pix2pix_hidden_channels,
268            "pix2pix_out_channels": args.pix2pix_out_channels,
269            "pix2pix_second_loss_lambda": args.pix2pix_second_loss_lambda,
270
271            "physicsformer_in_channels": args.physicsformer_in_channels,
272            "physicsformer_out_channels": args.physicsformer_out_channels,
273            "physicsformer_img_size": args.physicsformer_img_size,
274            "physicsformer_patch_size": args.physicsformer_patch_size,
275            "physicsformer_embedded_dim": args.physicsformer_embedded_dim,
276            "physicsformer_num_blocks": args.physicsformer_num_blocks,
277            "physicsformer_heads": args.physicsformer_heads,
278            "physicsformer_mlp_dim": args.physicsformer_mlp_dim,
279            "physicsformer_dropout": args.physicsformer_dropout,
280
281            # Data
282            "data_variation": args.data_variation,
283            "input_type": args.input_type,
284            "output_type": args.output_type,
285            "fake_rgb_output": args.fake_rgb_output,
286            "make_14_dividable_size": args.make_14_dividable_size,
287
288            # Experiment tracking
289            "experiment_name": args.experiment_name,
290            "run_name": current_save_name, # CURRENT_SAVE_NAME,
291            "tensorboard_path": args.tensorboard_path,
292            "save_path": args.save_path,
293            "checkpoint_save_dir": checkpoint_save_dir,
294            "cmap": args.cmap,
295
296            # >> Residual Model <<
297            "base_model": args.base_model,
298            "complex_model": args.complex_model,
299            "combine_mode": args.combine_mode,
300
301            # ---- Loss (2nd branch)
302            "loss_2": args.loss_2,
303            "wc_loss_silog_lambda_2": args.wc_loss_silog_lambda_2,
304            "wc_loss_weight_silog_2": args.wc_loss_weight_silog_2,
305            "wc_loss_weight_grad_2": args.wc_loss_weight_grad_2,
306            "wc_loss_weight_ssim_2": args.wc_loss_weight_ssim_2,
307            "wc_loss_weight_edge_aware_2": args.wc_loss_weight_edge_aware_2,
308            "wc_loss_weight_l1_2": args.wc_loss_weight_l1_2,
309            "wc_loss_weight_var_2": args.wc_loss_weight_var_2,
310            "wc_loss_weight_range_2": args.wc_loss_weight_range_2,
311            "wc_loss_weight_blur_2": args.wc_loss_weight_blur_2,
312
313            # ---- ResFCN Model 2
314            "resfcn_2_in_channels": args.resfcn_2_in_channels,
315            "resfcn_2_hidden_channels": args.resfcn_2_hidden_channels,
316            "resfcn_2_out_channels": args.resfcn_2_out_channels,
317            "resfcn_2_num_blocks": args.resfcn_2_num_blocks,
318
319            # ---- Pix2Pix Model 2
320            "pix2pix_2_in_channels": args.pix2pix_2_in_channels,
321            "pix2pix_2_hidden_channels": args.pix2pix_2_hidden_channels,
322            "pix2pix_2_out_channels": args.pix2pix_2_out_channels,
323            "pix2pix_2_second_loss_lambda": args.pix2pix_2_second_loss_lambda,
324
325            # ---- PhysicsFormer Model 2
326            "physicsformer_in_channels_2": args.physicsformer_in_channels_2,
327            "physicsformer_out_channels_2": args.physicsformer_out_channels_2,
328            "physicsformer_img_size_2": args.physicsformer_img_size_2,
329            "physicsformer_patch_size_2": args.physicsformer_patch_size_2,
330            "physicsformer_embedded_dim_2": args.physicsformer_embedded_dim_2,
331            "physicsformer_num_blocks_2": args.physicsformer_num_blocks_2,
332            "physicsformer_heads_2": args.physicsformer_heads_2,
333            "physicsformer_mlp_dim_2": args.physicsformer_mlp_dim_2,
334            "physicsformer_dropout_2": args.physicsformer_dropout_2,
335        }
336
337
338def backward_model(model, x, y, optimizer, criterion, device, epoch, amp_scaler, gradient_clipping_threshold=None):
339    """
340    Performs a backward pass, optimizer step, and mixed-precision handling.
341
342    Parameter:
343    - model (nn.Module):
344        Model to train.
345    - x (torch.Tensor):
346        Input tensor.
347    - y (torch.Tensor):
348        Target tensor.
349    - optimizer:
350        Optimizer or tuple of optimizers (for Pix2Pix/ResidualDesignModel).
351    - criterion:
352        Loss function.
353    - device (torch.device):
354        Device to perform computation on.
355    - epoch (int):
356        Current training epoch.
357    - amp_scaler (GradScaler):
358        Gradient scaler for mixed-precision training.
359    - gradient_clipping_threshold (float, optional):
360        Maximum allowed gradient norm for clipping.
361
362    Returns:
363    - loss (torch.Tensor): Computed loss value for the batch.
364    """
365    if isinstance(model, Pix2Pix):
366        if epoch is None or epoch % 2 == 0:
367            model.discriminator_step(x, y, optimizer[1], amp_scaler, device, gradient_clipping_threshold)
368        loss, _, _ = model.generator_step(x, y, optimizer[0], amp_scaler, device, gradient_clipping_threshold)
369    elif amp_scaler:
370        # reset gradients 
371        optimizer.zero_grad()
372
373        with autocast(device_type=device.type):
374            y_predict = model(x)
375            loss = criterion(y_predict, y)
376        amp_scaler.scale(loss).backward()
377        if gradient_clipping_threshold:
378            # Unscale first!
379            amp_scaler.unscale_(optimizer)
380            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clipping_threshold)
381        amp_scaler.step(optimizer)
382        amp_scaler.update()
383    else:
384        # reset gradients 
385        optimizer.zero_grad()
386
387        y_predict = model(x)
388        loss = criterion(y_predict, y)
389        loss.backward()
390        if gradient_clipping_threshold:
391            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clipping_threshold)
392        optimizer.step()
393    return loss
394
395
396
397def train_one_epoch(model, loader, optimizer, criterion, device, epoch=None, amp_scaler=None, gradient_clipping_threshold=None):
398    """
399    Runs one full epoch of training and returns the average loss.
400
401    Parameter:
402    - model (nn.Module):
403        Model to train.
404    - loader (DataLoader):
405        Data loader containing training batches.
406    - optimizer:
407        Optimizer or tuple of optimizers.
408    - criterion:
409        Loss function or tuple of losses (for multi-stage models).
410    - device (torch.device):
411        Device to use for training.
412    - epoch (int, optional):
413        Current epoch index.
414    - amp_scaler (GradScaler, optional):
415        Mixed-precision scaler.
416    - gradient_clipping_threshold (float, optional):
417        Max allowed gradient norm.
418
419    Returns:
420    - avg_loss (float): Average training loss for the epoch.
421    """
422    # change to train mode -> calc gradients
423    model.train()
424
425    total_loss = 0.0
426    # for x, y in tqdm(loader, desc=f"Epoch {epoch:03}", leave=True, ascii=True mininterval=5):
427    for x, y in loader:
428
429        if not isinstance(model, ResidualDesignModel):
430            x, y = x.to(device), y.to(device)
431
432        if isinstance(model, ResidualDesignModel):
433            base_input, complex_input = x
434            base_target, complex_target, target_= y
435
436            # Basline
437            base_input = base_input.to(device)
438            base_target = base_target.to(device)
439            base_loss = backward_model(model=model.base_model, x=base_input, y=base_target, optimizer=optimizer[0], criterion=criterion[0], device=device, epoch=epoch, amp_scaler=amp_scaler, gradient_clipping_threshold=gradient_clipping_threshold)
440            # del base_input
441            # torch.cuda.empty_cache()
442
443            # Complex
444            complex_input = complex_input.to(device)
445            complex_target = complex_target.to(device)
446            complex_loss = backward_model(model=model.complex_model, x=complex_input, y=complex_target, optimizer=optimizer[1], criterion=criterion[1], device=device, epoch=epoch, amp_scaler=amp_scaler, gradient_clipping_threshold=gradient_clipping_threshold)
447            # del complex_input
448            # torch.cuda.empty_cache()
449
450            # Fusion
451            target_ = target_.to(device)
452            combine_loss = model.combine_net.backward(base_target, complex_target, target_)
453
454            model.backward(base_target, complex_target, target_)
455
456            model.last_base_loss = base_loss
457            model.last_complex_loss = complex_loss
458            model.last_combined_loss = combine_loss
459            loss = base_loss + complex_loss + combine_loss
460        else:
461            loss = backward_model(model=model, x=x, y=y, optimizer=optimizer, criterion=criterion, device=device, epoch=epoch, amp_scaler=amp_scaler, gradient_clipping_threshold=gradient_clipping_threshold)
462        total_loss += loss.item()
463    return total_loss / len(loader)
464
465
466
467@torch.no_grad()
468def evaluate(model, loader, criterion, device, writer=None, epoch=None, save_path=None, cmap="gray", use_tqdm=True):
469    """
470    Evaluates the model on the validation set and logs results to TensorBoard or MLflow.
471
472    Parameter:
473    - model (nn.Module):
474        Model to evaluate.
475    - loader (DataLoader):
476        Validation data loader.
477    - criterion:
478        Loss function used for evaluation.
479    - device (torch.device):
480        Device to perform inference on.
481    - writer (SummaryWriter, optional):
482        TensorBoard writer for visualization.
483    - epoch (int, optional):
484        Current epoch index for logging.
485    - save_path (str, optional):
486        Directory to save sample images.
487    - cmap (str):
488        Colormap for saved images.
489    - use_tqdm (bool):
490        Whether to use tqdm progress bar.
491
492    Returns:
493    - avg_loss (float): Average validation loss.
494    """
495    model.eval()
496
497    total_loss = 0.0
498    is_first_round = True
499
500    if use_tqdm:
501        validation_iter = tqdm(loader, desc="Validation", ascii=True, mininterval=3)
502    else:
503        validation_iter = loader
504
505    for x, y in validation_iter:
506        x, y = x.to(device), y.to(device)
507        y_predict = model(x)
508        total_loss += criterion(y_predict, y).item()
509
510        if is_first_round:
511            if writer:
512                # Convert to grid
513                img_grid_input = torchvision.utils.make_grid(x[:4].cpu(), normalize=True, scale_each=True)
514                max_val = max(y_predict.max().item(), 1e-8)
515                if y_predict.ndim == 3:  # e.g., [B, H, W]
516                    img_grid_pred = torchvision.utils.make_grid(y_predict[:4].unsqueeze(1).float().cpu() / max_val)
517                else:  # [B, 1, H, W]
518                    img_grid_pred = torchvision.utils.make_grid(y_predict[:4].float().cpu() / max_val)
519
520                max_val = max(y.max().item(), 1e-8)
521                if y.ndim == 3:
522                    img_grid_gt = torchvision.utils.make_grid(y[:4].unsqueeze(1).float().cpu() / max_val)
523                else:
524                    img_grid_gt = torchvision.utils.make_grid(y[:4].float().cpu() / max_val)
525
526                # Log to TensorBoard
527                writer.add_image("Input", img_grid_input, epoch)
528                writer.add_image("Prediction", img_grid_pred, epoch)
529                writer.add_image("GroundTruth", img_grid_gt, epoch)
530
531            if save_path:
532                # os.makedirs(save_path, exist_ok=True)
533                prediction_path = os.path.join(save_path, f"{epoch}_prediction.png")
534                plt.imsave(prediction_path, 
535                        y_predict[0].detach().cpu().numpy().squeeze(), cmap=cmap)
536                mlflow.log_artifact(prediction_path, artifact_path="images")
537
538                input_path = os.path.join(save_path, f"{epoch}_input.png")
539                plt.imsave(input_path, 
540                        x[0][0].detach().cpu().numpy().squeeze(), cmap=cmap)
541                mlflow.log_artifact(input_path, artifact_path="images")
542
543                ground_truth_path = os.path.join(save_path, f"{epoch}_ground_truth.png")
544                plt.imsave(ground_truth_path, 
545                        y[0].detach().cpu().numpy().squeeze(), cmap=cmap)
546                mlflow.log_artifacts(ground_truth_path, artifact_path="images")
547
548                # alternative direct save:
549                # import numpy as np
550                # import io
551                # from PIL import Image
552
553                # img = (y[0].detach().cpu().squeeze().numpy() * 255).astype(np.uint8)
554                # img_pil = Image.fromarray(img)
555
556                # buf = io.BytesIO()
557                # img_pil.save(buf, format='PNG')
558                # buf.seek(0)
559
560                # mlflow.log_image(image=img_pil, artifact_file=f"images/{epoch}_ground_truth.png")
561                
562        is_first_round = False
563
564    return total_loss / len(loader)
565
566
567
568# ---------------------------
569#        > Train Main <
570# ---------------------------
571def train(args=None):
572    """
573    Main training loop for image-to-image tasks.
574
575    Workflow:
576    1. Initializes the training and validation datasets based on model type.
577    2. Constructs the model and its loss functions.
578    3. Configures optimizers, learning rate schedulers, and optional warm-up phases.
579    4. Enables mixed precision (AMP) if selected.
580    5. Sets up MLflow experiment tracking and TensorBoard visualization.
581    6. Executes the epoch loop:
582        - Trains the model for one epoch (`train_one_epoch()`).
583        - Optionally evaluates on the validation set.
584        - Logs metrics and learning rates.
585        - Updates the scheduler.
586        - Saves checkpoints (best or periodic).
587    7. Logs the trained model and experiment results to MLflow upon completion.
588
589    Parameters:
590    - args : argparse.Namespace, optional
591        Parsed command-line arguments containing all training configurations.
592        If None, the function will automatically call `parse_args()` to obtain them.
593    
594    Returns:
595    - None: The function performs training and logging in-place without returning values.
596
597    Notes:
598    - Automatically handles model-specific configurations (e.g., Pix2Pix discriminator, ResidualDesignModel branches).
599    - Uses `prime.get_time()` to generate time-stamped run names.
600    - Supports gradient clipping and various learning rate schedulers.
601
602    Logging:
603    - **MLflow**: Stores metrics, hyperparameters, checkpoints, and final model.
604    - **TensorBoard**: Logs training/validation losses, learning rates, and sub-loss components.
605    """
606    print("\n---> Welcome to Image-to-Image Training <---")
607
608    print("\nChecking your Hardware:")
609    print(prime.get_hardware())
610
611    # Parse arguments
612    if args is None:
613        args = parse_args()
614
615    CURRENT_SAVE_NAME = prime.get_time(pattern="YEAR-MONTH-DAY_HOUR_MINUTE_", time_zone="Europe/Berlin") + args.run_name
616
617    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
618
619    # Dataset loading
620    train_dataset, val_dataset, train_loader, val_loader = get_data(args)
621
622    # Loss
623    criterion = get_loss(loss_name=args.loss, args=args)
624    
625    if args.model.lower() == "residual_design_model":
626        criterion = [
627            criterion,
628            get_loss(loss_name=args.loss_2+"_2", args=args)
629        ]
630
631    # Model Loading
632    model = get_model(args, device, criterion=criterion)
633
634    # get parameter amount
635    n_model_params = 0
636    for cur_model_param in model.parameters():
637        n_model_params += cur_model_param.numel()
638
639    INPUT_CHANNELS = model.get_input_channels()
640
641    # Optimizer
642    if args.model.lower() == "pix2pix":
643        optimizer = [get_optimizer(optimizer_name=args.optimizer, model=model.generator, lr=args.lr, args=args),
644                     get_optimizer(optimizer_name=args.optimizer_2, model=model.discriminator, lr=args.lr, args=args)]
645    elif args.model.lower() == "residual_design_model":
646        optimizer = [get_optimizer(optimizer_name=args.optimizer, model=model.base_model, lr=args.lr, args=args),
647                     get_optimizer(optimizer_name=args.optimizer_2, model=model.complex_model, lr=args.lr, args=args)]
648        
649        if args.base_model.lower() == "pix2pix":
650            optimizer[0] = [get_optimizer(optimizer_name=args.optimizer, model=model.base_model.generator, lr=args.lr, args=args), 
651                            get_optimizer(optimizer_name=args.optimizer, model=model.base_model.discriminator, lr=args.lr, args=args)]
652        if args.complex_model.lower() == "pix2pix":
653            optimizer[1] = [get_optimizer(optimizer_name=args.optimizer_2, model=model.complex_model.generator, lr=args.lr, args=args),
654                            get_optimizer(optimizer_name=args.optimizer_2, model=model.complex_model.discriminator, lr=args.lr, args=args)]
655    else:
656        optimizer = get_optimizer(optimizer_name=args.optimizer, model=model, lr=args.lr, args=args)
657
658    # Scheduler
659    if args.model.lower() == "residual_design_model":
660
661        if args.base_model.lower() == "pix2pix":
662            scheduler_1 = [get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0][0], args=args),
663                           get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0][1], args=args)]
664        else:
665            scheduler_1 = get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0], args=args)
666            
667        if args.complex_model.lower() == "pix2pix":
668            scheduler_2 = [get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1][0], args=args),
669                           get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1][1], args=args)]
670        else:
671            scheduler_2 = get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1], args=args)
672            
673            scheduler = [scheduler_1, scheduler_2]
674    elif args.model.lower() == "pix2pix":
675        scheduler = [get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0], args=args),
676                     get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1], args=args)]
677    else:
678        scheduler = get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer, args=args)
679
680    # Warm-Up Scheduler
681    if args.use_warm_up:
682        if isinstance(scheduler, (tuple, list)):
683            new_scheduler = []
684            for cur_scheduler in scheduler:
685                new_scheduler += [WarmUpScheduler(start_lr=args.warm_up_start_lr, end_lr=args.lr, optimizer=cur_scheduler.optimizer, scheduler=cur_scheduler, step_duration=args.warm_up_step_duration)]
686
687            scheduler = new_scheduler
688        else:
689            if scheduler is not None:
690                cur_optimizer = scheduler.optimizer
691            else:
692                cur_optimizer = None
693            scheduler = WarmUpScheduler(start_lr=args.warm_up_start_lr, end_lr=args.lr, optimizer=cur_optimizer, scheduler=scheduler, step_duration=args.warm_up_step_duration)
694
695    # AMP Scaler
696    if args.activate_amp == False:
697        amp_scaler = None
698    elif args.amp_scaler.lower() == "none":
699        amp_scaler = DummyScaler()
700    elif args.amp_scaler.lower() == "grad":
701        amp_scaler = GradScaler()
702    else:
703        raise ValueError(f"'{args.amp_scaler}' is not an supported scaler.")
704
705    # setup checkpoint saving
706    checkpoint_save_dir = os.path.join(args.checkpoint_save_dir, args.experiment_name, CURRENT_SAVE_NAME)
707    os.makedirs(checkpoint_save_dir, exist_ok=True)
708    shutil.rmtree(checkpoint_save_dir)
709    os.makedirs(checkpoint_save_dir, exist_ok=True)
710
711    # setup intermediate image saving
712    save_path = os.path.join(args.save_path, args.experiment_name, CURRENT_SAVE_NAME)
713    os.makedirs(save_path, exist_ok=True)
714    shutil.rmtree(save_path)
715    os.makedirs(save_path, exist_ok=True)
716
717    # setup gradient clipping
718    if args.gradient_clipping:
719        gradient_clipping_threshold = args.gradient_clipping_threshold
720    else:
721        gradient_clipping_threshold = None
722
723    mlflow.set_experiment(args.experiment_name)
724    # same as:
725        # mlflow.create_experiment(args.experiment_name)
726        # mlflow.get_experiment_by_name(experiment_name)
727
728    # Start MLflow run
729    with mlflow.start_run(run_name=CURRENT_SAVE_NAME):
730
731        # TensorBoard writer
732        tensorboard_path = os.path.join(args.tensorboard_path, args.experiment_name, CURRENT_SAVE_NAME)
733        os.makedirs(tensorboard_path, exist_ok=True)
734        shutil.rmtree(tensorboard_path)
735        os.makedirs(tensorboard_path, exist_ok=True)
736        writer = SummaryWriter(log_dir=tensorboard_path)
737
738        # Log hyperparameters
739        params =  get_params(args=args, 
740                             device=device, 
741                             n_model_params=n_model_params, 
742                             current_save_name=CURRENT_SAVE_NAME, 
743                             checkpoint_save_dir=checkpoint_save_dir)
744        
745        # mlflow.log_params(vars(args))
746        mlflow.log_params(params)
747
748        params_text = "\n".join([f"{k}: {v}" for k, v in params.items()])
749        writer.add_text("hyperparameters", params_text, 0)
750
751        print(f"Train dataset size: {len(train_dataset)} | Validation dataset size: {len(val_dataset)}")
752        mlflow.log_metrics({
753            "train_dataset_size": len(train_dataset),
754            "val_dataset_size": len(val_dataset)
755        })
756
757
758        # log model architecture
759        #    -> create dataset for that
760        if isinstance(INPUT_CHANNELS, (list, tuple)):
761            dummy_inference_data = [torch.ones(size=(1, channels, 256, 256)).to(device) for channels in INPUT_CHANNELS]
762        else:
763            dummy_inference_data = torch.ones(size=(1, INPUT_CHANNELS, 256, 256)).to(device)
764        writer.add_graph(model, dummy_inference_data)
765
766        # Run Training
767        last_best_loss = float("inf")
768        try:
769            epoch_iter = tqdm(range(1, args.epochs + 1), desc="Epochs", ascii=True)
770            for epoch in epoch_iter:
771                start_time = time.time()
772                train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch, amp_scaler, gradient_clipping_threshold)
773                duration = time.time() - start_time
774                if epoch % args.validation_interval == 0:
775                    val_loss = evaluate(model, val_loader, criterion, device, writer=writer, epoch=epoch, save_path=save_path, cmap=args.cmap, use_tqdm=False)
776                else:
777                    val_loss = float("inf")
778
779                val_str = f"{val_loss:.4f}" if epoch % args.validation_interval == 0 else "N/A"
780                epoch_iter.set_postfix(train_loss=f"{train_loss:.4f}", val_loss=val_str, time_needed=f"{duration:.2f}s")
781                # tqdm.write(f" -> Train Loss: {train_loss:.4f} | Val Loss: {val_str} | Time: {duration:.2f}")
782                # \n\n[Epoch {epoch:02}/{args.epochs}]
783                
784                # Hint: Tensorboard and mlflow does not like spaces in tags!
785
786                # Log to TensorBoard
787                writer.add_scalar("Time/epoch_duration", duration, epoch)
788                writer.add_scalar("Loss/train", train_loss, epoch)
789                writer.add_scalar("Loss/val", val_loss, epoch)
790                if isinstance(scheduler, list):
791                    if isinstance(scheduler[0], list):
792                        writer.add_scalar("LR/generator", scheduler[0][0].get_last_lr()[0], epoch)
793                        writer.add_scalar("LR/discriminator", scheduler[0][1].get_last_lr()[0], epoch)
794                    else:
795                        name = "generator" if args.model.lower() == "pix2pix" else "base_model"
796                        writer.add_scalar(f"LR/{name}", scheduler[0].get_last_lr()[0], epoch)
797
798                    if isinstance(scheduler[1], list):
799                        writer.add_scalar("LR/generator", scheduler[1][0].get_last_lr()[0], epoch)
800                        writer.add_scalar("LR/discriminator", scheduler[1][1].get_last_lr()[0], epoch)
801                    else:
802                        name = "discriminator" if args.model.lower() == "pix2pix" else "complex_model"
803                        writer.add_scalar(f"LR/{name}", scheduler[1].get_last_lr()[0], epoch)
804                else:
805                    writer.add_scalar("LR", scheduler.get_last_lr()[0], epoch)
806
807                # Log to MLflow
808                if type(scheduler) in [list, tuple]:
809                    metrics = {
810                        "train_loss": train_loss,
811                        "val_loss": val_loss,
812                    }
813                    idx = 0
814                    for idx, cur_scheduler in enumerate(scheduler):
815                        metrics[f"lr_{idx}"] = cur_scheduler.get_last_lr()[0]
816                    mlflow.log_metrics(metrics, step=epoch)
817                else:
818                    mlflow.log_metrics({
819                        "train_loss": train_loss,
820                        "val_loss": val_loss,
821                        "lr": scheduler.get_last_lr()[0]
822                    }, step=epoch)
823
824                # add sub losses / loss components
825                if args.model.lower() in ["pix2pix", "residual_design_model"]:
826                    losses = model.get_dict()
827                    for name, value in losses.items():
828                        writer.add_scalar(f"LossComponents/{name}", value, epoch)
829                    mlflow.log_metrics(losses, step=epoch)
830
831                if type(criterion) in [list, tuple]:
832                    if args.loss in ["weighted_combined"]:
833                        losses = criterion[0].get_dict()
834                        for name, value in losses.items():
835                            writer.add_scalar(f"LossComponents/{name}", value, epoch)
836                        mlflow.log_metrics(losses, step=epoch)
837
838                    if args.loss_2 in ["weighted_combined"] and args.model.lower() in ["residual_design_model"]:
839                        losses = criterion[1].get_dict()
840                        for name, value in losses.items():
841                            writer.add_scalar(f"LossComponents/{name}", value, epoch)
842                        mlflow.log_metrics(losses, step=epoch)
843                else:
844                    if args.loss in ["weighted_combined"]:
845                        losses = criterion.get_dict()
846                        for name, value in losses.items():
847                            writer.add_scalar(f"LossComponents/{name}", value, epoch)
848                        mlflow.log_metrics(losses, step=epoch)
849
850                
851
852                # Step scheduler
853                if args.model.lower() in ["pix2pix"]:
854                    scheduler[0].step()
855                    scheduler[1].step()
856                else:
857                    scheduler.step()
858
859                # Save Checkpoint
860                if args.save_only_best_model:
861                    if val_loss < last_best_loss or (last_best_loss == float("inf") and epoch == args.epochs):
862                        last_best_loss = val_loss
863                        checkpoint_path = os.path.join(checkpoint_save_dir, f"best_checkpoint.pth")
864                        save_checkpoint(args, model, optimizer, scheduler, epoch, checkpoint_path)
865
866                        # Log model checkpoint path
867                        mlflow.log_artifact(checkpoint_path)
868                elif epoch % args.checkpoint_interval == 0 or epoch == args.epochs:
869                    checkpoint_path = os.path.join(checkpoint_save_dir, f"epoch_{epoch}.pth")
870                    save_checkpoint(args, model, optimizer, scheduler, epoch, checkpoint_path)
871
872                    # Log model checkpoint path
873                    mlflow.log_artifact(checkpoint_path)
874
875            # Log final model
876            try:
877                mlflow.pytorch.log_model(model.cpu(), name="model", input_example=dummy_inference_data.cpu().numpy())
878            except Exception as e:
879                print(e)
880                mlflow.pytorch.log_model(model, name="model")
881            mlflow.end_run()
882        finally:
883            writer.close()
884
885        print("Training completed.")
886
887
888
889if __name__ == "__main__":
890    train()
def get_data(args):
72def get_data(args):
73    if args.model.lower() == "residual_design_model":
74        train_dataset = PhysGenResidualDataset(variation=args.data_variation, mode="train", 
75                                               fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
76                                               reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
77        
78        val_dataset = PhysGenDataset(variation=args.data_variation, mode="validation", input_type="osm", output_type="standard", 
79                                    fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
80                                    reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
81    else:
82        train_dataset = PhysGenDataset(variation=args.data_variation, mode="train", input_type=args.input_type, output_type=args.output_type, 
83                                    fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
84                                    reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
85        
86        val_dataset = PhysGenDataset(variation=args.data_variation, mode="validation", input_type=args.input_type, output_type=args.output_type, 
87                                    fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
88                                    reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
89
90    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
91    val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
92
93    return train_dataset, val_dataset, train_loader, val_loader
def get_loss(loss_name, args):
 97def get_loss(loss_name, args):
 98    """
 99    Returns a loss function instance based on the provided loss name.
100
101    Supported losses:
102    - L1 / L1_2
103    - CrossEntropy / CrossEntropy_2
104    - WeightedCombined / WeightedCombined_2
105
106    Parameter:
107    - loss_name (str):
108        Name of the loss function.
109    - args:
110        Parsed command-line arguments with configured loss weights.
111
112    Returns:
113    - criterion (nn.Module): Instantiated loss function.
114    """
115    loss_name = loss_name.lower()
116
117    if loss_name == "l1":
118        criterion = nn.L1Loss()
119    elif loss_name == "l1_2":
120        criterion = nn.L1Loss()
121    elif loss_name == "crossentropy":
122        criterion = nn.CrossEntropyLoss()
123    elif loss_name == "crossentropy_2":
124        criterion = nn.CrossEntropyLoss()
125    elif loss_name == "weighted_combined":
126        criterion = WeightedCombinedLoss( 
127                        silog_lambda=args.wc_loss_silog_lambda, 
128                        weight_silog=args.wc_loss_weight_silog, 
129                        weight_grad=args.wc_loss_weight_grad,
130                        weight_ssim=args.wc_loss_weight_ssim,
131                        weight_edge_aware=args.wc_loss_weight_edge_aware,
132                        weight_l1=args.wc_loss_weight_l1,
133                        weight_var=args.wc_loss_weight_var,
134                        weight_range=args.wc_loss_weight_range,
135                        weight_blur=args.wc_loss_weight_blur
136                    )    
137    elif loss_name == "weighted_combined_2":
138        criterion = WeightedCombinedLoss( 
139                        silog_lambda=args.wc_loss_silog_lambda_2, 
140                        weight_silog=args.wc_loss_weight_silog_2, 
141                        weight_grad=args.wc_loss_weight_grad_2,
142                        weight_ssim=args.wc_loss_weight_ssim_2,
143                        weight_edge_aware=args.wc_loss_weight_edge_aware_2,
144                        weight_l1=args.wc_loss_weight_l1_2,
145                        weight_var=args.wc_loss_weight_var_2,
146                        weight_range=args.wc_loss_weight_range_2,
147                        weight_blur=args.wc_loss_weight_blur_2
148                    )    
149    else:
150        raise ValueError(f"'{loss_name}' is not a supported loss.")
151    
152    return criterion

Returns a loss function instance based on the provided loss name.

Supported losses:

  • L1 / L1_2
  • CrossEntropy / CrossEntropy_2
  • WeightedCombined / WeightedCombined_2

Parameter:

  • loss_name (str): Name of the loss function.
  • args: Parsed command-line arguments with configured loss weights.

Returns:

  • criterion (nn.Module): Instantiated loss function.
def get_optimizer(optimizer_name, model, lr, args):
156def get_optimizer(optimizer_name, model, lr, args):
157    """
158    Returns an optimizer for the given model.
159
160    Supported optimizers:
161    - Adam
162    - AdamW
163
164    Parameter:
165    - optimizer_name (str):
166        Name of the optimizer.
167    - model (nn.Module):
168        Model whose parameters should be optimized.
169    - lr (float):
170        Learning rate.
171    - args:
172        Parsed command-line arguments with optimizer configuration.
173
174    Returns:
175    - optimizer (torch.optim.Optimizer): Instantiated optimizer.
176    """
177    optimizer_name = optimizer_name.lower()
178
179    weight_decay_rate = args.weight_decay_rate if args.weight_decay else 0
180
181    if args.optimizer.lower() == "adam":
182        optimizer = optim.Adam(model.parameters(), lr=lr,  betas=(0.5, 0.999), weight_decay=weight_decay_rate)
183    elif args.optimizer.lower() == "adamw":
184        optimizer = optim.AdamW(model.parameters(), lr=lr,  betas=(0.5, 0.999), weight_decay=weight_decay_rate)
185    else:
186        raise ValueError(f"'{optimizer_name}' is not a supported optimizer.")
187    
188    return optimizer

Returns an optimizer for the given model.

Supported optimizers:

  • Adam
  • AdamW

Parameter:

  • optimizer_name (str): Name of the optimizer.
  • model (nn.Module): Model whose parameters should be optimized.
  • lr (float): Learning rate.
  • args: Parsed command-line arguments with optimizer configuration.

Returns:

  • optimizer (torch.optim.Optimizer): Instantiated optimizer.
def get_scheduler(scheduler_name, optimizer, args):
192def get_scheduler(scheduler_name, optimizer, args):
193    """
194    Returns a learning rate scheduler for the given optimizer.
195
196    Supported schedulers:
197    - StepLR
198    - CosineAnnealingLR
199
200    Parameter:
201    - scheduler_name (str):
202        Name of the scheduler.
203    - optimizer:
204        Optimizer whose learning rate will be managed.
205    - args:
206        Parsed command-line arguments containing scheduler configuration.
207
208    Returns:
209    - scheduler (torch.optim.lr_scheduler): Instantiated scheduler.
210    """
211    scheduler_name = scheduler_name.lower()
212
213    if scheduler_name == "step":
214        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
215    elif scheduler_name == "cosine":
216        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
217    else:
218        raise ValueError(f"'{scheduler_name}' is not a supported scheduler.")
219    
220    return scheduler

Returns a learning rate scheduler for the given optimizer.

Supported schedulers:

  • StepLR
  • CosineAnnealingLR

Parameter:

  • scheduler_name (str): Name of the scheduler.
  • optimizer: Optimizer whose learning rate will be managed.
  • args: Parsed command-line arguments containing scheduler configuration.

Returns:

  • scheduler (torch.optim.lr_scheduler): Instantiated scheduler.
def get_params(args, device, n_model_params, current_save_name, checkpoint_save_dir):
224def get_params(args, device, n_model_params, current_save_name, checkpoint_save_dir):
225    return {
226            # General
227            "mode": args.mode,
228            "device": str(device),
229
230            # Training
231            "epochs": args.epochs,
232            "batch_size": args.batch_size,
233            "learning_rate": args.lr,
234            "loss_function": args.loss,
235            "optimizer": args.optimizer,
236            "weight_decay": args.weight_decay,
237            "weight_decay_rate": args.weight_decay_rate,
238            "gradient_clipping": args.gradient_clipping,
239            "gradient_clipping_threshold": args.gradient_clipping_threshold,
240            "scheduler": args.scheduler,
241            "use_warm_up": args.use_warm_up,
242            "warm_up_start_lr": args.warm_up_start_lr,
243            "warm_up_step_duration": args.warm_up_step_duration,
244            "use_amp": args.activate_amp,
245            "amp_scaler": args.amp_scaler,
246            "save_only_best_model": args.save_only_best_model,
247
248            # Loss
249            "wc_loss_silog_lambda": args.wc_loss_silog_lambda,
250            "wc_loss_weight_silog": args.wc_loss_weight_silog,
251            "wc_loss_weight_grad": args.wc_loss_weight_grad,
252            "wc_loss_weight_ssim": args.wc_loss_weight_ssim,
253            "wc_loss_weight_edge_aware": args.wc_loss_weight_edge_aware,
254            "wc_loss_weight_l1": args.wc_loss_weight_l1,
255            "wc_loss_weight_var": args.wc_loss_weight_var,
256            "wc_loss_weight_range": args.wc_loss_weight_range,
257            "wc_loss_weight_blur": args.wc_loss_weight_blur,
258
259            # Model
260            "model": args.model,
261            "n_model_params": n_model_params,
262            "resfcn_in_channels": args.resfcn_in_channels,
263            "resfcn_hidden_channels": args.resfcn_hidden_channels,
264            "resfcn_out_channels": args.resfcn_out_channels,
265            "resfcn_num_blocks": args.resfcn_num_blocks,
266
267            "pix2pix_in_channels": args.pix2pix_in_channels,
268            "pix2pix_hidden_channels": args.pix2pix_hidden_channels,
269            "pix2pix_out_channels": args.pix2pix_out_channels,
270            "pix2pix_second_loss_lambda": args.pix2pix_second_loss_lambda,
271
272            "physicsformer_in_channels": args.physicsformer_in_channels,
273            "physicsformer_out_channels": args.physicsformer_out_channels,
274            "physicsformer_img_size": args.physicsformer_img_size,
275            "physicsformer_patch_size": args.physicsformer_patch_size,
276            "physicsformer_embedded_dim": args.physicsformer_embedded_dim,
277            "physicsformer_num_blocks": args.physicsformer_num_blocks,
278            "physicsformer_heads": args.physicsformer_heads,
279            "physicsformer_mlp_dim": args.physicsformer_mlp_dim,
280            "physicsformer_dropout": args.physicsformer_dropout,
281
282            # Data
283            "data_variation": args.data_variation,
284            "input_type": args.input_type,
285            "output_type": args.output_type,
286            "fake_rgb_output": args.fake_rgb_output,
287            "make_14_dividable_size": args.make_14_dividable_size,
288
289            # Experiment tracking
290            "experiment_name": args.experiment_name,
291            "run_name": current_save_name, # CURRENT_SAVE_NAME,
292            "tensorboard_path": args.tensorboard_path,
293            "save_path": args.save_path,
294            "checkpoint_save_dir": checkpoint_save_dir,
295            "cmap": args.cmap,
296
297            # >> Residual Model <<
298            "base_model": args.base_model,
299            "complex_model": args.complex_model,
300            "combine_mode": args.combine_mode,
301
302            # ---- Loss (2nd branch)
303            "loss_2": args.loss_2,
304            "wc_loss_silog_lambda_2": args.wc_loss_silog_lambda_2,
305            "wc_loss_weight_silog_2": args.wc_loss_weight_silog_2,
306            "wc_loss_weight_grad_2": args.wc_loss_weight_grad_2,
307            "wc_loss_weight_ssim_2": args.wc_loss_weight_ssim_2,
308            "wc_loss_weight_edge_aware_2": args.wc_loss_weight_edge_aware_2,
309            "wc_loss_weight_l1_2": args.wc_loss_weight_l1_2,
310            "wc_loss_weight_var_2": args.wc_loss_weight_var_2,
311            "wc_loss_weight_range_2": args.wc_loss_weight_range_2,
312            "wc_loss_weight_blur_2": args.wc_loss_weight_blur_2,
313
314            # ---- ResFCN Model 2
315            "resfcn_2_in_channels": args.resfcn_2_in_channels,
316            "resfcn_2_hidden_channels": args.resfcn_2_hidden_channels,
317            "resfcn_2_out_channels": args.resfcn_2_out_channels,
318            "resfcn_2_num_blocks": args.resfcn_2_num_blocks,
319
320            # ---- Pix2Pix Model 2
321            "pix2pix_2_in_channels": args.pix2pix_2_in_channels,
322            "pix2pix_2_hidden_channels": args.pix2pix_2_hidden_channels,
323            "pix2pix_2_out_channels": args.pix2pix_2_out_channels,
324            "pix2pix_2_second_loss_lambda": args.pix2pix_2_second_loss_lambda,
325
326            # ---- PhysicsFormer Model 2
327            "physicsformer_in_channels_2": args.physicsformer_in_channels_2,
328            "physicsformer_out_channels_2": args.physicsformer_out_channels_2,
329            "physicsformer_img_size_2": args.physicsformer_img_size_2,
330            "physicsformer_patch_size_2": args.physicsformer_patch_size_2,
331            "physicsformer_embedded_dim_2": args.physicsformer_embedded_dim_2,
332            "physicsformer_num_blocks_2": args.physicsformer_num_blocks_2,
333            "physicsformer_heads_2": args.physicsformer_heads_2,
334            "physicsformer_mlp_dim_2": args.physicsformer_mlp_dim_2,
335            "physicsformer_dropout_2": args.physicsformer_dropout_2,
336        }
def backward_model( model, x, y, optimizer, criterion, device, epoch, amp_scaler, gradient_clipping_threshold=None):
339def backward_model(model, x, y, optimizer, criterion, device, epoch, amp_scaler, gradient_clipping_threshold=None):
340    """
341    Performs a backward pass, optimizer step, and mixed-precision handling.
342
343    Parameter:
344    - model (nn.Module):
345        Model to train.
346    - x (torch.Tensor):
347        Input tensor.
348    - y (torch.Tensor):
349        Target tensor.
350    - optimizer:
351        Optimizer or tuple of optimizers (for Pix2Pix/ResidualDesignModel).
352    - criterion:
353        Loss function.
354    - device (torch.device):
355        Device to perform computation on.
356    - epoch (int):
357        Current training epoch.
358    - amp_scaler (GradScaler):
359        Gradient scaler for mixed-precision training.
360    - gradient_clipping_threshold (float, optional):
361        Maximum allowed gradient norm for clipping.
362
363    Returns:
364    - loss (torch.Tensor): Computed loss value for the batch.
365    """
366    if isinstance(model, Pix2Pix):
367        if epoch is None or epoch % 2 == 0:
368            model.discriminator_step(x, y, optimizer[1], amp_scaler, device, gradient_clipping_threshold)
369        loss, _, _ = model.generator_step(x, y, optimizer[0], amp_scaler, device, gradient_clipping_threshold)
370    elif amp_scaler:
371        # reset gradients 
372        optimizer.zero_grad()
373
374        with autocast(device_type=device.type):
375            y_predict = model(x)
376            loss = criterion(y_predict, y)
377        amp_scaler.scale(loss).backward()
378        if gradient_clipping_threshold:
379            # Unscale first!
380            amp_scaler.unscale_(optimizer)
381            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clipping_threshold)
382        amp_scaler.step(optimizer)
383        amp_scaler.update()
384    else:
385        # reset gradients 
386        optimizer.zero_grad()
387
388        y_predict = model(x)
389        loss = criterion(y_predict, y)
390        loss.backward()
391        if gradient_clipping_threshold:
392            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clipping_threshold)
393        optimizer.step()
394    return loss

Performs a backward pass, optimizer step, and mixed-precision handling.

Parameter:

  • model (nn.Module): Model to train.
  • x (torch.Tensor): Input tensor.
  • y (torch.Tensor): Target tensor.
  • optimizer: Optimizer or tuple of optimizers (for Pix2Pix/ResidualDesignModel).
  • criterion: Loss function.
  • device (torch.device): Device to perform computation on.
  • epoch (int): Current training epoch.
  • amp_scaler (GradScaler): Gradient scaler for mixed-precision training.
  • gradient_clipping_threshold (float, optional): Maximum allowed gradient norm for clipping.

Returns:

  • loss (torch.Tensor): Computed loss value for the batch.
def train_one_epoch( model, loader, optimizer, criterion, device, epoch=None, amp_scaler=None, gradient_clipping_threshold=None):
398def train_one_epoch(model, loader, optimizer, criterion, device, epoch=None, amp_scaler=None, gradient_clipping_threshold=None):
399    """
400    Runs one full epoch of training and returns the average loss.
401
402    Parameter:
403    - model (nn.Module):
404        Model to train.
405    - loader (DataLoader):
406        Data loader containing training batches.
407    - optimizer:
408        Optimizer or tuple of optimizers.
409    - criterion:
410        Loss function or tuple of losses (for multi-stage models).
411    - device (torch.device):
412        Device to use for training.
413    - epoch (int, optional):
414        Current epoch index.
415    - amp_scaler (GradScaler, optional):
416        Mixed-precision scaler.
417    - gradient_clipping_threshold (float, optional):
418        Max allowed gradient norm.
419
420    Returns:
421    - avg_loss (float): Average training loss for the epoch.
422    """
423    # change to train mode -> calc gradients
424    model.train()
425
426    total_loss = 0.0
427    # for x, y in tqdm(loader, desc=f"Epoch {epoch:03}", leave=True, ascii=True mininterval=5):
428    for x, y in loader:
429
430        if not isinstance(model, ResidualDesignModel):
431            x, y = x.to(device), y.to(device)
432
433        if isinstance(model, ResidualDesignModel):
434            base_input, complex_input = x
435            base_target, complex_target, target_= y
436
437            # Basline
438            base_input = base_input.to(device)
439            base_target = base_target.to(device)
440            base_loss = backward_model(model=model.base_model, x=base_input, y=base_target, optimizer=optimizer[0], criterion=criterion[0], device=device, epoch=epoch, amp_scaler=amp_scaler, gradient_clipping_threshold=gradient_clipping_threshold)
441            # del base_input
442            # torch.cuda.empty_cache()
443
444            # Complex
445            complex_input = complex_input.to(device)
446            complex_target = complex_target.to(device)
447            complex_loss = backward_model(model=model.complex_model, x=complex_input, y=complex_target, optimizer=optimizer[1], criterion=criterion[1], device=device, epoch=epoch, amp_scaler=amp_scaler, gradient_clipping_threshold=gradient_clipping_threshold)
448            # del complex_input
449            # torch.cuda.empty_cache()
450
451            # Fusion
452            target_ = target_.to(device)
453            combine_loss = model.combine_net.backward(base_target, complex_target, target_)
454
455            model.backward(base_target, complex_target, target_)
456
457            model.last_base_loss = base_loss
458            model.last_complex_loss = complex_loss
459            model.last_combined_loss = combine_loss
460            loss = base_loss + complex_loss + combine_loss
461        else:
462            loss = backward_model(model=model, x=x, y=y, optimizer=optimizer, criterion=criterion, device=device, epoch=epoch, amp_scaler=amp_scaler, gradient_clipping_threshold=gradient_clipping_threshold)
463        total_loss += loss.item()
464    return total_loss / len(loader)

Runs one full epoch of training and returns the average loss.

Parameter:

  • model (nn.Module): Model to train.
  • loader (DataLoader): Data loader containing training batches.
  • optimizer: Optimizer or tuple of optimizers.
  • criterion: Loss function or tuple of losses (for multi-stage models).
  • device (torch.device): Device to use for training.
  • epoch (int, optional): Current epoch index.
  • amp_scaler (GradScaler, optional): Mixed-precision scaler.
  • gradient_clipping_threshold (float, optional): Max allowed gradient norm.

Returns:

  • avg_loss (float): Average training loss for the epoch.
@torch.no_grad()
def evaluate( model, loader, criterion, device, writer=None, epoch=None, save_path=None, cmap='gray', use_tqdm=True):
468@torch.no_grad()
469def evaluate(model, loader, criterion, device, writer=None, epoch=None, save_path=None, cmap="gray", use_tqdm=True):
470    """
471    Evaluates the model on the validation set and logs results to TensorBoard or MLflow.
472
473    Parameter:
474    - model (nn.Module):
475        Model to evaluate.
476    - loader (DataLoader):
477        Validation data loader.
478    - criterion:
479        Loss function used for evaluation.
480    - device (torch.device):
481        Device to perform inference on.
482    - writer (SummaryWriter, optional):
483        TensorBoard writer for visualization.
484    - epoch (int, optional):
485        Current epoch index for logging.
486    - save_path (str, optional):
487        Directory to save sample images.
488    - cmap (str):
489        Colormap for saved images.
490    - use_tqdm (bool):
491        Whether to use tqdm progress bar.
492
493    Returns:
494    - avg_loss (float): Average validation loss.
495    """
496    model.eval()
497
498    total_loss = 0.0
499    is_first_round = True
500
501    if use_tqdm:
502        validation_iter = tqdm(loader, desc="Validation", ascii=True, mininterval=3)
503    else:
504        validation_iter = loader
505
506    for x, y in validation_iter:
507        x, y = x.to(device), y.to(device)
508        y_predict = model(x)
509        total_loss += criterion(y_predict, y).item()
510
511        if is_first_round:
512            if writer:
513                # Convert to grid
514                img_grid_input = torchvision.utils.make_grid(x[:4].cpu(), normalize=True, scale_each=True)
515                max_val = max(y_predict.max().item(), 1e-8)
516                if y_predict.ndim == 3:  # e.g., [B, H, W]
517                    img_grid_pred = torchvision.utils.make_grid(y_predict[:4].unsqueeze(1).float().cpu() / max_val)
518                else:  # [B, 1, H, W]
519                    img_grid_pred = torchvision.utils.make_grid(y_predict[:4].float().cpu() / max_val)
520
521                max_val = max(y.max().item(), 1e-8)
522                if y.ndim == 3:
523                    img_grid_gt = torchvision.utils.make_grid(y[:4].unsqueeze(1).float().cpu() / max_val)
524                else:
525                    img_grid_gt = torchvision.utils.make_grid(y[:4].float().cpu() / max_val)
526
527                # Log to TensorBoard
528                writer.add_image("Input", img_grid_input, epoch)
529                writer.add_image("Prediction", img_grid_pred, epoch)
530                writer.add_image("GroundTruth", img_grid_gt, epoch)
531
532            if save_path:
533                # os.makedirs(save_path, exist_ok=True)
534                prediction_path = os.path.join(save_path, f"{epoch}_prediction.png")
535                plt.imsave(prediction_path, 
536                        y_predict[0].detach().cpu().numpy().squeeze(), cmap=cmap)
537                mlflow.log_artifact(prediction_path, artifact_path="images")
538
539                input_path = os.path.join(save_path, f"{epoch}_input.png")
540                plt.imsave(input_path, 
541                        x[0][0].detach().cpu().numpy().squeeze(), cmap=cmap)
542                mlflow.log_artifact(input_path, artifact_path="images")
543
544                ground_truth_path = os.path.join(save_path, f"{epoch}_ground_truth.png")
545                plt.imsave(ground_truth_path, 
546                        y[0].detach().cpu().numpy().squeeze(), cmap=cmap)
547                mlflow.log_artifacts(ground_truth_path, artifact_path="images")
548
549                # alternative direct save:
550                # import numpy as np
551                # import io
552                # from PIL import Image
553
554                # img = (y[0].detach().cpu().squeeze().numpy() * 255).astype(np.uint8)
555                # img_pil = Image.fromarray(img)
556
557                # buf = io.BytesIO()
558                # img_pil.save(buf, format='PNG')
559                # buf.seek(0)
560
561                # mlflow.log_image(image=img_pil, artifact_file=f"images/{epoch}_ground_truth.png")
562                
563        is_first_round = False
564
565    return total_loss / len(loader)

Evaluates the model on the validation set and logs results to TensorBoard or MLflow.

Parameter:

  • model (nn.Module): Model to evaluate.
  • loader (DataLoader): Validation data loader.
  • criterion: Loss function used for evaluation.
  • device (torch.device): Device to perform inference on.
  • writer (SummaryWriter, optional): TensorBoard writer for visualization.
  • epoch (int, optional): Current epoch index for logging.
  • save_path (str, optional): Directory to save sample images.
  • cmap (str): Colormap for saved images.
  • use_tqdm (bool): Whether to use tqdm progress bar.

Returns:

  • avg_loss (float): Average validation loss.
def train(args=None):
572def train(args=None):
573    """
574    Main training loop for image-to-image tasks.
575
576    Workflow:
577    1. Initializes the training and validation datasets based on model type.
578    2. Constructs the model and its loss functions.
579    3. Configures optimizers, learning rate schedulers, and optional warm-up phases.
580    4. Enables mixed precision (AMP) if selected.
581    5. Sets up MLflow experiment tracking and TensorBoard visualization.
582    6. Executes the epoch loop:
583        - Trains the model for one epoch (`train_one_epoch()`).
584        - Optionally evaluates on the validation set.
585        - Logs metrics and learning rates.
586        - Updates the scheduler.
587        - Saves checkpoints (best or periodic).
588    7. Logs the trained model and experiment results to MLflow upon completion.
589
590    Parameters:
591    - args : argparse.Namespace, optional
592        Parsed command-line arguments containing all training configurations.
593        If None, the function will automatically call `parse_args()` to obtain them.
594    
595    Returns:
596    - None: The function performs training and logging in-place without returning values.
597
598    Notes:
599    - Automatically handles model-specific configurations (e.g., Pix2Pix discriminator, ResidualDesignModel branches).
600    - Uses `prime.get_time()` to generate time-stamped run names.
601    - Supports gradient clipping and various learning rate schedulers.
602
603    Logging:
604    - **MLflow**: Stores metrics, hyperparameters, checkpoints, and final model.
605    - **TensorBoard**: Logs training/validation losses, learning rates, and sub-loss components.
606    """
607    print("\n---> Welcome to Image-to-Image Training <---")
608
609    print("\nChecking your Hardware:")
610    print(prime.get_hardware())
611
612    # Parse arguments
613    if args is None:
614        args = parse_args()
615
616    CURRENT_SAVE_NAME = prime.get_time(pattern="YEAR-MONTH-DAY_HOUR_MINUTE_", time_zone="Europe/Berlin") + args.run_name
617
618    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
619
620    # Dataset loading
621    train_dataset, val_dataset, train_loader, val_loader = get_data(args)
622
623    # Loss
624    criterion = get_loss(loss_name=args.loss, args=args)
625    
626    if args.model.lower() == "residual_design_model":
627        criterion = [
628            criterion,
629            get_loss(loss_name=args.loss_2+"_2", args=args)
630        ]
631
632    # Model Loading
633    model = get_model(args, device, criterion=criterion)
634
635    # get parameter amount
636    n_model_params = 0
637    for cur_model_param in model.parameters():
638        n_model_params += cur_model_param.numel()
639
640    INPUT_CHANNELS = model.get_input_channels()
641
642    # Optimizer
643    if args.model.lower() == "pix2pix":
644        optimizer = [get_optimizer(optimizer_name=args.optimizer, model=model.generator, lr=args.lr, args=args),
645                     get_optimizer(optimizer_name=args.optimizer_2, model=model.discriminator, lr=args.lr, args=args)]
646    elif args.model.lower() == "residual_design_model":
647        optimizer = [get_optimizer(optimizer_name=args.optimizer, model=model.base_model, lr=args.lr, args=args),
648                     get_optimizer(optimizer_name=args.optimizer_2, model=model.complex_model, lr=args.lr, args=args)]
649        
650        if args.base_model.lower() == "pix2pix":
651            optimizer[0] = [get_optimizer(optimizer_name=args.optimizer, model=model.base_model.generator, lr=args.lr, args=args), 
652                            get_optimizer(optimizer_name=args.optimizer, model=model.base_model.discriminator, lr=args.lr, args=args)]
653        if args.complex_model.lower() == "pix2pix":
654            optimizer[1] = [get_optimizer(optimizer_name=args.optimizer_2, model=model.complex_model.generator, lr=args.lr, args=args),
655                            get_optimizer(optimizer_name=args.optimizer_2, model=model.complex_model.discriminator, lr=args.lr, args=args)]
656    else:
657        optimizer = get_optimizer(optimizer_name=args.optimizer, model=model, lr=args.lr, args=args)
658
659    # Scheduler
660    if args.model.lower() == "residual_design_model":
661
662        if args.base_model.lower() == "pix2pix":
663            scheduler_1 = [get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0][0], args=args),
664                           get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0][1], args=args)]
665        else:
666            scheduler_1 = get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0], args=args)
667            
668        if args.complex_model.lower() == "pix2pix":
669            scheduler_2 = [get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1][0], args=args),
670                           get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1][1], args=args)]
671        else:
672            scheduler_2 = get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1], args=args)
673            
674            scheduler = [scheduler_1, scheduler_2]
675    elif args.model.lower() == "pix2pix":
676        scheduler = [get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer[0], args=args),
677                     get_scheduler(scheduler_name=args.scheduler_2, optimizer=optimizer[1], args=args)]
678    else:
679        scheduler = get_scheduler(scheduler_name=args.scheduler, optimizer=optimizer, args=args)
680
681    # Warm-Up Scheduler
682    if args.use_warm_up:
683        if isinstance(scheduler, (tuple, list)):
684            new_scheduler = []
685            for cur_scheduler in scheduler:
686                new_scheduler += [WarmUpScheduler(start_lr=args.warm_up_start_lr, end_lr=args.lr, optimizer=cur_scheduler.optimizer, scheduler=cur_scheduler, step_duration=args.warm_up_step_duration)]
687
688            scheduler = new_scheduler
689        else:
690            if scheduler is not None:
691                cur_optimizer = scheduler.optimizer
692            else:
693                cur_optimizer = None
694            scheduler = WarmUpScheduler(start_lr=args.warm_up_start_lr, end_lr=args.lr, optimizer=cur_optimizer, scheduler=scheduler, step_duration=args.warm_up_step_duration)
695
696    # AMP Scaler
697    if args.activate_amp == False:
698        amp_scaler = None
699    elif args.amp_scaler.lower() == "none":
700        amp_scaler = DummyScaler()
701    elif args.amp_scaler.lower() == "grad":
702        amp_scaler = GradScaler()
703    else:
704        raise ValueError(f"'{args.amp_scaler}' is not an supported scaler.")
705
706    # setup checkpoint saving
707    checkpoint_save_dir = os.path.join(args.checkpoint_save_dir, args.experiment_name, CURRENT_SAVE_NAME)
708    os.makedirs(checkpoint_save_dir, exist_ok=True)
709    shutil.rmtree(checkpoint_save_dir)
710    os.makedirs(checkpoint_save_dir, exist_ok=True)
711
712    # setup intermediate image saving
713    save_path = os.path.join(args.save_path, args.experiment_name, CURRENT_SAVE_NAME)
714    os.makedirs(save_path, exist_ok=True)
715    shutil.rmtree(save_path)
716    os.makedirs(save_path, exist_ok=True)
717
718    # setup gradient clipping
719    if args.gradient_clipping:
720        gradient_clipping_threshold = args.gradient_clipping_threshold
721    else:
722        gradient_clipping_threshold = None
723
724    mlflow.set_experiment(args.experiment_name)
725    # same as:
726        # mlflow.create_experiment(args.experiment_name)
727        # mlflow.get_experiment_by_name(experiment_name)
728
729    # Start MLflow run
730    with mlflow.start_run(run_name=CURRENT_SAVE_NAME):
731
732        # TensorBoard writer
733        tensorboard_path = os.path.join(args.tensorboard_path, args.experiment_name, CURRENT_SAVE_NAME)
734        os.makedirs(tensorboard_path, exist_ok=True)
735        shutil.rmtree(tensorboard_path)
736        os.makedirs(tensorboard_path, exist_ok=True)
737        writer = SummaryWriter(log_dir=tensorboard_path)
738
739        # Log hyperparameters
740        params =  get_params(args=args, 
741                             device=device, 
742                             n_model_params=n_model_params, 
743                             current_save_name=CURRENT_SAVE_NAME, 
744                             checkpoint_save_dir=checkpoint_save_dir)
745        
746        # mlflow.log_params(vars(args))
747        mlflow.log_params(params)
748
749        params_text = "\n".join([f"{k}: {v}" for k, v in params.items()])
750        writer.add_text("hyperparameters", params_text, 0)
751
752        print(f"Train dataset size: {len(train_dataset)} | Validation dataset size: {len(val_dataset)}")
753        mlflow.log_metrics({
754            "train_dataset_size": len(train_dataset),
755            "val_dataset_size": len(val_dataset)
756        })
757
758
759        # log model architecture
760        #    -> create dataset for that
761        if isinstance(INPUT_CHANNELS, (list, tuple)):
762            dummy_inference_data = [torch.ones(size=(1, channels, 256, 256)).to(device) for channels in INPUT_CHANNELS]
763        else:
764            dummy_inference_data = torch.ones(size=(1, INPUT_CHANNELS, 256, 256)).to(device)
765        writer.add_graph(model, dummy_inference_data)
766
767        # Run Training
768        last_best_loss = float("inf")
769        try:
770            epoch_iter = tqdm(range(1, args.epochs + 1), desc="Epochs", ascii=True)
771            for epoch in epoch_iter:
772                start_time = time.time()
773                train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch, amp_scaler, gradient_clipping_threshold)
774                duration = time.time() - start_time
775                if epoch % args.validation_interval == 0:
776                    val_loss = evaluate(model, val_loader, criterion, device, writer=writer, epoch=epoch, save_path=save_path, cmap=args.cmap, use_tqdm=False)
777                else:
778                    val_loss = float("inf")
779
780                val_str = f"{val_loss:.4f}" if epoch % args.validation_interval == 0 else "N/A"
781                epoch_iter.set_postfix(train_loss=f"{train_loss:.4f}", val_loss=val_str, time_needed=f"{duration:.2f}s")
782                # tqdm.write(f" -> Train Loss: {train_loss:.4f} | Val Loss: {val_str} | Time: {duration:.2f}")
783                # \n\n[Epoch {epoch:02}/{args.epochs}]
784                
785                # Hint: Tensorboard and mlflow does not like spaces in tags!
786
787                # Log to TensorBoard
788                writer.add_scalar("Time/epoch_duration", duration, epoch)
789                writer.add_scalar("Loss/train", train_loss, epoch)
790                writer.add_scalar("Loss/val", val_loss, epoch)
791                if isinstance(scheduler, list):
792                    if isinstance(scheduler[0], list):
793                        writer.add_scalar("LR/generator", scheduler[0][0].get_last_lr()[0], epoch)
794                        writer.add_scalar("LR/discriminator", scheduler[0][1].get_last_lr()[0], epoch)
795                    else:
796                        name = "generator" if args.model.lower() == "pix2pix" else "base_model"
797                        writer.add_scalar(f"LR/{name}", scheduler[0].get_last_lr()[0], epoch)
798
799                    if isinstance(scheduler[1], list):
800                        writer.add_scalar("LR/generator", scheduler[1][0].get_last_lr()[0], epoch)
801                        writer.add_scalar("LR/discriminator", scheduler[1][1].get_last_lr()[0], epoch)
802                    else:
803                        name = "discriminator" if args.model.lower() == "pix2pix" else "complex_model"
804                        writer.add_scalar(f"LR/{name}", scheduler[1].get_last_lr()[0], epoch)
805                else:
806                    writer.add_scalar("LR", scheduler.get_last_lr()[0], epoch)
807
808                # Log to MLflow
809                if type(scheduler) in [list, tuple]:
810                    metrics = {
811                        "train_loss": train_loss,
812                        "val_loss": val_loss,
813                    }
814                    idx = 0
815                    for idx, cur_scheduler in enumerate(scheduler):
816                        metrics[f"lr_{idx}"] = cur_scheduler.get_last_lr()[0]
817                    mlflow.log_metrics(metrics, step=epoch)
818                else:
819                    mlflow.log_metrics({
820                        "train_loss": train_loss,
821                        "val_loss": val_loss,
822                        "lr": scheduler.get_last_lr()[0]
823                    }, step=epoch)
824
825                # add sub losses / loss components
826                if args.model.lower() in ["pix2pix", "residual_design_model"]:
827                    losses = model.get_dict()
828                    for name, value in losses.items():
829                        writer.add_scalar(f"LossComponents/{name}", value, epoch)
830                    mlflow.log_metrics(losses, step=epoch)
831
832                if type(criterion) in [list, tuple]:
833                    if args.loss in ["weighted_combined"]:
834                        losses = criterion[0].get_dict()
835                        for name, value in losses.items():
836                            writer.add_scalar(f"LossComponents/{name}", value, epoch)
837                        mlflow.log_metrics(losses, step=epoch)
838
839                    if args.loss_2 in ["weighted_combined"] and args.model.lower() in ["residual_design_model"]:
840                        losses = criterion[1].get_dict()
841                        for name, value in losses.items():
842                            writer.add_scalar(f"LossComponents/{name}", value, epoch)
843                        mlflow.log_metrics(losses, step=epoch)
844                else:
845                    if args.loss in ["weighted_combined"]:
846                        losses = criterion.get_dict()
847                        for name, value in losses.items():
848                            writer.add_scalar(f"LossComponents/{name}", value, epoch)
849                        mlflow.log_metrics(losses, step=epoch)
850
851                
852
853                # Step scheduler
854                if args.model.lower() in ["pix2pix"]:
855                    scheduler[0].step()
856                    scheduler[1].step()
857                else:
858                    scheduler.step()
859
860                # Save Checkpoint
861                if args.save_only_best_model:
862                    if val_loss < last_best_loss or (last_best_loss == float("inf") and epoch == args.epochs):
863                        last_best_loss = val_loss
864                        checkpoint_path = os.path.join(checkpoint_save_dir, f"best_checkpoint.pth")
865                        save_checkpoint(args, model, optimizer, scheduler, epoch, checkpoint_path)
866
867                        # Log model checkpoint path
868                        mlflow.log_artifact(checkpoint_path)
869                elif epoch % args.checkpoint_interval == 0 or epoch == args.epochs:
870                    checkpoint_path = os.path.join(checkpoint_save_dir, f"epoch_{epoch}.pth")
871                    save_checkpoint(args, model, optimizer, scheduler, epoch, checkpoint_path)
872
873                    # Log model checkpoint path
874                    mlflow.log_artifact(checkpoint_path)
875
876            # Log final model
877            try:
878                mlflow.pytorch.log_model(model.cpu(), name="model", input_example=dummy_inference_data.cpu().numpy())
879            except Exception as e:
880                print(e)
881                mlflow.pytorch.log_model(model, name="model")
882            mlflow.end_run()
883        finally:
884            writer.close()
885
886        print("Training completed.")

Main training loop for image-to-image tasks.

Workflow:

  1. Initializes the training and validation datasets based on model type.
  2. Constructs the model and its loss functions.
  3. Configures optimizers, learning rate schedulers, and optional warm-up phases.
  4. Enables mixed precision (AMP) if selected.
  5. Sets up MLflow experiment tracking and TensorBoard visualization.
  6. Executes the epoch loop:
    • Trains the model for one epoch (train_one_epoch()).
    • Optionally evaluates on the validation set.
    • Logs metrics and learning rates.
    • Updates the scheduler.
    • Saves checkpoints (best or periodic).
  7. Logs the trained model and experiment results to MLflow upon completion.

Parameters:

  • args : argparse.Namespace, optional Parsed command-line arguments containing all training configurations. If None, the function will automatically call parse_args() to obtain them.

Returns:

  • None: The function performs training and logging in-place without returning values.

Notes:

  • Automatically handles model-specific configurations (e.g., Pix2Pix discriminator, ResidualDesignModel branches).
  • Uses prime.get_time() to generate time-stamped run names.
  • Supports gradient clipping and various learning rate schedulers.

Logging:

  • MLflow: Stores metrics, hyperparameters, checkpoints, and final model.
  • TensorBoard: Logs training/validation losses, learning rates, and sub-loss components.