image_to_image.models.pix2pix
Module to define a Pix2Pix Model. A UNet CNN Generator combined with a small generative loss.
Functions:
- unet_down_block
- unet_up_block
Classes:
- MMC
- UNetGenerator
- Discriminator
- Pix2Pix
By Tobia Ippolito
1""" 2Module to define a Pix2Pix Model. 3A UNet CNN Generator combined with a small generative loss. 4 5Functions: 6- unet_down_block 7- unet_up_block 8 9Classes: 10- MMC 11- UNetGenerator 12- Discriminator 13- Pix2Pix 14 15By Tobia Ippolito 16""" 17# --------------------------- 18# > Imports < 19# --------------------------- 20import torch 21import torch.nn as nn 22import torch.nn.functional as F 23from torch.amp import autocast 24 25 26 27# --------------------------- 28# > Generator < 29# --------------------------- 30class MMC(nn.Module): # MinMaxClamping 31 """ 32 Min-Max Clamping Module. 33 34 Clamps input tensor values between a specified minimum and maximum. 35 36 Parameter: 37 - min (float): 38 Minimum allowed value (default=0.0). 39 - max (float): 40 Maximum allowed value (default=1.0). 41 42 Usage: 43 - Can be used at the output layer of a generator to ensure predictions remain in a valid range. 44 """ 45 def __init__(self, min=0.0, max=1.0): 46 super().__init__() 47 self.min = min 48 self.max = max 49 50 def forward(self, x): 51 """ 52 Forward pass. 53 54 Parameter: 55 - x (torch.tensor): 56 Input tensor. 57 58 Returns: 59 - torch.tensor: Clamped tensor with values between `min` and `max`. 60 """ 61 return torch.clamp(x, self.min, self.max) 62 63def unet_down_block(in_channels=1, out_channels=1, normalize=True): 64 """ 65 Creates a U-Net downsampling block. 66 67 Parameter: 68 - in_channels (int): 69 Number of input channels. 70 - out_channels (int): 71 Number of output channels. 72 - normalize (bool): 73 Whether to apply instance normalization. 74 75 Returns: 76 - nn.Sequential: Convolutional downsampling block with LeakyReLU activation. 77 """ 78 layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)] 79 if normalize: 80 # layers += [nn.BatchNorm2d(out_channels)] 81 layers += [nn.InstanceNorm2d(out_channels, affine=True)] 82 layers += [nn.LeakyReLU(0.2, inplace=True)] 83 return nn.Sequential(*layers) 84 85def unet_up_block(in_channels=1, out_channels=1, dropout=0.0): 86 """ 87 Creates a U-Net upsampling block. 88 89 Parameter: 90 - in_channels (int): 91 Number of input channels. 92 - out_channels (int): 93 Number of output channels. 94 - dropout (float): 95 Dropout probability (default=0). 96 97 Returns: 98 - nn.Sequential: Transposed convolutional block with ReLU activation and optional dropout. 99 """ 100 layers = [ 101 nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False), 102 # nn.BatchNorm2d(out_channels), 103 nn.InstanceNorm2d(out_channels, affine=True), 104 nn.ReLU(inplace=True) 105 ] 106 107 if dropout: 108 layers += [nn.Dropout(dropout)] 109 return nn.Sequential(*layers) 110 111class UNetGenerator(nn.Module): 112 """ 113 U-Net Generator for image-to-image translation. 114 115 Architecture: 116 - 8 downsampling blocks (encoder) 117 - 8 upsampling blocks (decoder) with skip connections 118 - Sigmoid activation at output for [0,1] pixel normalization 119 120 Parameter: 121 - input_channels (int): 122 Number of input image channels. 123 - output_channels (int): 124 Number of output image channels. 125 - hidden_channels (int): 126 Base hidden channels for first layer. 127 """ 128 def __init__(self, input_channels=1, output_channels=1, hidden_channels=64): 129 super().__init__() 130 # Encoder 131 self.down1 = unet_down_block(input_channels, hidden_channels, normalize=False) # 128 132 self.down2 = unet_down_block(hidden_channels, hidden_channels*2) # 64 133 self.down3 = unet_down_block(hidden_channels*2, hidden_channels*4) # 32 134 self.down4 = unet_down_block(hidden_channels*4, hidden_channels*8) # 16 135 self.down5 = unet_down_block(hidden_channels*8, hidden_channels*8) # 8 136 self.down6 = unet_down_block(hidden_channels*8, hidden_channels*8) # 4 137 self.down7 = unet_down_block(hidden_channels*8, hidden_channels*8) # 2 138 self.down8 = unet_down_block(hidden_channels*8, hidden_channels*8, normalize=False) # 1 139 140 # Decoder 141 self.up1 = unet_up_block(hidden_channels*8, hidden_channels*8, dropout=0.5) 142 self.up2 = unet_up_block(hidden_channels*16, hidden_channels*8, dropout=0.5) 143 self.up3 = unet_up_block(hidden_channels*16, hidden_channels*8, dropout=0.5) 144 self.up4 = unet_up_block(hidden_channels*16, hidden_channels*8) 145 self.up5 = unet_up_block(hidden_channels*16, hidden_channels*4) 146 self.up6 = unet_up_block(hidden_channels*8, hidden_channels*2) 147 self.up7 = unet_up_block(hidden_channels*4, hidden_channels) 148 self.up8 = nn.Sequential( 149 nn.ConvTranspose2d(hidden_channels*2, output_channels, 4, 2, 1), 150 # nn.Tanh() 151 ## MMC(min=0.0, max=1.0) 152 nn.Sigmoid() 153 ) 154 155 def forward(self, x): 156 """ 157 Forward pass of the U-Net generator. 158 159 Parameter: 160 - x (torch.tensor): 161 Input image tensor (batch_size, input_channels, H, W). 162 163 Returns: 164 - torch.tensor: Generated output tensor (batch_size, output_channels, H, W). 165 """ 166 d1 = self.down1(x) 167 d2 = self.down2(d1) 168 d3 = self.down3(d2) 169 d4 = self.down4(d3) 170 d5 = self.down5(d4) 171 d6 = self.down6(d5) 172 d7 = self.down7(d6) 173 d8 = self.down8(d7) 174 175 u1 = self.up1(d8) 176 u2 = self.up2(torch.cat([u1, d7], 1)) 177 u3 = self.up3(torch.cat([u2, d6], 1)) 178 u4 = self.up4(torch.cat([u3, d5], 1)) 179 u5 = self.up5(torch.cat([u4, d4], 1)) 180 u6 = self.up6(torch.cat([u5, d3], 1)) 181 u7 = self.up7(torch.cat([u6, d2], 1)) 182 u8 = self.up8(torch.cat([u7, d1], 1)) 183 184 return u8 185 186 187 188# --------------------------- 189# > Discriminator < 190# --------------------------- 191# PatchGAN 192class Discriminator(nn.Module): 193 """ 194 PatchGAN Discriminator for Pix2Pix GAN. 195 196 Parameter: 197 - input_channels (int): 198 Number of input channels (typically input + target channels concatenated). 199 - hidden_channels (int): 200 Base hidden channels for first layer. 201 202 Architecture: 203 - 5 convolutional blocks with LeakyReLU and batch normalization. 204 - Outputs a 2D patch map of predictions. 205 """ 206 def __init__(self, input_channels=6, hidden_channels=64): 207 """ 208 Initializes a PatchGAN discriminator. 209 210 The discriminator evaluates input-target image pairs to determine 211 if they are real or generated (fake). It progressively downsamples 212 the spatial dimensions while increasing the number of feature channels. 213 214 Parameters: 215 - input_channels (int): 216 Number of input channels, typically input + target concatenated (default=6). 217 - hidden_channels (int): 218 Number of channels in the first convolutional layer; doubled in subsequent layers (default=64). 219 """ 220 super().__init__() 221 self.model = nn.Sequential( 222 nn.Conv2d(input_channels, hidden_channels, kernel_size=4, stride=2, padding=1), 223 nn.LeakyReLU(0.2, inplace=True), 224 225 nn.Conv2d(hidden_channels, hidden_channels*2, kernel_size=4, stride=2, padding=1, bias=False), 226 nn.BatchNorm2d(hidden_channels*2), 227 nn.LeakyReLU(0.2, inplace=True), 228 229 nn.Conv2d(hidden_channels*2, hidden_channels*4, kernel_size=4, stride=2, padding=1, bias=False), 230 nn.BatchNorm2d(hidden_channels*4), 231 nn.LeakyReLU(0.2, inplace=True), 232 233 nn.Conv2d(hidden_channels*4, hidden_channels*8, kernel_size=4, stride=1, padding=1, bias=False), 234 nn.BatchNorm2d(hidden_channels*8), 235 nn.LeakyReLU(0.2, inplace=True), 236 237 nn.Conv2d(hidden_channels*8, 1, kernel_size=4, stride=1, padding=1) 238 ) 239 240 def forward(self, x, y): 241 """ 242 Forward pass of the discriminator. 243 244 Parameter: 245 - x (torch.tensor): 246 Input image tensor. 247 - y (torch.tensor): 248 Target or generated image tensor. 249 250 Returns: 251 - torch.tensor: PatchGAN output tensor predicting real/fake for each patch. 252 """ 253 # concatenate input and target channels 254 return self.model(torch.cat([x, y], dim=1)) 255 256 257 258# --------------------------- 259# > Pix2Pix < 260# --------------------------- 261class Pix2Pix(nn.Module): 262 """ 263 Pix2Pix GAN for image-to-image translation. 264 265 Components: 266 - Generator: U-Net generator producing synthetic images. 267 - Discriminator: PatchGAN discriminator evaluating real vs fake images. 268 - Adversarial loss: Binary cross-entropy. 269 - Optional second loss for pixel-wise supervision. 270 271 Parameter: 272 - input_channels (int): 273 Number of input channels. 274 - output_channels (int): 275 Number of output channels. 276 - hidden_channels (int): 277 Base hidden channels for both generator and discriminator. 278 - second_loss (nn.Module): 279 Optional secondary loss (default: L1Loss). 280 - lambda_second (float): 281 Weight for secondary loss in generator optimization. 282 """ 283 def __init__(self, input_channels=1, output_channels=1, hidden_channels=64, 284 second_loss=nn.L1Loss(), lambda_second=100): 285 """ 286 Initializes the Pix2Pix GAN model. 287 288 Components: 289 - Generator: U-Net architecture for producing synthetic images. 290 - Discriminator: PatchGAN for evaluating real vs. fake images. 291 - Adversarial loss: Binary cross-entropy to train the generator to fool the discriminator. 292 - Optional secondary loss: Pixel-wise supervision (default: L1Loss). 293 294 Parameter: 295 - input_channels (int): 296 Number of channels in the input images (default=1). 297 - output_channels (int): 298 Number of channels in the output images (default=1). 299 - hidden_channels (int): 300 Base number of hidden channels in the generator and discriminator (default=64). 301 - second_loss (nn.Module): 302 Optional secondary loss for the generator (default: nn.L1Loss()). 303 - lambda_second (float): 304 Weight applied to the secondary loss in generator optimization (default=100). 305 """ 306 super().__init__() 307 self.input_channels = input_channels 308 self.output_channels = output_channels 309 310 self.generator = UNetGenerator(input_channels=input_channels, 311 output_channels=output_channels, 312 hidden_channels=hidden_channels) 313 self.discriminator = Discriminator(input_channels=input_channels+output_channels, hidden_channels=hidden_channels) 314 315 self.adversarial_loss = nn.BCEWithLogitsLoss() 316 self.second_loss = second_loss 317 self.lambda_second = lambda_second 318 319 self.last_generator_loss = float("inf") 320 self.last_generator_adversarial_loss = float("inf") 321 self.last_generator_second_loss = float("inf") 322 self.last_discriminator_loss = float("inf") 323 324 def get_input_channels(self): 325 """ 326 Returns the number of input channels used by the model. 327 328 Returns: 329 - int: 330 Number of input channels expected by the model. 331 """ 332 return self.input_channels 333 334 def get_output_channels(self): 335 """ 336 Returns the number of output channels produced by the model. 337 338 Returns: 339 - int: 340 Number of output channels the model generates 341 """ 342 return self.output_channels 343 344 def get_dict(self): 345 """ 346 Returns a dictionary with the most recent loss values. 347 348 Returns: 349 - dict: Loss components (base, complex). 350 351 Notes: 352 - Useful for logging or monitoring training progress. 353 """ 354 return { 355 f"loss_generator": self.last_generator_loss, 356 f"loss_generator_adversarial": self.last_generator_adversarial_loss, 357 f"loss_generator_second": self.last_generator_second_loss, 358 f"loss_discriminator": self.last_discriminator_loss 359 } 360 361 def forward(self, x): 362 """ 363 Forward pass through the generator. 364 365 Parameter: 366 - x (torch.tensor): 367 Input tensor. 368 369 Returns: 370 - torch.tensor: Generated output image. 371 """ 372 return self.generator(x) 373 374 def generator_step(self, x, y, optimizer, amp_scaler, device, gradient_clipping_threshold): 375 """ 376 Performs a single optimization step for the generator. 377 378 This includes: 379 - Forward pass through the generator and discriminator. 380 - Computing adversarial loss (generator tries to fool the discriminator). 381 - Computing optional secondary loss (e.g., L1 or MSE). 382 - Backpropagation and optimizer step, optionally with AMP and gradient clipping. 383 384 Parameters: 385 - x (torch.tensor): 386 Input tensor for the generator (e.g., source image). 387 - y (torch.tensor): 388 Target tensor for supervised secondary loss. 389 - optimizer (torch.optim.Optimizer): 390 Optimizer for the generator parameters. 391 - amp_scaler (torch.cuda.amp.GradScaler or None): 392 Automatic mixed precision scaler. 393 - device (torch.device): 394 Device for AMP autocast. 395 - gradient_clipping_threshold (float or None): 396 Max norm for gradient clipping; if None, no clipping. 397 398 Returns: 399 - tuple(torch.tensor, torch.tensor, torch.tensor): 400 - Total generator loss (adversarial + secondary). 401 - Adversarial loss component. 402 - Secondary loss component (weighted by `lambda_second`). 403 404 Notes: 405 - If AMP is enabled, gradients are scaled and unscaled appropriately. 406 - `last_generator_loss`, `last_generator_adversarial_loss`, and `last_generator_second_loss` are updated. 407 """ 408 if amp_scaler: 409 with autocast(device_type=device.type): 410 # make predictions 411 fake_y = self.generator(x) 412 413 discriminator_fake = self.discriminator(x, fake_y) 414 415 # calc loss -> discriminator thinks it is real? 416 loss_adversarial = self.adversarial_loss(discriminator_fake, torch.ones_like(discriminator_fake)) 417 loss_second = self.second_loss(fake_y, y) * self.lambda_second 418 loss_total = loss_adversarial + loss_second 419 420 # backward pass -> calc gradients and change the weights towards the opposite of gradients via optimizer 421 422 optimizer.zero_grad(set_to_none=True) 423 amp_scaler.scale(loss_total).backward() 424 if gradient_clipping_threshold: 425 # Unscale first! 426 amp_scaler.unscale_(optimizer) 427 torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=gradient_clipping_threshold) 428 amp_scaler.step(optimizer) 429 amp_scaler.update() 430 else: 431 # make predictions 432 fake_y = self.generator(x) 433 434 discriminator_fake = self.discriminator(x, fake_y) 435 436 # calc loss -> discriminator thinks it is real? 437 loss_adversarial = self.adversarial_loss(discriminator_fake, torch.ones_like(discriminator_fake)) 438 loss_second = self.second_loss(fake_y, y) * self.lambda_second 439 loss_total = loss_adversarial + loss_second 440 optimizer.zero_grad() 441 loss_total.backward() 442 if gradient_clipping_threshold: 443 torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=gradient_clipping_threshold) 444 optimizer.step() 445 446 self.last_generator_loss = loss_total.item() 447 self.last_generator_adversarial_loss = loss_adversarial.item() 448 self.last_generator_second_loss = loss_second.item() 449 450 return loss_total, loss_adversarial, loss_second 451 452 def discriminator_step(self, x, y, optimizer, amp_scaler, device, gradient_clipping_threshold): 453 """ 454 Performs a single optimization step for the discriminator. 455 456 This includes: 457 - Forward pass through the discriminator for both real and fake samples. 458 - Computing adversarial loss (binary cross-entropy) for real vs fake patches. 459 - Backpropagation and optimizer step, optionally with AMP and gradient clipping. 460 461 Parameters: 462 - x (torch.tensor): 463 Input tensor (e.g., source image). 464 - y (torch.tensor): 465 Target tensor (real image) for the discriminator. 466 - optimizer (torch.optim.Optimizer): 467 Optimizer for the discriminator parameters. 468 - amp_scaler (torch.cuda.amp.GradScaler or None): 469 Automatic mixed precision scaler. 470 - device (torch.device): 471 Device for AMP autocast. 472 - gradient_clipping_threshold (float or None): 473 Max norm for gradient clipping; if None, no clipping. 474 475 Returns: 476 - tuple(torch.tensor, torch.tensor, torch.tensor): 477 - Total discriminator loss (mean of real and fake losses). 478 - Loss for real samples. 479 - Loss for fake samples. 480 481 Notes: 482 - Fake images are detached from the generator to prevent updating its weights. 483 - `last_discriminator_loss` is updated. 484 - Supports AMP and optional gradient clipping for stability. 485 """ 486 if amp_scaler: 487 with autocast(device_type=device.type): 488 # make predictions 489 fake_y = self.generator(x).detach() # don't update generator!! 490 491 discriminator_real = self.discriminator(x, y) 492 discriminator_fake = self.discriminator(x, fake_y) 493 494 # calc loss -> 1: predictions = real, 0: predictions = fake 495 loss_real = self.adversarial_loss(discriminator_real, torch.ones_like(discriminator_real)) # torch.full_like(discriminator_real, 0.9) 496 loss_fake = self.adversarial_loss(discriminator_fake, torch.zeros_like(discriminator_fake)) 497 loss_total = (loss_real + loss_fake) * 0.5 498 499 # backward pass -> calc gradients and change the weights towards the opposite of gradients via optimizer 500 optimizer.zero_grad(set_to_none=True) 501 amp_scaler.scale(loss_total).backward() 502 if gradient_clipping_threshold: 503 # Unscale first! 504 amp_scaler.unscale_(optimizer) 505 torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=gradient_clipping_threshold) 506 amp_scaler.step(optimizer) 507 amp_scaler.update() 508 else: 509 # make predictions 510 fake_y = self.generator(x).detach() 511 512 discriminator_real = self.discriminator(x, y) 513 discriminator_fake = self.discriminator(x, fake_y) 514 515 # calc loss -> 1: predictions = real, 0: predictions = fake 516 loss_real = self.adversarial_loss(discriminator_real, torch.ones_like(discriminator_real)) # torch.full_like(discriminator_real, 0.9) 517 loss_fake = self.adversarial_loss(discriminator_fake, torch.zeros_like(discriminator_fake)) 518 loss_total = (loss_real + loss_fake) * 0.5 519 optimizer.zero_grad() 520 loss_total.backward() 521 if gradient_clipping_threshold: 522 torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=gradient_clipping_threshold) 523 optimizer.step() 524 525 self.last_discriminator_loss = loss_total.item() 526 527 return loss_total, loss_real, loss_fake
31class MMC(nn.Module): # MinMaxClamping 32 """ 33 Min-Max Clamping Module. 34 35 Clamps input tensor values between a specified minimum and maximum. 36 37 Parameter: 38 - min (float): 39 Minimum allowed value (default=0.0). 40 - max (float): 41 Maximum allowed value (default=1.0). 42 43 Usage: 44 - Can be used at the output layer of a generator to ensure predictions remain in a valid range. 45 """ 46 def __init__(self, min=0.0, max=1.0): 47 super().__init__() 48 self.min = min 49 self.max = max 50 51 def forward(self, x): 52 """ 53 Forward pass. 54 55 Parameter: 56 - x (torch.tensor): 57 Input tensor. 58 59 Returns: 60 - torch.tensor: Clamped tensor with values between `min` and `max`. 61 """ 62 return torch.clamp(x, self.min, self.max)
Min-Max Clamping Module.
Clamps input tensor values between a specified minimum and maximum.
Parameter:
- min (float): Minimum allowed value (default=0.0).
- max (float): Maximum allowed value (default=1.0).
Usage:
- Can be used at the output layer of a generator to ensure predictions remain in a valid range.
64def unet_down_block(in_channels=1, out_channels=1, normalize=True): 65 """ 66 Creates a U-Net downsampling block. 67 68 Parameter: 69 - in_channels (int): 70 Number of input channels. 71 - out_channels (int): 72 Number of output channels. 73 - normalize (bool): 74 Whether to apply instance normalization. 75 76 Returns: 77 - nn.Sequential: Convolutional downsampling block with LeakyReLU activation. 78 """ 79 layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)] 80 if normalize: 81 # layers += [nn.BatchNorm2d(out_channels)] 82 layers += [nn.InstanceNorm2d(out_channels, affine=True)] 83 layers += [nn.LeakyReLU(0.2, inplace=True)] 84 return nn.Sequential(*layers)
Creates a U-Net downsampling block.
Parameter:
- in_channels (int): Number of input channels.
- out_channels (int): Number of output channels.
- normalize (bool): Whether to apply instance normalization.
Returns:
- nn.Sequential: Convolutional downsampling block with LeakyReLU activation.
86def unet_up_block(in_channels=1, out_channels=1, dropout=0.0): 87 """ 88 Creates a U-Net upsampling block. 89 90 Parameter: 91 - in_channels (int): 92 Number of input channels. 93 - out_channels (int): 94 Number of output channels. 95 - dropout (float): 96 Dropout probability (default=0). 97 98 Returns: 99 - nn.Sequential: Transposed convolutional block with ReLU activation and optional dropout. 100 """ 101 layers = [ 102 nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False), 103 # nn.BatchNorm2d(out_channels), 104 nn.InstanceNorm2d(out_channels, affine=True), 105 nn.ReLU(inplace=True) 106 ] 107 108 if dropout: 109 layers += [nn.Dropout(dropout)] 110 return nn.Sequential(*layers)
Creates a U-Net upsampling block.
Parameter:
- in_channels (int): Number of input channels.
- out_channels (int): Number of output channels.
- dropout (float): Dropout probability (default=0).
Returns:
- nn.Sequential: Transposed convolutional block with ReLU activation and optional dropout.
112class UNetGenerator(nn.Module): 113 """ 114 U-Net Generator for image-to-image translation. 115 116 Architecture: 117 - 8 downsampling blocks (encoder) 118 - 8 upsampling blocks (decoder) with skip connections 119 - Sigmoid activation at output for [0,1] pixel normalization 120 121 Parameter: 122 - input_channels (int): 123 Number of input image channels. 124 - output_channels (int): 125 Number of output image channels. 126 - hidden_channels (int): 127 Base hidden channels for first layer. 128 """ 129 def __init__(self, input_channels=1, output_channels=1, hidden_channels=64): 130 super().__init__() 131 # Encoder 132 self.down1 = unet_down_block(input_channels, hidden_channels, normalize=False) # 128 133 self.down2 = unet_down_block(hidden_channels, hidden_channels*2) # 64 134 self.down3 = unet_down_block(hidden_channels*2, hidden_channels*4) # 32 135 self.down4 = unet_down_block(hidden_channels*4, hidden_channels*8) # 16 136 self.down5 = unet_down_block(hidden_channels*8, hidden_channels*8) # 8 137 self.down6 = unet_down_block(hidden_channels*8, hidden_channels*8) # 4 138 self.down7 = unet_down_block(hidden_channels*8, hidden_channels*8) # 2 139 self.down8 = unet_down_block(hidden_channels*8, hidden_channels*8, normalize=False) # 1 140 141 # Decoder 142 self.up1 = unet_up_block(hidden_channels*8, hidden_channels*8, dropout=0.5) 143 self.up2 = unet_up_block(hidden_channels*16, hidden_channels*8, dropout=0.5) 144 self.up3 = unet_up_block(hidden_channels*16, hidden_channels*8, dropout=0.5) 145 self.up4 = unet_up_block(hidden_channels*16, hidden_channels*8) 146 self.up5 = unet_up_block(hidden_channels*16, hidden_channels*4) 147 self.up6 = unet_up_block(hidden_channels*8, hidden_channels*2) 148 self.up7 = unet_up_block(hidden_channels*4, hidden_channels) 149 self.up8 = nn.Sequential( 150 nn.ConvTranspose2d(hidden_channels*2, output_channels, 4, 2, 1), 151 # nn.Tanh() 152 ## MMC(min=0.0, max=1.0) 153 nn.Sigmoid() 154 ) 155 156 def forward(self, x): 157 """ 158 Forward pass of the U-Net generator. 159 160 Parameter: 161 - x (torch.tensor): 162 Input image tensor (batch_size, input_channels, H, W). 163 164 Returns: 165 - torch.tensor: Generated output tensor (batch_size, output_channels, H, W). 166 """ 167 d1 = self.down1(x) 168 d2 = self.down2(d1) 169 d3 = self.down3(d2) 170 d4 = self.down4(d3) 171 d5 = self.down5(d4) 172 d6 = self.down6(d5) 173 d7 = self.down7(d6) 174 d8 = self.down8(d7) 175 176 u1 = self.up1(d8) 177 u2 = self.up2(torch.cat([u1, d7], 1)) 178 u3 = self.up3(torch.cat([u2, d6], 1)) 179 u4 = self.up4(torch.cat([u3, d5], 1)) 180 u5 = self.up5(torch.cat([u4, d4], 1)) 181 u6 = self.up6(torch.cat([u5, d3], 1)) 182 u7 = self.up7(torch.cat([u6, d2], 1)) 183 u8 = self.up8(torch.cat([u7, d1], 1)) 184 185 return u8
U-Net Generator for image-to-image translation.
Architecture:
- 8 downsampling blocks (encoder)
- 8 upsampling blocks (decoder) with skip connections
- Sigmoid activation at output for [0,1] pixel normalization
Parameter:
- input_channels (int): Number of input image channels.
- output_channels (int): Number of output image channels.
- hidden_channels (int): Base hidden channels for first layer.
129 def __init__(self, input_channels=1, output_channels=1, hidden_channels=64): 130 super().__init__() 131 # Encoder 132 self.down1 = unet_down_block(input_channels, hidden_channels, normalize=False) # 128 133 self.down2 = unet_down_block(hidden_channels, hidden_channels*2) # 64 134 self.down3 = unet_down_block(hidden_channels*2, hidden_channels*4) # 32 135 self.down4 = unet_down_block(hidden_channels*4, hidden_channels*8) # 16 136 self.down5 = unet_down_block(hidden_channels*8, hidden_channels*8) # 8 137 self.down6 = unet_down_block(hidden_channels*8, hidden_channels*8) # 4 138 self.down7 = unet_down_block(hidden_channels*8, hidden_channels*8) # 2 139 self.down8 = unet_down_block(hidden_channels*8, hidden_channels*8, normalize=False) # 1 140 141 # Decoder 142 self.up1 = unet_up_block(hidden_channels*8, hidden_channels*8, dropout=0.5) 143 self.up2 = unet_up_block(hidden_channels*16, hidden_channels*8, dropout=0.5) 144 self.up3 = unet_up_block(hidden_channels*16, hidden_channels*8, dropout=0.5) 145 self.up4 = unet_up_block(hidden_channels*16, hidden_channels*8) 146 self.up5 = unet_up_block(hidden_channels*16, hidden_channels*4) 147 self.up6 = unet_up_block(hidden_channels*8, hidden_channels*2) 148 self.up7 = unet_up_block(hidden_channels*4, hidden_channels) 149 self.up8 = nn.Sequential( 150 nn.ConvTranspose2d(hidden_channels*2, output_channels, 4, 2, 1), 151 # nn.Tanh() 152 ## MMC(min=0.0, max=1.0) 153 nn.Sigmoid() 154 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
156 def forward(self, x): 157 """ 158 Forward pass of the U-Net generator. 159 160 Parameter: 161 - x (torch.tensor): 162 Input image tensor (batch_size, input_channels, H, W). 163 164 Returns: 165 - torch.tensor: Generated output tensor (batch_size, output_channels, H, W). 166 """ 167 d1 = self.down1(x) 168 d2 = self.down2(d1) 169 d3 = self.down3(d2) 170 d4 = self.down4(d3) 171 d5 = self.down5(d4) 172 d6 = self.down6(d5) 173 d7 = self.down7(d6) 174 d8 = self.down8(d7) 175 176 u1 = self.up1(d8) 177 u2 = self.up2(torch.cat([u1, d7], 1)) 178 u3 = self.up3(torch.cat([u2, d6], 1)) 179 u4 = self.up4(torch.cat([u3, d5], 1)) 180 u5 = self.up5(torch.cat([u4, d4], 1)) 181 u6 = self.up6(torch.cat([u5, d3], 1)) 182 u7 = self.up7(torch.cat([u6, d2], 1)) 183 u8 = self.up8(torch.cat([u7, d1], 1)) 184 185 return u8
Forward pass of the U-Net generator.
Parameter:
- x (torch.tensor): Input image tensor (batch_size, input_channels, H, W).
Returns:
- torch.tensor: Generated output tensor (batch_size, output_channels, H, W).
193class Discriminator(nn.Module): 194 """ 195 PatchGAN Discriminator for Pix2Pix GAN. 196 197 Parameter: 198 - input_channels (int): 199 Number of input channels (typically input + target channels concatenated). 200 - hidden_channels (int): 201 Base hidden channels for first layer. 202 203 Architecture: 204 - 5 convolutional blocks with LeakyReLU and batch normalization. 205 - Outputs a 2D patch map of predictions. 206 """ 207 def __init__(self, input_channels=6, hidden_channels=64): 208 """ 209 Initializes a PatchGAN discriminator. 210 211 The discriminator evaluates input-target image pairs to determine 212 if they are real or generated (fake). It progressively downsamples 213 the spatial dimensions while increasing the number of feature channels. 214 215 Parameters: 216 - input_channels (int): 217 Number of input channels, typically input + target concatenated (default=6). 218 - hidden_channels (int): 219 Number of channels in the first convolutional layer; doubled in subsequent layers (default=64). 220 """ 221 super().__init__() 222 self.model = nn.Sequential( 223 nn.Conv2d(input_channels, hidden_channels, kernel_size=4, stride=2, padding=1), 224 nn.LeakyReLU(0.2, inplace=True), 225 226 nn.Conv2d(hidden_channels, hidden_channels*2, kernel_size=4, stride=2, padding=1, bias=False), 227 nn.BatchNorm2d(hidden_channels*2), 228 nn.LeakyReLU(0.2, inplace=True), 229 230 nn.Conv2d(hidden_channels*2, hidden_channels*4, kernel_size=4, stride=2, padding=1, bias=False), 231 nn.BatchNorm2d(hidden_channels*4), 232 nn.LeakyReLU(0.2, inplace=True), 233 234 nn.Conv2d(hidden_channels*4, hidden_channels*8, kernel_size=4, stride=1, padding=1, bias=False), 235 nn.BatchNorm2d(hidden_channels*8), 236 nn.LeakyReLU(0.2, inplace=True), 237 238 nn.Conv2d(hidden_channels*8, 1, kernel_size=4, stride=1, padding=1) 239 ) 240 241 def forward(self, x, y): 242 """ 243 Forward pass of the discriminator. 244 245 Parameter: 246 - x (torch.tensor): 247 Input image tensor. 248 - y (torch.tensor): 249 Target or generated image tensor. 250 251 Returns: 252 - torch.tensor: PatchGAN output tensor predicting real/fake for each patch. 253 """ 254 # concatenate input and target channels 255 return self.model(torch.cat([x, y], dim=1))
PatchGAN Discriminator for Pix2Pix GAN.
Parameter:
- input_channels (int): Number of input channels (typically input + target channels concatenated).
- hidden_channels (int): Base hidden channels for first layer.
Architecture:
- 5 convolutional blocks with LeakyReLU and batch normalization.
- Outputs a 2D patch map of predictions.
207 def __init__(self, input_channels=6, hidden_channels=64): 208 """ 209 Initializes a PatchGAN discriminator. 210 211 The discriminator evaluates input-target image pairs to determine 212 if they are real or generated (fake). It progressively downsamples 213 the spatial dimensions while increasing the number of feature channels. 214 215 Parameters: 216 - input_channels (int): 217 Number of input channels, typically input + target concatenated (default=6). 218 - hidden_channels (int): 219 Number of channels in the first convolutional layer; doubled in subsequent layers (default=64). 220 """ 221 super().__init__() 222 self.model = nn.Sequential( 223 nn.Conv2d(input_channels, hidden_channels, kernel_size=4, stride=2, padding=1), 224 nn.LeakyReLU(0.2, inplace=True), 225 226 nn.Conv2d(hidden_channels, hidden_channels*2, kernel_size=4, stride=2, padding=1, bias=False), 227 nn.BatchNorm2d(hidden_channels*2), 228 nn.LeakyReLU(0.2, inplace=True), 229 230 nn.Conv2d(hidden_channels*2, hidden_channels*4, kernel_size=4, stride=2, padding=1, bias=False), 231 nn.BatchNorm2d(hidden_channels*4), 232 nn.LeakyReLU(0.2, inplace=True), 233 234 nn.Conv2d(hidden_channels*4, hidden_channels*8, kernel_size=4, stride=1, padding=1, bias=False), 235 nn.BatchNorm2d(hidden_channels*8), 236 nn.LeakyReLU(0.2, inplace=True), 237 238 nn.Conv2d(hidden_channels*8, 1, kernel_size=4, stride=1, padding=1) 239 )
Initializes a PatchGAN discriminator.
The discriminator evaluates input-target image pairs to determine if they are real or generated (fake). It progressively downsamples the spatial dimensions while increasing the number of feature channels.
Parameters:
- input_channels (int): Number of input channels, typically input + target concatenated (default=6).
- hidden_channels (int): Number of channels in the first convolutional layer; doubled in subsequent layers (default=64).
241 def forward(self, x, y): 242 """ 243 Forward pass of the discriminator. 244 245 Parameter: 246 - x (torch.tensor): 247 Input image tensor. 248 - y (torch.tensor): 249 Target or generated image tensor. 250 251 Returns: 252 - torch.tensor: PatchGAN output tensor predicting real/fake for each patch. 253 """ 254 # concatenate input and target channels 255 return self.model(torch.cat([x, y], dim=1))
Forward pass of the discriminator.
Parameter:
- x (torch.tensor): Input image tensor.
- y (torch.tensor): Target or generated image tensor.
Returns:
- torch.tensor: PatchGAN output tensor predicting real/fake for each patch.
262class Pix2Pix(nn.Module): 263 """ 264 Pix2Pix GAN for image-to-image translation. 265 266 Components: 267 - Generator: U-Net generator producing synthetic images. 268 - Discriminator: PatchGAN discriminator evaluating real vs fake images. 269 - Adversarial loss: Binary cross-entropy. 270 - Optional second loss for pixel-wise supervision. 271 272 Parameter: 273 - input_channels (int): 274 Number of input channels. 275 - output_channels (int): 276 Number of output channels. 277 - hidden_channels (int): 278 Base hidden channels for both generator and discriminator. 279 - second_loss (nn.Module): 280 Optional secondary loss (default: L1Loss). 281 - lambda_second (float): 282 Weight for secondary loss in generator optimization. 283 """ 284 def __init__(self, input_channels=1, output_channels=1, hidden_channels=64, 285 second_loss=nn.L1Loss(), lambda_second=100): 286 """ 287 Initializes the Pix2Pix GAN model. 288 289 Components: 290 - Generator: U-Net architecture for producing synthetic images. 291 - Discriminator: PatchGAN for evaluating real vs. fake images. 292 - Adversarial loss: Binary cross-entropy to train the generator to fool the discriminator. 293 - Optional secondary loss: Pixel-wise supervision (default: L1Loss). 294 295 Parameter: 296 - input_channels (int): 297 Number of channels in the input images (default=1). 298 - output_channels (int): 299 Number of channels in the output images (default=1). 300 - hidden_channels (int): 301 Base number of hidden channels in the generator and discriminator (default=64). 302 - second_loss (nn.Module): 303 Optional secondary loss for the generator (default: nn.L1Loss()). 304 - lambda_second (float): 305 Weight applied to the secondary loss in generator optimization (default=100). 306 """ 307 super().__init__() 308 self.input_channels = input_channels 309 self.output_channels = output_channels 310 311 self.generator = UNetGenerator(input_channels=input_channels, 312 output_channels=output_channels, 313 hidden_channels=hidden_channels) 314 self.discriminator = Discriminator(input_channels=input_channels+output_channels, hidden_channels=hidden_channels) 315 316 self.adversarial_loss = nn.BCEWithLogitsLoss() 317 self.second_loss = second_loss 318 self.lambda_second = lambda_second 319 320 self.last_generator_loss = float("inf") 321 self.last_generator_adversarial_loss = float("inf") 322 self.last_generator_second_loss = float("inf") 323 self.last_discriminator_loss = float("inf") 324 325 def get_input_channels(self): 326 """ 327 Returns the number of input channels used by the model. 328 329 Returns: 330 - int: 331 Number of input channels expected by the model. 332 """ 333 return self.input_channels 334 335 def get_output_channels(self): 336 """ 337 Returns the number of output channels produced by the model. 338 339 Returns: 340 - int: 341 Number of output channels the model generates 342 """ 343 return self.output_channels 344 345 def get_dict(self): 346 """ 347 Returns a dictionary with the most recent loss values. 348 349 Returns: 350 - dict: Loss components (base, complex). 351 352 Notes: 353 - Useful for logging or monitoring training progress. 354 """ 355 return { 356 f"loss_generator": self.last_generator_loss, 357 f"loss_generator_adversarial": self.last_generator_adversarial_loss, 358 f"loss_generator_second": self.last_generator_second_loss, 359 f"loss_discriminator": self.last_discriminator_loss 360 } 361 362 def forward(self, x): 363 """ 364 Forward pass through the generator. 365 366 Parameter: 367 - x (torch.tensor): 368 Input tensor. 369 370 Returns: 371 - torch.tensor: Generated output image. 372 """ 373 return self.generator(x) 374 375 def generator_step(self, x, y, optimizer, amp_scaler, device, gradient_clipping_threshold): 376 """ 377 Performs a single optimization step for the generator. 378 379 This includes: 380 - Forward pass through the generator and discriminator. 381 - Computing adversarial loss (generator tries to fool the discriminator). 382 - Computing optional secondary loss (e.g., L1 or MSE). 383 - Backpropagation and optimizer step, optionally with AMP and gradient clipping. 384 385 Parameters: 386 - x (torch.tensor): 387 Input tensor for the generator (e.g., source image). 388 - y (torch.tensor): 389 Target tensor for supervised secondary loss. 390 - optimizer (torch.optim.Optimizer): 391 Optimizer for the generator parameters. 392 - amp_scaler (torch.cuda.amp.GradScaler or None): 393 Automatic mixed precision scaler. 394 - device (torch.device): 395 Device for AMP autocast. 396 - gradient_clipping_threshold (float or None): 397 Max norm for gradient clipping; if None, no clipping. 398 399 Returns: 400 - tuple(torch.tensor, torch.tensor, torch.tensor): 401 - Total generator loss (adversarial + secondary). 402 - Adversarial loss component. 403 - Secondary loss component (weighted by `lambda_second`). 404 405 Notes: 406 - If AMP is enabled, gradients are scaled and unscaled appropriately. 407 - `last_generator_loss`, `last_generator_adversarial_loss`, and `last_generator_second_loss` are updated. 408 """ 409 if amp_scaler: 410 with autocast(device_type=device.type): 411 # make predictions 412 fake_y = self.generator(x) 413 414 discriminator_fake = self.discriminator(x, fake_y) 415 416 # calc loss -> discriminator thinks it is real? 417 loss_adversarial = self.adversarial_loss(discriminator_fake, torch.ones_like(discriminator_fake)) 418 loss_second = self.second_loss(fake_y, y) * self.lambda_second 419 loss_total = loss_adversarial + loss_second 420 421 # backward pass -> calc gradients and change the weights towards the opposite of gradients via optimizer 422 423 optimizer.zero_grad(set_to_none=True) 424 amp_scaler.scale(loss_total).backward() 425 if gradient_clipping_threshold: 426 # Unscale first! 427 amp_scaler.unscale_(optimizer) 428 torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=gradient_clipping_threshold) 429 amp_scaler.step(optimizer) 430 amp_scaler.update() 431 else: 432 # make predictions 433 fake_y = self.generator(x) 434 435 discriminator_fake = self.discriminator(x, fake_y) 436 437 # calc loss -> discriminator thinks it is real? 438 loss_adversarial = self.adversarial_loss(discriminator_fake, torch.ones_like(discriminator_fake)) 439 loss_second = self.second_loss(fake_y, y) * self.lambda_second 440 loss_total = loss_adversarial + loss_second 441 optimizer.zero_grad() 442 loss_total.backward() 443 if gradient_clipping_threshold: 444 torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=gradient_clipping_threshold) 445 optimizer.step() 446 447 self.last_generator_loss = loss_total.item() 448 self.last_generator_adversarial_loss = loss_adversarial.item() 449 self.last_generator_second_loss = loss_second.item() 450 451 return loss_total, loss_adversarial, loss_second 452 453 def discriminator_step(self, x, y, optimizer, amp_scaler, device, gradient_clipping_threshold): 454 """ 455 Performs a single optimization step for the discriminator. 456 457 This includes: 458 - Forward pass through the discriminator for both real and fake samples. 459 - Computing adversarial loss (binary cross-entropy) for real vs fake patches. 460 - Backpropagation and optimizer step, optionally with AMP and gradient clipping. 461 462 Parameters: 463 - x (torch.tensor): 464 Input tensor (e.g., source image). 465 - y (torch.tensor): 466 Target tensor (real image) for the discriminator. 467 - optimizer (torch.optim.Optimizer): 468 Optimizer for the discriminator parameters. 469 - amp_scaler (torch.cuda.amp.GradScaler or None): 470 Automatic mixed precision scaler. 471 - device (torch.device): 472 Device for AMP autocast. 473 - gradient_clipping_threshold (float or None): 474 Max norm for gradient clipping; if None, no clipping. 475 476 Returns: 477 - tuple(torch.tensor, torch.tensor, torch.tensor): 478 - Total discriminator loss (mean of real and fake losses). 479 - Loss for real samples. 480 - Loss for fake samples. 481 482 Notes: 483 - Fake images are detached from the generator to prevent updating its weights. 484 - `last_discriminator_loss` is updated. 485 - Supports AMP and optional gradient clipping for stability. 486 """ 487 if amp_scaler: 488 with autocast(device_type=device.type): 489 # make predictions 490 fake_y = self.generator(x).detach() # don't update generator!! 491 492 discriminator_real = self.discriminator(x, y) 493 discriminator_fake = self.discriminator(x, fake_y) 494 495 # calc loss -> 1: predictions = real, 0: predictions = fake 496 loss_real = self.adversarial_loss(discriminator_real, torch.ones_like(discriminator_real)) # torch.full_like(discriminator_real, 0.9) 497 loss_fake = self.adversarial_loss(discriminator_fake, torch.zeros_like(discriminator_fake)) 498 loss_total = (loss_real + loss_fake) * 0.5 499 500 # backward pass -> calc gradients and change the weights towards the opposite of gradients via optimizer 501 optimizer.zero_grad(set_to_none=True) 502 amp_scaler.scale(loss_total).backward() 503 if gradient_clipping_threshold: 504 # Unscale first! 505 amp_scaler.unscale_(optimizer) 506 torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=gradient_clipping_threshold) 507 amp_scaler.step(optimizer) 508 amp_scaler.update() 509 else: 510 # make predictions 511 fake_y = self.generator(x).detach() 512 513 discriminator_real = self.discriminator(x, y) 514 discriminator_fake = self.discriminator(x, fake_y) 515 516 # calc loss -> 1: predictions = real, 0: predictions = fake 517 loss_real = self.adversarial_loss(discriminator_real, torch.ones_like(discriminator_real)) # torch.full_like(discriminator_real, 0.9) 518 loss_fake = self.adversarial_loss(discriminator_fake, torch.zeros_like(discriminator_fake)) 519 loss_total = (loss_real + loss_fake) * 0.5 520 optimizer.zero_grad() 521 loss_total.backward() 522 if gradient_clipping_threshold: 523 torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=gradient_clipping_threshold) 524 optimizer.step() 525 526 self.last_discriminator_loss = loss_total.item() 527 528 return loss_total, loss_real, loss_fake
Pix2Pix GAN for image-to-image translation.
Components:
- Generator: U-Net generator producing synthetic images.
- Discriminator: PatchGAN discriminator evaluating real vs fake images.
- Adversarial loss: Binary cross-entropy.
- Optional second loss for pixel-wise supervision.
Parameter:
- input_channels (int): Number of input channels.
- output_channels (int): Number of output channels.
- hidden_channels (int): Base hidden channels for both generator and discriminator.
- second_loss (nn.Module): Optional secondary loss (default: L1Loss).
- lambda_second (float): Weight for secondary loss in generator optimization.
284 def __init__(self, input_channels=1, output_channels=1, hidden_channels=64, 285 second_loss=nn.L1Loss(), lambda_second=100): 286 """ 287 Initializes the Pix2Pix GAN model. 288 289 Components: 290 - Generator: U-Net architecture for producing synthetic images. 291 - Discriminator: PatchGAN for evaluating real vs. fake images. 292 - Adversarial loss: Binary cross-entropy to train the generator to fool the discriminator. 293 - Optional secondary loss: Pixel-wise supervision (default: L1Loss). 294 295 Parameter: 296 - input_channels (int): 297 Number of channels in the input images (default=1). 298 - output_channels (int): 299 Number of channels in the output images (default=1). 300 - hidden_channels (int): 301 Base number of hidden channels in the generator and discriminator (default=64). 302 - second_loss (nn.Module): 303 Optional secondary loss for the generator (default: nn.L1Loss()). 304 - lambda_second (float): 305 Weight applied to the secondary loss in generator optimization (default=100). 306 """ 307 super().__init__() 308 self.input_channels = input_channels 309 self.output_channels = output_channels 310 311 self.generator = UNetGenerator(input_channels=input_channels, 312 output_channels=output_channels, 313 hidden_channels=hidden_channels) 314 self.discriminator = Discriminator(input_channels=input_channels+output_channels, hidden_channels=hidden_channels) 315 316 self.adversarial_loss = nn.BCEWithLogitsLoss() 317 self.second_loss = second_loss 318 self.lambda_second = lambda_second 319 320 self.last_generator_loss = float("inf") 321 self.last_generator_adversarial_loss = float("inf") 322 self.last_generator_second_loss = float("inf") 323 self.last_discriminator_loss = float("inf")
Initializes the Pix2Pix GAN model.
Components:
- Generator: U-Net architecture for producing synthetic images.
- Discriminator: PatchGAN for evaluating real vs. fake images.
- Adversarial loss: Binary cross-entropy to train the generator to fool the discriminator.
- Optional secondary loss: Pixel-wise supervision (default: L1Loss).
Parameter:
- input_channels (int): Number of channels in the input images (default=1).
- output_channels (int): Number of channels in the output images (default=1).
- hidden_channels (int): Base number of hidden channels in the generator and discriminator (default=64).
- second_loss (nn.Module): Optional secondary loss for the generator (default: nn.L1Loss()).
- lambda_second (float): Weight applied to the secondary loss in generator optimization (default=100).
325 def get_input_channels(self): 326 """ 327 Returns the number of input channels used by the model. 328 329 Returns: 330 - int: 331 Number of input channels expected by the model. 332 """ 333 return self.input_channels
Returns the number of input channels used by the model.
Returns:
- int: Number of input channels expected by the model.
335 def get_output_channels(self): 336 """ 337 Returns the number of output channels produced by the model. 338 339 Returns: 340 - int: 341 Number of output channels the model generates 342 """ 343 return self.output_channels
Returns the number of output channels produced by the model.
Returns:
- int: Number of output channels the model generates
345 def get_dict(self): 346 """ 347 Returns a dictionary with the most recent loss values. 348 349 Returns: 350 - dict: Loss components (base, complex). 351 352 Notes: 353 - Useful for logging or monitoring training progress. 354 """ 355 return { 356 f"loss_generator": self.last_generator_loss, 357 f"loss_generator_adversarial": self.last_generator_adversarial_loss, 358 f"loss_generator_second": self.last_generator_second_loss, 359 f"loss_discriminator": self.last_discriminator_loss 360 }
Returns a dictionary with the most recent loss values.
Returns:
- dict: Loss components (base, complex).
Notes:
- Useful for logging or monitoring training progress.
362 def forward(self, x): 363 """ 364 Forward pass through the generator. 365 366 Parameter: 367 - x (torch.tensor): 368 Input tensor. 369 370 Returns: 371 - torch.tensor: Generated output image. 372 """ 373 return self.generator(x)
Forward pass through the generator.
Parameter:
- x (torch.tensor): Input tensor.
Returns:
- torch.tensor: Generated output image.
375 def generator_step(self, x, y, optimizer, amp_scaler, device, gradient_clipping_threshold): 376 """ 377 Performs a single optimization step for the generator. 378 379 This includes: 380 - Forward pass through the generator and discriminator. 381 - Computing adversarial loss (generator tries to fool the discriminator). 382 - Computing optional secondary loss (e.g., L1 or MSE). 383 - Backpropagation and optimizer step, optionally with AMP and gradient clipping. 384 385 Parameters: 386 - x (torch.tensor): 387 Input tensor for the generator (e.g., source image). 388 - y (torch.tensor): 389 Target tensor for supervised secondary loss. 390 - optimizer (torch.optim.Optimizer): 391 Optimizer for the generator parameters. 392 - amp_scaler (torch.cuda.amp.GradScaler or None): 393 Automatic mixed precision scaler. 394 - device (torch.device): 395 Device for AMP autocast. 396 - gradient_clipping_threshold (float or None): 397 Max norm for gradient clipping; if None, no clipping. 398 399 Returns: 400 - tuple(torch.tensor, torch.tensor, torch.tensor): 401 - Total generator loss (adversarial + secondary). 402 - Adversarial loss component. 403 - Secondary loss component (weighted by `lambda_second`). 404 405 Notes: 406 - If AMP is enabled, gradients are scaled and unscaled appropriately. 407 - `last_generator_loss`, `last_generator_adversarial_loss`, and `last_generator_second_loss` are updated. 408 """ 409 if amp_scaler: 410 with autocast(device_type=device.type): 411 # make predictions 412 fake_y = self.generator(x) 413 414 discriminator_fake = self.discriminator(x, fake_y) 415 416 # calc loss -> discriminator thinks it is real? 417 loss_adversarial = self.adversarial_loss(discriminator_fake, torch.ones_like(discriminator_fake)) 418 loss_second = self.second_loss(fake_y, y) * self.lambda_second 419 loss_total = loss_adversarial + loss_second 420 421 # backward pass -> calc gradients and change the weights towards the opposite of gradients via optimizer 422 423 optimizer.zero_grad(set_to_none=True) 424 amp_scaler.scale(loss_total).backward() 425 if gradient_clipping_threshold: 426 # Unscale first! 427 amp_scaler.unscale_(optimizer) 428 torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=gradient_clipping_threshold) 429 amp_scaler.step(optimizer) 430 amp_scaler.update() 431 else: 432 # make predictions 433 fake_y = self.generator(x) 434 435 discriminator_fake = self.discriminator(x, fake_y) 436 437 # calc loss -> discriminator thinks it is real? 438 loss_adversarial = self.adversarial_loss(discriminator_fake, torch.ones_like(discriminator_fake)) 439 loss_second = self.second_loss(fake_y, y) * self.lambda_second 440 loss_total = loss_adversarial + loss_second 441 optimizer.zero_grad() 442 loss_total.backward() 443 if gradient_clipping_threshold: 444 torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=gradient_clipping_threshold) 445 optimizer.step() 446 447 self.last_generator_loss = loss_total.item() 448 self.last_generator_adversarial_loss = loss_adversarial.item() 449 self.last_generator_second_loss = loss_second.item() 450 451 return loss_total, loss_adversarial, loss_second
Performs a single optimization step for the generator.
This includes:
- Forward pass through the generator and discriminator.
- Computing adversarial loss (generator tries to fool the discriminator).
- Computing optional secondary loss (e.g., L1 or MSE).
- Backpropagation and optimizer step, optionally with AMP and gradient clipping.
Parameters:
- x (torch.tensor): Input tensor for the generator (e.g., source image).
- y (torch.tensor): Target tensor for supervised secondary loss.
- optimizer (torch.optim.Optimizer): Optimizer for the generator parameters.
- amp_scaler (torch.cuda.amp.GradScaler or None): Automatic mixed precision scaler.
- device (torch.device): Device for AMP autocast.
- gradient_clipping_threshold (float or None): Max norm for gradient clipping; if None, no clipping.
Returns:
- tuple(torch.tensor, torch.tensor, torch.tensor):
- Total generator loss (adversarial + secondary).
- Adversarial loss component.
- Secondary loss component (weighted by
lambda_second).
Notes:
- If AMP is enabled, gradients are scaled and unscaled appropriately.
last_generator_loss,last_generator_adversarial_loss, andlast_generator_second_lossare updated.
453 def discriminator_step(self, x, y, optimizer, amp_scaler, device, gradient_clipping_threshold): 454 """ 455 Performs a single optimization step for the discriminator. 456 457 This includes: 458 - Forward pass through the discriminator for both real and fake samples. 459 - Computing adversarial loss (binary cross-entropy) for real vs fake patches. 460 - Backpropagation and optimizer step, optionally with AMP and gradient clipping. 461 462 Parameters: 463 - x (torch.tensor): 464 Input tensor (e.g., source image). 465 - y (torch.tensor): 466 Target tensor (real image) for the discriminator. 467 - optimizer (torch.optim.Optimizer): 468 Optimizer for the discriminator parameters. 469 - amp_scaler (torch.cuda.amp.GradScaler or None): 470 Automatic mixed precision scaler. 471 - device (torch.device): 472 Device for AMP autocast. 473 - gradient_clipping_threshold (float or None): 474 Max norm for gradient clipping; if None, no clipping. 475 476 Returns: 477 - tuple(torch.tensor, torch.tensor, torch.tensor): 478 - Total discriminator loss (mean of real and fake losses). 479 - Loss for real samples. 480 - Loss for fake samples. 481 482 Notes: 483 - Fake images are detached from the generator to prevent updating its weights. 484 - `last_discriminator_loss` is updated. 485 - Supports AMP and optional gradient clipping for stability. 486 """ 487 if amp_scaler: 488 with autocast(device_type=device.type): 489 # make predictions 490 fake_y = self.generator(x).detach() # don't update generator!! 491 492 discriminator_real = self.discriminator(x, y) 493 discriminator_fake = self.discriminator(x, fake_y) 494 495 # calc loss -> 1: predictions = real, 0: predictions = fake 496 loss_real = self.adversarial_loss(discriminator_real, torch.ones_like(discriminator_real)) # torch.full_like(discriminator_real, 0.9) 497 loss_fake = self.adversarial_loss(discriminator_fake, torch.zeros_like(discriminator_fake)) 498 loss_total = (loss_real + loss_fake) * 0.5 499 500 # backward pass -> calc gradients and change the weights towards the opposite of gradients via optimizer 501 optimizer.zero_grad(set_to_none=True) 502 amp_scaler.scale(loss_total).backward() 503 if gradient_clipping_threshold: 504 # Unscale first! 505 amp_scaler.unscale_(optimizer) 506 torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=gradient_clipping_threshold) 507 amp_scaler.step(optimizer) 508 amp_scaler.update() 509 else: 510 # make predictions 511 fake_y = self.generator(x).detach() 512 513 discriminator_real = self.discriminator(x, y) 514 discriminator_fake = self.discriminator(x, fake_y) 515 516 # calc loss -> 1: predictions = real, 0: predictions = fake 517 loss_real = self.adversarial_loss(discriminator_real, torch.ones_like(discriminator_real)) # torch.full_like(discriminator_real, 0.9) 518 loss_fake = self.adversarial_loss(discriminator_fake, torch.zeros_like(discriminator_fake)) 519 loss_total = (loss_real + loss_fake) * 0.5 520 optimizer.zero_grad() 521 loss_total.backward() 522 if gradient_clipping_threshold: 523 torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=gradient_clipping_threshold) 524 optimizer.step() 525 526 self.last_discriminator_loss = loss_total.item() 527 528 return loss_total, loss_real, loss_fake
Performs a single optimization step for the discriminator.
This includes:
- Forward pass through the discriminator for both real and fake samples.
- Computing adversarial loss (binary cross-entropy) for real vs fake patches.
- Backpropagation and optimizer step, optionally with AMP and gradient clipping.
Parameters:
- x (torch.tensor): Input tensor (e.g., source image).
- y (torch.tensor): Target tensor (real image) for the discriminator.
- optimizer (torch.optim.Optimizer): Optimizer for the discriminator parameters.
- amp_scaler (torch.cuda.amp.GradScaler or None): Automatic mixed precision scaler.
- device (torch.device): Device for AMP autocast.
- gradient_clipping_threshold (float or None): Max norm for gradient clipping; if None, no clipping.
Returns:
- tuple(torch.tensor, torch.tensor, torch.tensor):
- Total discriminator loss (mean of real and fake losses).
- Loss for real samples.
- Loss for fake samples.
Notes:
- Fake images are detached from the generator to prevent updating its weights.
last_discriminator_lossis updated.- Supports AMP and optional gradient clipping for stability.