image_to_image.models.pix2pix

Module to define a Pix2Pix Model. A UNet CNN Generator combined with a small generative loss.

Functions:

  • unet_down_block
  • unet_up_block

Classes:

  • MMC
  • UNetGenerator
  • Discriminator
  • Pix2Pix

By Tobia Ippolito

  1"""
  2Module to define a Pix2Pix Model. 
  3A UNet CNN Generator combined with a small generative loss.
  4
  5Functions:
  6- unet_down_block
  7- unet_up_block
  8
  9Classes:
 10- MMC
 11- UNetGenerator
 12- Discriminator
 13- Pix2Pix
 14
 15By Tobia Ippolito
 16"""
 17# ---------------------------
 18#        > Imports <
 19# ---------------------------
 20import torch
 21import torch.nn as nn
 22import torch.nn.functional as F
 23from torch.amp import autocast
 24
 25
 26
 27# ---------------------------
 28#       > Generator <
 29# ---------------------------
 30class MMC(nn.Module):  # MinMaxClamping
 31    """
 32    Min-Max Clamping Module.
 33
 34    Clamps input tensor values between a specified minimum and maximum.
 35
 36    Parameter:
 37    - min (float): 
 38        Minimum allowed value (default=0.0).
 39    - max (float): 
 40        Maximum allowed value (default=1.0).
 41
 42    Usage:
 43    - Can be used at the output layer of a generator to ensure predictions remain in a valid range.
 44    """
 45    def __init__(self, min=0.0, max=1.0):
 46        super().__init__()
 47        self.min = min
 48        self.max = max
 49
 50    def forward(self, x):
 51        """
 52        Forward pass.
 53
 54        Parameter:
 55        - x (torch.tensor): 
 56            Input tensor.
 57
 58        Returns:
 59        - torch.tensor: Clamped tensor with values between `min` and `max`.
 60        """
 61        return torch.clamp(x, self.min, self.max)
 62
 63def unet_down_block(in_channels=1, out_channels=1, normalize=True):
 64    """
 65    Creates a U-Net downsampling block.
 66
 67    Parameter:
 68    - in_channels (int): 
 69        Number of input channels.
 70    - out_channels (int): 
 71        Number of output channels.
 72    - normalize (bool): 
 73        Whether to apply instance normalization.
 74
 75    Returns:
 76    - nn.Sequential: Convolutional downsampling block with LeakyReLU activation.
 77    """
 78    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)]
 79    if normalize:
 80        # layers += [nn.BatchNorm2d(out_channels)]
 81        layers += [nn.InstanceNorm2d(out_channels, affine=True)]
 82    layers += [nn.LeakyReLU(0.2, inplace=True)]
 83    return nn.Sequential(*layers)
 84
 85def unet_up_block(in_channels=1, out_channels=1, dropout=0.0):
 86    """
 87    Creates a U-Net upsampling block.
 88
 89    Parameter:
 90    - in_channels (int): 
 91        Number of input channels.
 92    - out_channels (int): 
 93        Number of output channels.
 94    - dropout (float): 
 95        Dropout probability (default=0).
 96
 97    Returns:
 98    - nn.Sequential: Transposed convolutional block with ReLU activation and optional dropout.
 99    """
100    layers = [
101        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
102        # nn.BatchNorm2d(out_channels),
103        nn.InstanceNorm2d(out_channels, affine=True),
104        nn.ReLU(inplace=True)
105    ]
106    
107    if dropout:
108        layers += [nn.Dropout(dropout)]
109    return nn.Sequential(*layers)
110
111class UNetGenerator(nn.Module):
112    """
113    U-Net Generator for image-to-image translation.
114
115    Architecture:
116    - 8 downsampling blocks (encoder)
117    - 8 upsampling blocks (decoder) with skip connections
118    - Sigmoid activation at output for [0,1] pixel normalization
119
120    Parameter:
121    - input_channels (int): 
122        Number of input image channels.
123    - output_channels (int): 
124        Number of output image channels.
125    - hidden_channels (int): 
126        Base hidden channels for first layer.
127    """
128    def __init__(self, input_channels=1, output_channels=1, hidden_channels=64):
129        super().__init__()
130        # Encoder
131        self.down1 = unet_down_block(input_channels, hidden_channels, normalize=False) # 128
132        self.down2 = unet_down_block(hidden_channels, hidden_channels*2)                    # 64
133        self.down3 = unet_down_block(hidden_channels*2, hidden_channels*4)                  # 32
134        self.down4 = unet_down_block(hidden_channels*4, hidden_channels*8)                  # 16
135        self.down5 = unet_down_block(hidden_channels*8, hidden_channels*8)                  # 8
136        self.down6 = unet_down_block(hidden_channels*8, hidden_channels*8)                  # 4
137        self.down7 = unet_down_block(hidden_channels*8, hidden_channels*8)                  # 2
138        self.down8 = unet_down_block(hidden_channels*8, hidden_channels*8, normalize=False) # 1
139
140        # Decoder
141        self.up1 = unet_up_block(hidden_channels*8, hidden_channels*8, dropout=0.5)
142        self.up2 = unet_up_block(hidden_channels*16, hidden_channels*8, dropout=0.5)
143        self.up3 = unet_up_block(hidden_channels*16, hidden_channels*8, dropout=0.5)
144        self.up4 = unet_up_block(hidden_channels*16, hidden_channels*8)
145        self.up5 = unet_up_block(hidden_channels*16, hidden_channels*4)
146        self.up6 = unet_up_block(hidden_channels*8, hidden_channels*2)
147        self.up7 = unet_up_block(hidden_channels*4, hidden_channels)
148        self.up8 = nn.Sequential(
149            nn.ConvTranspose2d(hidden_channels*2, output_channels, 4, 2, 1),
150            # nn.Tanh()
151            ## MMC(min=0.0, max=1.0)
152            nn.Sigmoid()
153        )
154
155    def forward(self, x):
156        """
157        Forward pass of the U-Net generator.
158
159        Parameter:
160        - x (torch.tensor): 
161            Input image tensor (batch_size, input_channels, H, W).
162
163        Returns:
164        - torch.tensor: Generated output tensor (batch_size, output_channels, H, W).
165        """
166        d1 = self.down1(x)
167        d2 = self.down2(d1)
168        d3 = self.down3(d2)
169        d4 = self.down4(d3)
170        d5 = self.down5(d4)
171        d6 = self.down6(d5)
172        d7 = self.down7(d6)
173        d8 = self.down8(d7)
174
175        u1 = self.up1(d8)
176        u2 = self.up2(torch.cat([u1, d7], 1))
177        u3 = self.up3(torch.cat([u2, d6], 1))
178        u4 = self.up4(torch.cat([u3, d5], 1))
179        u5 = self.up5(torch.cat([u4, d4], 1))
180        u6 = self.up6(torch.cat([u5, d3], 1))
181        u7 = self.up7(torch.cat([u6, d2], 1))
182        u8 = self.up8(torch.cat([u7, d1], 1))
183
184        return u8
185
186
187
188# ---------------------------
189#      > Discriminator <
190# ---------------------------
191# PatchGAN
192class Discriminator(nn.Module):
193    """
194    PatchGAN Discriminator for Pix2Pix GAN.
195
196    Parameter:
197    - input_channels (int): 
198        Number of input channels (typically input + target channels concatenated).
199    - hidden_channels (int): 
200        Base hidden channels for first layer.
201
202    Architecture:
203    - 5 convolutional blocks with LeakyReLU and batch normalization.
204    - Outputs a 2D patch map of predictions.
205    """
206    def __init__(self, input_channels=6, hidden_channels=64):
207        """
208        Initializes a PatchGAN discriminator.
209
210        The discriminator evaluates input-target image pairs to determine
211        if they are real or generated (fake). It progressively downsamples
212        the spatial dimensions while increasing the number of feature channels.
213
214        Parameters:
215        - input_channels (int): 
216            Number of input channels, typically input + target concatenated (default=6).
217        - hidden_channels (int): 
218            Number of channels in the first convolutional layer; doubled in subsequent layers (default=64).
219        """
220        super().__init__()
221        self.model = nn.Sequential(
222            nn.Conv2d(input_channels, hidden_channels, kernel_size=4, stride=2, padding=1),
223            nn.LeakyReLU(0.2, inplace=True),
224
225            nn.Conv2d(hidden_channels, hidden_channels*2, kernel_size=4, stride=2, padding=1, bias=False),
226            nn.BatchNorm2d(hidden_channels*2),
227            nn.LeakyReLU(0.2, inplace=True),
228
229            nn.Conv2d(hidden_channels*2, hidden_channels*4, kernel_size=4, stride=2, padding=1, bias=False),
230            nn.BatchNorm2d(hidden_channels*4),
231            nn.LeakyReLU(0.2, inplace=True),
232
233            nn.Conv2d(hidden_channels*4, hidden_channels*8, kernel_size=4, stride=1, padding=1, bias=False),
234            nn.BatchNorm2d(hidden_channels*8),
235            nn.LeakyReLU(0.2, inplace=True),
236
237            nn.Conv2d(hidden_channels*8, 1, kernel_size=4, stride=1, padding=1)
238        )
239
240    def forward(self, x, y):
241        """
242        Forward pass of the discriminator.
243
244        Parameter:
245        - x (torch.tensor): 
246            Input image tensor.
247        - y (torch.tensor): 
248            Target or generated image tensor.
249
250        Returns:
251        - torch.tensor: PatchGAN output tensor predicting real/fake for each patch.
252        """
253        # concatenate input and target channels
254        return self.model(torch.cat([x, y], dim=1))
255
256
257
258# ---------------------------
259#         > Pix2Pix <
260# ---------------------------
261class Pix2Pix(nn.Module):
262    """
263    Pix2Pix GAN for image-to-image translation.
264
265    Components:
266    - Generator: U-Net generator producing synthetic images.
267    - Discriminator: PatchGAN discriminator evaluating real vs fake images.
268    - Adversarial loss: Binary cross-entropy.
269    - Optional second loss for pixel-wise supervision.
270
271    Parameter:
272    - input_channels (int): 
273        Number of input channels.
274    - output_channels (int): 
275        Number of output channels.
276    - hidden_channels (int): 
277        Base hidden channels for both generator and discriminator.
278    - second_loss (nn.Module): 
279        Optional secondary loss (default: L1Loss).
280    - lambda_second (float): 
281        Weight for secondary loss in generator optimization.
282    """
283    def __init__(self, input_channels=1, output_channels=1, hidden_channels=64, 
284                 second_loss=nn.L1Loss(), lambda_second=100):
285        """
286        Initializes the Pix2Pix GAN model.
287
288        Components:
289        - Generator: U-Net architecture for producing synthetic images.
290        - Discriminator: PatchGAN for evaluating real vs. fake images.
291        - Adversarial loss: Binary cross-entropy to train the generator to fool the discriminator.
292        - Optional secondary loss: Pixel-wise supervision (default: L1Loss).
293
294        Parameter:
295        - input_channels (int): 
296            Number of channels in the input images (default=1).
297        - output_channels (int): 
298            Number of channels in the output images (default=1).
299        - hidden_channels (int): 
300            Base number of hidden channels in the generator and discriminator (default=64).
301        - second_loss (nn.Module): 
302            Optional secondary loss for the generator (default: nn.L1Loss()).
303        - lambda_second (float): 
304            Weight applied to the secondary loss in generator optimization (default=100).
305        """
306        super().__init__()
307        self.input_channels = input_channels
308        self.output_channels = output_channels
309
310        self.generator = UNetGenerator(input_channels=input_channels, 
311                                       output_channels=output_channels, 
312                                       hidden_channels=hidden_channels)
313        self.discriminator = Discriminator(input_channels=input_channels+output_channels, hidden_channels=hidden_channels)
314
315        self.adversarial_loss = nn.BCEWithLogitsLoss()
316        self.second_loss = second_loss
317        self.lambda_second = lambda_second
318
319        self.last_generator_loss = float("inf")
320        self.last_generator_adversarial_loss = float("inf")
321        self.last_generator_second_loss = float("inf")
322        self.last_discriminator_loss = float("inf")
323
324    def get_input_channels(self):
325        """
326        Returns the number of input channels used by the model.
327
328        Returns:
329        - int: 
330            Number of input channels expected by the model.
331        """
332        return self.input_channels
333    
334    def get_output_channels(self):
335        """
336        Returns the number of output channels produced by the model.
337
338        Returns:
339        - int: 
340            Number of output channels the model generates
341        """
342        return self.output_channels
343
344    def get_dict(self):
345        """
346        Returns a dictionary with the most recent loss values.
347
348        Returns:
349        - dict: Loss components (base, complex).
350
351        Notes:
352        - Useful for logging or monitoring training progress.
353        """
354        return {
355                f"loss_generator": self.last_generator_loss, 
356                f"loss_generator_adversarial": self.last_generator_adversarial_loss, 
357                f"loss_generator_second": self.last_generator_second_loss,
358                f"loss_discriminator": self.last_discriminator_loss
359               }
360
361    def forward(self, x):
362        """
363        Forward pass through the generator.
364
365        Parameter:
366        - x (torch.tensor): 
367            Input tensor.
368
369        Returns:
370        - torch.tensor: Generated output image.
371        """
372        return self.generator(x)
373
374    def generator_step(self, x, y, optimizer, amp_scaler, device, gradient_clipping_threshold):
375        """
376        Performs a single optimization step for the generator.
377
378        This includes:
379        - Forward pass through the generator and discriminator.
380        - Computing adversarial loss (generator tries to fool the discriminator).
381        - Computing optional secondary loss (e.g., L1 or MSE).
382        - Backpropagation and optimizer step, optionally with AMP and gradient clipping.
383
384        Parameters:
385        - x (torch.tensor): 
386            Input tensor for the generator (e.g., source image).
387        - y (torch.tensor): 
388            Target tensor for supervised secondary loss.
389        - optimizer (torch.optim.Optimizer): 
390            Optimizer for the generator parameters.
391        - amp_scaler (torch.cuda.amp.GradScaler or None): 
392            Automatic mixed precision scaler.
393        - device (torch.device): 
394            Device for AMP autocast.
395        - gradient_clipping_threshold (float or None): 
396            Max norm for gradient clipping; if None, no clipping.
397
398        Returns:
399        - tuple(torch.tensor, torch.tensor, torch.tensor):
400            - Total generator loss (adversarial + secondary).
401            - Adversarial loss component.
402            - Secondary loss component (weighted by `lambda_second`).
403
404        Notes:
405        - If AMP is enabled, gradients are scaled and unscaled appropriately.
406        - `last_generator_loss`, `last_generator_adversarial_loss`, and `last_generator_second_loss` are updated.
407        """
408        if amp_scaler:
409            with autocast(device_type=device.type):
410                # make predictions
411                fake_y = self.generator(x)
412
413                discriminator_fake = self.discriminator(x, fake_y)
414
415                # calc loss -> discriminator thinks it is real?
416                loss_adversarial = self.adversarial_loss(discriminator_fake, torch.ones_like(discriminator_fake))
417                loss_second = self.second_loss(fake_y, y) * self.lambda_second
418                loss_total = loss_adversarial + loss_second
419
420            # backward pass -> calc gradients and change the weights towards the opposite of gradients via optimizer
421    
422            optimizer.zero_grad(set_to_none=True)
423            amp_scaler.scale(loss_total).backward()
424            if gradient_clipping_threshold:
425                # Unscale first!
426                amp_scaler.unscale_(optimizer)
427                torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=gradient_clipping_threshold)
428            amp_scaler.step(optimizer)
429            amp_scaler.update()
430        else:
431            # make predictions
432            fake_y = self.generator(x)
433
434            discriminator_fake = self.discriminator(x, fake_y)
435
436            # calc loss -> discriminator thinks it is real?
437            loss_adversarial = self.adversarial_loss(discriminator_fake, torch.ones_like(discriminator_fake))
438            loss_second = self.second_loss(fake_y, y) * self.lambda_second
439            loss_total = loss_adversarial + loss_second
440            optimizer.zero_grad()
441            loss_total.backward()
442            if gradient_clipping_threshold:
443                torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=gradient_clipping_threshold)
444            optimizer.step()
445
446        self.last_generator_loss = loss_total.item()
447        self.last_generator_adversarial_loss = loss_adversarial.item()
448        self.last_generator_second_loss = loss_second.item()
449
450        return loss_total, loss_adversarial, loss_second
451
452    def discriminator_step(self, x, y, optimizer, amp_scaler, device, gradient_clipping_threshold):
453        """
454        Performs a single optimization step for the discriminator.
455
456        This includes:
457        - Forward pass through the discriminator for both real and fake samples.
458        - Computing adversarial loss (binary cross-entropy) for real vs fake patches.
459        - Backpropagation and optimizer step, optionally with AMP and gradient clipping.
460
461        Parameters:
462        - x (torch.tensor): 
463            Input tensor (e.g., source image).
464        - y (torch.tensor): 
465            Target tensor (real image) for the discriminator.
466        - optimizer (torch.optim.Optimizer): 
467            Optimizer for the discriminator parameters.
468        - amp_scaler (torch.cuda.amp.GradScaler or None): 
469            Automatic mixed precision scaler.
470        - device (torch.device): 
471            Device for AMP autocast.
472        - gradient_clipping_threshold (float or None): 
473            Max norm for gradient clipping; if None, no clipping.
474
475        Returns:
476        - tuple(torch.tensor, torch.tensor, torch.tensor):
477            - Total discriminator loss (mean of real and fake losses).
478            - Loss for real samples.
479            - Loss for fake samples.
480
481        Notes:
482        - Fake images are detached from the generator to prevent updating its weights.
483        - `last_discriminator_loss` is updated.
484        - Supports AMP and optional gradient clipping for stability.
485        """
486        if amp_scaler:
487            with autocast(device_type=device.type): 
488                # make predictions
489                fake_y = self.generator(x).detach()  # don't update generator!!
490
491                discriminator_real = self.discriminator(x, y)
492                discriminator_fake = self.discriminator(x, fake_y)
493
494                # calc loss -> 1: predictions = real, 0: predictions = fake
495                loss_real = self.adversarial_loss(discriminator_real, torch.ones_like(discriminator_real))  # torch.full_like(discriminator_real, 0.9)
496                loss_fake = self.adversarial_loss(discriminator_fake, torch.zeros_like(discriminator_fake))
497                loss_total = (loss_real + loss_fake) * 0.5
498
499            # backward pass -> calc gradients and change the weights towards the opposite of gradients via optimizer
500            optimizer.zero_grad(set_to_none=True)
501            amp_scaler.scale(loss_total).backward()
502            if gradient_clipping_threshold:
503                # Unscale first!
504                amp_scaler.unscale_(optimizer)
505                torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=gradient_clipping_threshold)
506            amp_scaler.step(optimizer)
507            amp_scaler.update()
508        else:
509            # make predictions
510            fake_y = self.generator(x).detach()
511
512            discriminator_real = self.discriminator(x, y)
513            discriminator_fake = self.discriminator(x, fake_y)
514
515            # calc loss -> 1: predictions = real, 0: predictions = fake
516            loss_real = self.adversarial_loss(discriminator_real, torch.ones_like(discriminator_real))  # torch.full_like(discriminator_real, 0.9)
517            loss_fake = self.adversarial_loss(discriminator_fake, torch.zeros_like(discriminator_fake))
518            loss_total = (loss_real + loss_fake) * 0.5
519            optimizer.zero_grad()
520            loss_total.backward()
521            if gradient_clipping_threshold:
522                torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=gradient_clipping_threshold)
523            optimizer.step()
524
525        self.last_discriminator_loss = loss_total.item()
526
527        return loss_total, loss_real, loss_fake
class MMC(torch.nn.modules.module.Module):
31class MMC(nn.Module):  # MinMaxClamping
32    """
33    Min-Max Clamping Module.
34
35    Clamps input tensor values between a specified minimum and maximum.
36
37    Parameter:
38    - min (float): 
39        Minimum allowed value (default=0.0).
40    - max (float): 
41        Maximum allowed value (default=1.0).
42
43    Usage:
44    - Can be used at the output layer of a generator to ensure predictions remain in a valid range.
45    """
46    def __init__(self, min=0.0, max=1.0):
47        super().__init__()
48        self.min = min
49        self.max = max
50
51    def forward(self, x):
52        """
53        Forward pass.
54
55        Parameter:
56        - x (torch.tensor): 
57            Input tensor.
58
59        Returns:
60        - torch.tensor: Clamped tensor with values between `min` and `max`.
61        """
62        return torch.clamp(x, self.min, self.max)

Min-Max Clamping Module.

Clamps input tensor values between a specified minimum and maximum.

Parameter:

  • min (float): Minimum allowed value (default=0.0).
  • max (float): Maximum allowed value (default=1.0).

Usage:

  • Can be used at the output layer of a generator to ensure predictions remain in a valid range.
MMC(min=0.0, max=1.0)
46    def __init__(self, min=0.0, max=1.0):
47        super().__init__()
48        self.min = min
49        self.max = max

Initialize internal Module state, shared by both nn.Module and ScriptModule.

min
max
def forward(self, x):
51    def forward(self, x):
52        """
53        Forward pass.
54
55        Parameter:
56        - x (torch.tensor): 
57            Input tensor.
58
59        Returns:
60        - torch.tensor: Clamped tensor with values between `min` and `max`.
61        """
62        return torch.clamp(x, self.min, self.max)

Forward pass.

Parameter:

  • x (torch.tensor): Input tensor.

Returns:

  • torch.tensor: Clamped tensor with values between min and max.
def unet_down_block(in_channels=1, out_channels=1, normalize=True):
64def unet_down_block(in_channels=1, out_channels=1, normalize=True):
65    """
66    Creates a U-Net downsampling block.
67
68    Parameter:
69    - in_channels (int): 
70        Number of input channels.
71    - out_channels (int): 
72        Number of output channels.
73    - normalize (bool): 
74        Whether to apply instance normalization.
75
76    Returns:
77    - nn.Sequential: Convolutional downsampling block with LeakyReLU activation.
78    """
79    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)]
80    if normalize:
81        # layers += [nn.BatchNorm2d(out_channels)]
82        layers += [nn.InstanceNorm2d(out_channels, affine=True)]
83    layers += [nn.LeakyReLU(0.2, inplace=True)]
84    return nn.Sequential(*layers)

Creates a U-Net downsampling block.

Parameter:

  • in_channels (int): Number of input channels.
  • out_channels (int): Number of output channels.
  • normalize (bool): Whether to apply instance normalization.

Returns:

  • nn.Sequential: Convolutional downsampling block with LeakyReLU activation.
def unet_up_block(in_channels=1, out_channels=1, dropout=0.0):
 86def unet_up_block(in_channels=1, out_channels=1, dropout=0.0):
 87    """
 88    Creates a U-Net upsampling block.
 89
 90    Parameter:
 91    - in_channels (int): 
 92        Number of input channels.
 93    - out_channels (int): 
 94        Number of output channels.
 95    - dropout (float): 
 96        Dropout probability (default=0).
 97
 98    Returns:
 99    - nn.Sequential: Transposed convolutional block with ReLU activation and optional dropout.
100    """
101    layers = [
102        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
103        # nn.BatchNorm2d(out_channels),
104        nn.InstanceNorm2d(out_channels, affine=True),
105        nn.ReLU(inplace=True)
106    ]
107    
108    if dropout:
109        layers += [nn.Dropout(dropout)]
110    return nn.Sequential(*layers)

Creates a U-Net upsampling block.

Parameter:

  • in_channels (int): Number of input channels.
  • out_channels (int): Number of output channels.
  • dropout (float): Dropout probability (default=0).

Returns:

  • nn.Sequential: Transposed convolutional block with ReLU activation and optional dropout.
class UNetGenerator(torch.nn.modules.module.Module):
112class UNetGenerator(nn.Module):
113    """
114    U-Net Generator for image-to-image translation.
115
116    Architecture:
117    - 8 downsampling blocks (encoder)
118    - 8 upsampling blocks (decoder) with skip connections
119    - Sigmoid activation at output for [0,1] pixel normalization
120
121    Parameter:
122    - input_channels (int): 
123        Number of input image channels.
124    - output_channels (int): 
125        Number of output image channels.
126    - hidden_channels (int): 
127        Base hidden channels for first layer.
128    """
129    def __init__(self, input_channels=1, output_channels=1, hidden_channels=64):
130        super().__init__()
131        # Encoder
132        self.down1 = unet_down_block(input_channels, hidden_channels, normalize=False) # 128
133        self.down2 = unet_down_block(hidden_channels, hidden_channels*2)                    # 64
134        self.down3 = unet_down_block(hidden_channels*2, hidden_channels*4)                  # 32
135        self.down4 = unet_down_block(hidden_channels*4, hidden_channels*8)                  # 16
136        self.down5 = unet_down_block(hidden_channels*8, hidden_channels*8)                  # 8
137        self.down6 = unet_down_block(hidden_channels*8, hidden_channels*8)                  # 4
138        self.down7 = unet_down_block(hidden_channels*8, hidden_channels*8)                  # 2
139        self.down8 = unet_down_block(hidden_channels*8, hidden_channels*8, normalize=False) # 1
140
141        # Decoder
142        self.up1 = unet_up_block(hidden_channels*8, hidden_channels*8, dropout=0.5)
143        self.up2 = unet_up_block(hidden_channels*16, hidden_channels*8, dropout=0.5)
144        self.up3 = unet_up_block(hidden_channels*16, hidden_channels*8, dropout=0.5)
145        self.up4 = unet_up_block(hidden_channels*16, hidden_channels*8)
146        self.up5 = unet_up_block(hidden_channels*16, hidden_channels*4)
147        self.up6 = unet_up_block(hidden_channels*8, hidden_channels*2)
148        self.up7 = unet_up_block(hidden_channels*4, hidden_channels)
149        self.up8 = nn.Sequential(
150            nn.ConvTranspose2d(hidden_channels*2, output_channels, 4, 2, 1),
151            # nn.Tanh()
152            ## MMC(min=0.0, max=1.0)
153            nn.Sigmoid()
154        )
155
156    def forward(self, x):
157        """
158        Forward pass of the U-Net generator.
159
160        Parameter:
161        - x (torch.tensor): 
162            Input image tensor (batch_size, input_channels, H, W).
163
164        Returns:
165        - torch.tensor: Generated output tensor (batch_size, output_channels, H, W).
166        """
167        d1 = self.down1(x)
168        d2 = self.down2(d1)
169        d3 = self.down3(d2)
170        d4 = self.down4(d3)
171        d5 = self.down5(d4)
172        d6 = self.down6(d5)
173        d7 = self.down7(d6)
174        d8 = self.down8(d7)
175
176        u1 = self.up1(d8)
177        u2 = self.up2(torch.cat([u1, d7], 1))
178        u3 = self.up3(torch.cat([u2, d6], 1))
179        u4 = self.up4(torch.cat([u3, d5], 1))
180        u5 = self.up5(torch.cat([u4, d4], 1))
181        u6 = self.up6(torch.cat([u5, d3], 1))
182        u7 = self.up7(torch.cat([u6, d2], 1))
183        u8 = self.up8(torch.cat([u7, d1], 1))
184
185        return u8

U-Net Generator for image-to-image translation.

Architecture:

  • 8 downsampling blocks (encoder)
  • 8 upsampling blocks (decoder) with skip connections
  • Sigmoid activation at output for [0,1] pixel normalization

Parameter:

  • input_channels (int): Number of input image channels.
  • output_channels (int): Number of output image channels.
  • hidden_channels (int): Base hidden channels for first layer.
UNetGenerator(input_channels=1, output_channels=1, hidden_channels=64)
129    def __init__(self, input_channels=1, output_channels=1, hidden_channels=64):
130        super().__init__()
131        # Encoder
132        self.down1 = unet_down_block(input_channels, hidden_channels, normalize=False) # 128
133        self.down2 = unet_down_block(hidden_channels, hidden_channels*2)                    # 64
134        self.down3 = unet_down_block(hidden_channels*2, hidden_channels*4)                  # 32
135        self.down4 = unet_down_block(hidden_channels*4, hidden_channels*8)                  # 16
136        self.down5 = unet_down_block(hidden_channels*8, hidden_channels*8)                  # 8
137        self.down6 = unet_down_block(hidden_channels*8, hidden_channels*8)                  # 4
138        self.down7 = unet_down_block(hidden_channels*8, hidden_channels*8)                  # 2
139        self.down8 = unet_down_block(hidden_channels*8, hidden_channels*8, normalize=False) # 1
140
141        # Decoder
142        self.up1 = unet_up_block(hidden_channels*8, hidden_channels*8, dropout=0.5)
143        self.up2 = unet_up_block(hidden_channels*16, hidden_channels*8, dropout=0.5)
144        self.up3 = unet_up_block(hidden_channels*16, hidden_channels*8, dropout=0.5)
145        self.up4 = unet_up_block(hidden_channels*16, hidden_channels*8)
146        self.up5 = unet_up_block(hidden_channels*16, hidden_channels*4)
147        self.up6 = unet_up_block(hidden_channels*8, hidden_channels*2)
148        self.up7 = unet_up_block(hidden_channels*4, hidden_channels)
149        self.up8 = nn.Sequential(
150            nn.ConvTranspose2d(hidden_channels*2, output_channels, 4, 2, 1),
151            # nn.Tanh()
152            ## MMC(min=0.0, max=1.0)
153            nn.Sigmoid()
154        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

down1
down2
down3
down4
down5
down6
down7
down8
up1
up2
up3
up4
up5
up6
up7
up8
def forward(self, x):
156    def forward(self, x):
157        """
158        Forward pass of the U-Net generator.
159
160        Parameter:
161        - x (torch.tensor): 
162            Input image tensor (batch_size, input_channels, H, W).
163
164        Returns:
165        - torch.tensor: Generated output tensor (batch_size, output_channels, H, W).
166        """
167        d1 = self.down1(x)
168        d2 = self.down2(d1)
169        d3 = self.down3(d2)
170        d4 = self.down4(d3)
171        d5 = self.down5(d4)
172        d6 = self.down6(d5)
173        d7 = self.down7(d6)
174        d8 = self.down8(d7)
175
176        u1 = self.up1(d8)
177        u2 = self.up2(torch.cat([u1, d7], 1))
178        u3 = self.up3(torch.cat([u2, d6], 1))
179        u4 = self.up4(torch.cat([u3, d5], 1))
180        u5 = self.up5(torch.cat([u4, d4], 1))
181        u6 = self.up6(torch.cat([u5, d3], 1))
182        u7 = self.up7(torch.cat([u6, d2], 1))
183        u8 = self.up8(torch.cat([u7, d1], 1))
184
185        return u8

Forward pass of the U-Net generator.

Parameter:

  • x (torch.tensor): Input image tensor (batch_size, input_channels, H, W).

Returns:

  • torch.tensor: Generated output tensor (batch_size, output_channels, H, W).
class Discriminator(torch.nn.modules.module.Module):
193class Discriminator(nn.Module):
194    """
195    PatchGAN Discriminator for Pix2Pix GAN.
196
197    Parameter:
198    - input_channels (int): 
199        Number of input channels (typically input + target channels concatenated).
200    - hidden_channels (int): 
201        Base hidden channels for first layer.
202
203    Architecture:
204    - 5 convolutional blocks with LeakyReLU and batch normalization.
205    - Outputs a 2D patch map of predictions.
206    """
207    def __init__(self, input_channels=6, hidden_channels=64):
208        """
209        Initializes a PatchGAN discriminator.
210
211        The discriminator evaluates input-target image pairs to determine
212        if they are real or generated (fake). It progressively downsamples
213        the spatial dimensions while increasing the number of feature channels.
214
215        Parameters:
216        - input_channels (int): 
217            Number of input channels, typically input + target concatenated (default=6).
218        - hidden_channels (int): 
219            Number of channels in the first convolutional layer; doubled in subsequent layers (default=64).
220        """
221        super().__init__()
222        self.model = nn.Sequential(
223            nn.Conv2d(input_channels, hidden_channels, kernel_size=4, stride=2, padding=1),
224            nn.LeakyReLU(0.2, inplace=True),
225
226            nn.Conv2d(hidden_channels, hidden_channels*2, kernel_size=4, stride=2, padding=1, bias=False),
227            nn.BatchNorm2d(hidden_channels*2),
228            nn.LeakyReLU(0.2, inplace=True),
229
230            nn.Conv2d(hidden_channels*2, hidden_channels*4, kernel_size=4, stride=2, padding=1, bias=False),
231            nn.BatchNorm2d(hidden_channels*4),
232            nn.LeakyReLU(0.2, inplace=True),
233
234            nn.Conv2d(hidden_channels*4, hidden_channels*8, kernel_size=4, stride=1, padding=1, bias=False),
235            nn.BatchNorm2d(hidden_channels*8),
236            nn.LeakyReLU(0.2, inplace=True),
237
238            nn.Conv2d(hidden_channels*8, 1, kernel_size=4, stride=1, padding=1)
239        )
240
241    def forward(self, x, y):
242        """
243        Forward pass of the discriminator.
244
245        Parameter:
246        - x (torch.tensor): 
247            Input image tensor.
248        - y (torch.tensor): 
249            Target or generated image tensor.
250
251        Returns:
252        - torch.tensor: PatchGAN output tensor predicting real/fake for each patch.
253        """
254        # concatenate input and target channels
255        return self.model(torch.cat([x, y], dim=1))

PatchGAN Discriminator for Pix2Pix GAN.

Parameter:

  • input_channels (int): Number of input channels (typically input + target channels concatenated).
  • hidden_channels (int): Base hidden channels for first layer.

Architecture:

  • 5 convolutional blocks with LeakyReLU and batch normalization.
  • Outputs a 2D patch map of predictions.
Discriminator(input_channels=6, hidden_channels=64)
207    def __init__(self, input_channels=6, hidden_channels=64):
208        """
209        Initializes a PatchGAN discriminator.
210
211        The discriminator evaluates input-target image pairs to determine
212        if they are real or generated (fake). It progressively downsamples
213        the spatial dimensions while increasing the number of feature channels.
214
215        Parameters:
216        - input_channels (int): 
217            Number of input channels, typically input + target concatenated (default=6).
218        - hidden_channels (int): 
219            Number of channels in the first convolutional layer; doubled in subsequent layers (default=64).
220        """
221        super().__init__()
222        self.model = nn.Sequential(
223            nn.Conv2d(input_channels, hidden_channels, kernel_size=4, stride=2, padding=1),
224            nn.LeakyReLU(0.2, inplace=True),
225
226            nn.Conv2d(hidden_channels, hidden_channels*2, kernel_size=4, stride=2, padding=1, bias=False),
227            nn.BatchNorm2d(hidden_channels*2),
228            nn.LeakyReLU(0.2, inplace=True),
229
230            nn.Conv2d(hidden_channels*2, hidden_channels*4, kernel_size=4, stride=2, padding=1, bias=False),
231            nn.BatchNorm2d(hidden_channels*4),
232            nn.LeakyReLU(0.2, inplace=True),
233
234            nn.Conv2d(hidden_channels*4, hidden_channels*8, kernel_size=4, stride=1, padding=1, bias=False),
235            nn.BatchNorm2d(hidden_channels*8),
236            nn.LeakyReLU(0.2, inplace=True),
237
238            nn.Conv2d(hidden_channels*8, 1, kernel_size=4, stride=1, padding=1)
239        )

Initializes a PatchGAN discriminator.

The discriminator evaluates input-target image pairs to determine if they are real or generated (fake). It progressively downsamples the spatial dimensions while increasing the number of feature channels.

Parameters:

  • input_channels (int): Number of input channels, typically input + target concatenated (default=6).
  • hidden_channels (int): Number of channels in the first convolutional layer; doubled in subsequent layers (default=64).
model
def forward(self, x, y):
241    def forward(self, x, y):
242        """
243        Forward pass of the discriminator.
244
245        Parameter:
246        - x (torch.tensor): 
247            Input image tensor.
248        - y (torch.tensor): 
249            Target or generated image tensor.
250
251        Returns:
252        - torch.tensor: PatchGAN output tensor predicting real/fake for each patch.
253        """
254        # concatenate input and target channels
255        return self.model(torch.cat([x, y], dim=1))

Forward pass of the discriminator.

Parameter:

  • x (torch.tensor): Input image tensor.
  • y (torch.tensor): Target or generated image tensor.

Returns:

  • torch.tensor: PatchGAN output tensor predicting real/fake for each patch.
class Pix2Pix(torch.nn.modules.module.Module):
262class Pix2Pix(nn.Module):
263    """
264    Pix2Pix GAN for image-to-image translation.
265
266    Components:
267    - Generator: U-Net generator producing synthetic images.
268    - Discriminator: PatchGAN discriminator evaluating real vs fake images.
269    - Adversarial loss: Binary cross-entropy.
270    - Optional second loss for pixel-wise supervision.
271
272    Parameter:
273    - input_channels (int): 
274        Number of input channels.
275    - output_channels (int): 
276        Number of output channels.
277    - hidden_channels (int): 
278        Base hidden channels for both generator and discriminator.
279    - second_loss (nn.Module): 
280        Optional secondary loss (default: L1Loss).
281    - lambda_second (float): 
282        Weight for secondary loss in generator optimization.
283    """
284    def __init__(self, input_channels=1, output_channels=1, hidden_channels=64, 
285                 second_loss=nn.L1Loss(), lambda_second=100):
286        """
287        Initializes the Pix2Pix GAN model.
288
289        Components:
290        - Generator: U-Net architecture for producing synthetic images.
291        - Discriminator: PatchGAN for evaluating real vs. fake images.
292        - Adversarial loss: Binary cross-entropy to train the generator to fool the discriminator.
293        - Optional secondary loss: Pixel-wise supervision (default: L1Loss).
294
295        Parameter:
296        - input_channels (int): 
297            Number of channels in the input images (default=1).
298        - output_channels (int): 
299            Number of channels in the output images (default=1).
300        - hidden_channels (int): 
301            Base number of hidden channels in the generator and discriminator (default=64).
302        - second_loss (nn.Module): 
303            Optional secondary loss for the generator (default: nn.L1Loss()).
304        - lambda_second (float): 
305            Weight applied to the secondary loss in generator optimization (default=100).
306        """
307        super().__init__()
308        self.input_channels = input_channels
309        self.output_channels = output_channels
310
311        self.generator = UNetGenerator(input_channels=input_channels, 
312                                       output_channels=output_channels, 
313                                       hidden_channels=hidden_channels)
314        self.discriminator = Discriminator(input_channels=input_channels+output_channels, hidden_channels=hidden_channels)
315
316        self.adversarial_loss = nn.BCEWithLogitsLoss()
317        self.second_loss = second_loss
318        self.lambda_second = lambda_second
319
320        self.last_generator_loss = float("inf")
321        self.last_generator_adversarial_loss = float("inf")
322        self.last_generator_second_loss = float("inf")
323        self.last_discriminator_loss = float("inf")
324
325    def get_input_channels(self):
326        """
327        Returns the number of input channels used by the model.
328
329        Returns:
330        - int: 
331            Number of input channels expected by the model.
332        """
333        return self.input_channels
334    
335    def get_output_channels(self):
336        """
337        Returns the number of output channels produced by the model.
338
339        Returns:
340        - int: 
341            Number of output channels the model generates
342        """
343        return self.output_channels
344
345    def get_dict(self):
346        """
347        Returns a dictionary with the most recent loss values.
348
349        Returns:
350        - dict: Loss components (base, complex).
351
352        Notes:
353        - Useful for logging or monitoring training progress.
354        """
355        return {
356                f"loss_generator": self.last_generator_loss, 
357                f"loss_generator_adversarial": self.last_generator_adversarial_loss, 
358                f"loss_generator_second": self.last_generator_second_loss,
359                f"loss_discriminator": self.last_discriminator_loss
360               }
361
362    def forward(self, x):
363        """
364        Forward pass through the generator.
365
366        Parameter:
367        - x (torch.tensor): 
368            Input tensor.
369
370        Returns:
371        - torch.tensor: Generated output image.
372        """
373        return self.generator(x)
374
375    def generator_step(self, x, y, optimizer, amp_scaler, device, gradient_clipping_threshold):
376        """
377        Performs a single optimization step for the generator.
378
379        This includes:
380        - Forward pass through the generator and discriminator.
381        - Computing adversarial loss (generator tries to fool the discriminator).
382        - Computing optional secondary loss (e.g., L1 or MSE).
383        - Backpropagation and optimizer step, optionally with AMP and gradient clipping.
384
385        Parameters:
386        - x (torch.tensor): 
387            Input tensor for the generator (e.g., source image).
388        - y (torch.tensor): 
389            Target tensor for supervised secondary loss.
390        - optimizer (torch.optim.Optimizer): 
391            Optimizer for the generator parameters.
392        - amp_scaler (torch.cuda.amp.GradScaler or None): 
393            Automatic mixed precision scaler.
394        - device (torch.device): 
395            Device for AMP autocast.
396        - gradient_clipping_threshold (float or None): 
397            Max norm for gradient clipping; if None, no clipping.
398
399        Returns:
400        - tuple(torch.tensor, torch.tensor, torch.tensor):
401            - Total generator loss (adversarial + secondary).
402            - Adversarial loss component.
403            - Secondary loss component (weighted by `lambda_second`).
404
405        Notes:
406        - If AMP is enabled, gradients are scaled and unscaled appropriately.
407        - `last_generator_loss`, `last_generator_adversarial_loss`, and `last_generator_second_loss` are updated.
408        """
409        if amp_scaler:
410            with autocast(device_type=device.type):
411                # make predictions
412                fake_y = self.generator(x)
413
414                discriminator_fake = self.discriminator(x, fake_y)
415
416                # calc loss -> discriminator thinks it is real?
417                loss_adversarial = self.adversarial_loss(discriminator_fake, torch.ones_like(discriminator_fake))
418                loss_second = self.second_loss(fake_y, y) * self.lambda_second
419                loss_total = loss_adversarial + loss_second
420
421            # backward pass -> calc gradients and change the weights towards the opposite of gradients via optimizer
422    
423            optimizer.zero_grad(set_to_none=True)
424            amp_scaler.scale(loss_total).backward()
425            if gradient_clipping_threshold:
426                # Unscale first!
427                amp_scaler.unscale_(optimizer)
428                torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=gradient_clipping_threshold)
429            amp_scaler.step(optimizer)
430            amp_scaler.update()
431        else:
432            # make predictions
433            fake_y = self.generator(x)
434
435            discriminator_fake = self.discriminator(x, fake_y)
436
437            # calc loss -> discriminator thinks it is real?
438            loss_adversarial = self.adversarial_loss(discriminator_fake, torch.ones_like(discriminator_fake))
439            loss_second = self.second_loss(fake_y, y) * self.lambda_second
440            loss_total = loss_adversarial + loss_second
441            optimizer.zero_grad()
442            loss_total.backward()
443            if gradient_clipping_threshold:
444                torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=gradient_clipping_threshold)
445            optimizer.step()
446
447        self.last_generator_loss = loss_total.item()
448        self.last_generator_adversarial_loss = loss_adversarial.item()
449        self.last_generator_second_loss = loss_second.item()
450
451        return loss_total, loss_adversarial, loss_second
452
453    def discriminator_step(self, x, y, optimizer, amp_scaler, device, gradient_clipping_threshold):
454        """
455        Performs a single optimization step for the discriminator.
456
457        This includes:
458        - Forward pass through the discriminator for both real and fake samples.
459        - Computing adversarial loss (binary cross-entropy) for real vs fake patches.
460        - Backpropagation and optimizer step, optionally with AMP and gradient clipping.
461
462        Parameters:
463        - x (torch.tensor): 
464            Input tensor (e.g., source image).
465        - y (torch.tensor): 
466            Target tensor (real image) for the discriminator.
467        - optimizer (torch.optim.Optimizer): 
468            Optimizer for the discriminator parameters.
469        - amp_scaler (torch.cuda.amp.GradScaler or None): 
470            Automatic mixed precision scaler.
471        - device (torch.device): 
472            Device for AMP autocast.
473        - gradient_clipping_threshold (float or None): 
474            Max norm for gradient clipping; if None, no clipping.
475
476        Returns:
477        - tuple(torch.tensor, torch.tensor, torch.tensor):
478            - Total discriminator loss (mean of real and fake losses).
479            - Loss for real samples.
480            - Loss for fake samples.
481
482        Notes:
483        - Fake images are detached from the generator to prevent updating its weights.
484        - `last_discriminator_loss` is updated.
485        - Supports AMP and optional gradient clipping for stability.
486        """
487        if amp_scaler:
488            with autocast(device_type=device.type): 
489                # make predictions
490                fake_y = self.generator(x).detach()  # don't update generator!!
491
492                discriminator_real = self.discriminator(x, y)
493                discriminator_fake = self.discriminator(x, fake_y)
494
495                # calc loss -> 1: predictions = real, 0: predictions = fake
496                loss_real = self.adversarial_loss(discriminator_real, torch.ones_like(discriminator_real))  # torch.full_like(discriminator_real, 0.9)
497                loss_fake = self.adversarial_loss(discriminator_fake, torch.zeros_like(discriminator_fake))
498                loss_total = (loss_real + loss_fake) * 0.5
499
500            # backward pass -> calc gradients and change the weights towards the opposite of gradients via optimizer
501            optimizer.zero_grad(set_to_none=True)
502            amp_scaler.scale(loss_total).backward()
503            if gradient_clipping_threshold:
504                # Unscale first!
505                amp_scaler.unscale_(optimizer)
506                torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=gradient_clipping_threshold)
507            amp_scaler.step(optimizer)
508            amp_scaler.update()
509        else:
510            # make predictions
511            fake_y = self.generator(x).detach()
512
513            discriminator_real = self.discriminator(x, y)
514            discriminator_fake = self.discriminator(x, fake_y)
515
516            # calc loss -> 1: predictions = real, 0: predictions = fake
517            loss_real = self.adversarial_loss(discriminator_real, torch.ones_like(discriminator_real))  # torch.full_like(discriminator_real, 0.9)
518            loss_fake = self.adversarial_loss(discriminator_fake, torch.zeros_like(discriminator_fake))
519            loss_total = (loss_real + loss_fake) * 0.5
520            optimizer.zero_grad()
521            loss_total.backward()
522            if gradient_clipping_threshold:
523                torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=gradient_clipping_threshold)
524            optimizer.step()
525
526        self.last_discriminator_loss = loss_total.item()
527
528        return loss_total, loss_real, loss_fake

Pix2Pix GAN for image-to-image translation.

Components:

  • Generator: U-Net generator producing synthetic images.
  • Discriminator: PatchGAN discriminator evaluating real vs fake images.
  • Adversarial loss: Binary cross-entropy.
  • Optional second loss for pixel-wise supervision.

Parameter:

  • input_channels (int): Number of input channels.
  • output_channels (int): Number of output channels.
  • hidden_channels (int): Base hidden channels for both generator and discriminator.
  • second_loss (nn.Module): Optional secondary loss (default: L1Loss).
  • lambda_second (float): Weight for secondary loss in generator optimization.
Pix2Pix( input_channels=1, output_channels=1, hidden_channels=64, second_loss=L1Loss(), lambda_second=100)
284    def __init__(self, input_channels=1, output_channels=1, hidden_channels=64, 
285                 second_loss=nn.L1Loss(), lambda_second=100):
286        """
287        Initializes the Pix2Pix GAN model.
288
289        Components:
290        - Generator: U-Net architecture for producing synthetic images.
291        - Discriminator: PatchGAN for evaluating real vs. fake images.
292        - Adversarial loss: Binary cross-entropy to train the generator to fool the discriminator.
293        - Optional secondary loss: Pixel-wise supervision (default: L1Loss).
294
295        Parameter:
296        - input_channels (int): 
297            Number of channels in the input images (default=1).
298        - output_channels (int): 
299            Number of channels in the output images (default=1).
300        - hidden_channels (int): 
301            Base number of hidden channels in the generator and discriminator (default=64).
302        - second_loss (nn.Module): 
303            Optional secondary loss for the generator (default: nn.L1Loss()).
304        - lambda_second (float): 
305            Weight applied to the secondary loss in generator optimization (default=100).
306        """
307        super().__init__()
308        self.input_channels = input_channels
309        self.output_channels = output_channels
310
311        self.generator = UNetGenerator(input_channels=input_channels, 
312                                       output_channels=output_channels, 
313                                       hidden_channels=hidden_channels)
314        self.discriminator = Discriminator(input_channels=input_channels+output_channels, hidden_channels=hidden_channels)
315
316        self.adversarial_loss = nn.BCEWithLogitsLoss()
317        self.second_loss = second_loss
318        self.lambda_second = lambda_second
319
320        self.last_generator_loss = float("inf")
321        self.last_generator_adversarial_loss = float("inf")
322        self.last_generator_second_loss = float("inf")
323        self.last_discriminator_loss = float("inf")

Initializes the Pix2Pix GAN model.

Components:

  • Generator: U-Net architecture for producing synthetic images.
  • Discriminator: PatchGAN for evaluating real vs. fake images.
  • Adversarial loss: Binary cross-entropy to train the generator to fool the discriminator.
  • Optional secondary loss: Pixel-wise supervision (default: L1Loss).

Parameter:

  • input_channels (int): Number of channels in the input images (default=1).
  • output_channels (int): Number of channels in the output images (default=1).
  • hidden_channels (int): Base number of hidden channels in the generator and discriminator (default=64).
  • second_loss (nn.Module): Optional secondary loss for the generator (default: nn.L1Loss()).
  • lambda_second (float): Weight applied to the secondary loss in generator optimization (default=100).
input_channels
output_channels
generator
discriminator
adversarial_loss
second_loss
lambda_second
last_generator_loss
last_generator_adversarial_loss
last_generator_second_loss
last_discriminator_loss
def get_input_channels(self):
325    def get_input_channels(self):
326        """
327        Returns the number of input channels used by the model.
328
329        Returns:
330        - int: 
331            Number of input channels expected by the model.
332        """
333        return self.input_channels

Returns the number of input channels used by the model.

Returns:

  • int: Number of input channels expected by the model.
def get_output_channels(self):
335    def get_output_channels(self):
336        """
337        Returns the number of output channels produced by the model.
338
339        Returns:
340        - int: 
341            Number of output channels the model generates
342        """
343        return self.output_channels

Returns the number of output channels produced by the model.

Returns:

  • int: Number of output channels the model generates
def get_dict(self):
345    def get_dict(self):
346        """
347        Returns a dictionary with the most recent loss values.
348
349        Returns:
350        - dict: Loss components (base, complex).
351
352        Notes:
353        - Useful for logging or monitoring training progress.
354        """
355        return {
356                f"loss_generator": self.last_generator_loss, 
357                f"loss_generator_adversarial": self.last_generator_adversarial_loss, 
358                f"loss_generator_second": self.last_generator_second_loss,
359                f"loss_discriminator": self.last_discriminator_loss
360               }

Returns a dictionary with the most recent loss values.

Returns:

  • dict: Loss components (base, complex).

Notes:

  • Useful for logging or monitoring training progress.
def forward(self, x):
362    def forward(self, x):
363        """
364        Forward pass through the generator.
365
366        Parameter:
367        - x (torch.tensor): 
368            Input tensor.
369
370        Returns:
371        - torch.tensor: Generated output image.
372        """
373        return self.generator(x)

Forward pass through the generator.

Parameter:

  • x (torch.tensor): Input tensor.

Returns:

  • torch.tensor: Generated output image.
def generator_step( self, x, y, optimizer, amp_scaler, device, gradient_clipping_threshold):
375    def generator_step(self, x, y, optimizer, amp_scaler, device, gradient_clipping_threshold):
376        """
377        Performs a single optimization step for the generator.
378
379        This includes:
380        - Forward pass through the generator and discriminator.
381        - Computing adversarial loss (generator tries to fool the discriminator).
382        - Computing optional secondary loss (e.g., L1 or MSE).
383        - Backpropagation and optimizer step, optionally with AMP and gradient clipping.
384
385        Parameters:
386        - x (torch.tensor): 
387            Input tensor for the generator (e.g., source image).
388        - y (torch.tensor): 
389            Target tensor for supervised secondary loss.
390        - optimizer (torch.optim.Optimizer): 
391            Optimizer for the generator parameters.
392        - amp_scaler (torch.cuda.amp.GradScaler or None): 
393            Automatic mixed precision scaler.
394        - device (torch.device): 
395            Device for AMP autocast.
396        - gradient_clipping_threshold (float or None): 
397            Max norm for gradient clipping; if None, no clipping.
398
399        Returns:
400        - tuple(torch.tensor, torch.tensor, torch.tensor):
401            - Total generator loss (adversarial + secondary).
402            - Adversarial loss component.
403            - Secondary loss component (weighted by `lambda_second`).
404
405        Notes:
406        - If AMP is enabled, gradients are scaled and unscaled appropriately.
407        - `last_generator_loss`, `last_generator_adversarial_loss`, and `last_generator_second_loss` are updated.
408        """
409        if amp_scaler:
410            with autocast(device_type=device.type):
411                # make predictions
412                fake_y = self.generator(x)
413
414                discriminator_fake = self.discriminator(x, fake_y)
415
416                # calc loss -> discriminator thinks it is real?
417                loss_adversarial = self.adversarial_loss(discriminator_fake, torch.ones_like(discriminator_fake))
418                loss_second = self.second_loss(fake_y, y) * self.lambda_second
419                loss_total = loss_adversarial + loss_second
420
421            # backward pass -> calc gradients and change the weights towards the opposite of gradients via optimizer
422    
423            optimizer.zero_grad(set_to_none=True)
424            amp_scaler.scale(loss_total).backward()
425            if gradient_clipping_threshold:
426                # Unscale first!
427                amp_scaler.unscale_(optimizer)
428                torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=gradient_clipping_threshold)
429            amp_scaler.step(optimizer)
430            amp_scaler.update()
431        else:
432            # make predictions
433            fake_y = self.generator(x)
434
435            discriminator_fake = self.discriminator(x, fake_y)
436
437            # calc loss -> discriminator thinks it is real?
438            loss_adversarial = self.adversarial_loss(discriminator_fake, torch.ones_like(discriminator_fake))
439            loss_second = self.second_loss(fake_y, y) * self.lambda_second
440            loss_total = loss_adversarial + loss_second
441            optimizer.zero_grad()
442            loss_total.backward()
443            if gradient_clipping_threshold:
444                torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=gradient_clipping_threshold)
445            optimizer.step()
446
447        self.last_generator_loss = loss_total.item()
448        self.last_generator_adversarial_loss = loss_adversarial.item()
449        self.last_generator_second_loss = loss_second.item()
450
451        return loss_total, loss_adversarial, loss_second

Performs a single optimization step for the generator.

This includes:

  • Forward pass through the generator and discriminator.
  • Computing adversarial loss (generator tries to fool the discriminator).
  • Computing optional secondary loss (e.g., L1 or MSE).
  • Backpropagation and optimizer step, optionally with AMP and gradient clipping.

Parameters:

  • x (torch.tensor): Input tensor for the generator (e.g., source image).
  • y (torch.tensor): Target tensor for supervised secondary loss.
  • optimizer (torch.optim.Optimizer): Optimizer for the generator parameters.
  • amp_scaler (torch.cuda.amp.GradScaler or None): Automatic mixed precision scaler.
  • device (torch.device): Device for AMP autocast.
  • gradient_clipping_threshold (float or None): Max norm for gradient clipping; if None, no clipping.

Returns:

  • tuple(torch.tensor, torch.tensor, torch.tensor):
    • Total generator loss (adversarial + secondary).
    • Adversarial loss component.
    • Secondary loss component (weighted by lambda_second).

Notes:

def discriminator_step( self, x, y, optimizer, amp_scaler, device, gradient_clipping_threshold):
453    def discriminator_step(self, x, y, optimizer, amp_scaler, device, gradient_clipping_threshold):
454        """
455        Performs a single optimization step for the discriminator.
456
457        This includes:
458        - Forward pass through the discriminator for both real and fake samples.
459        - Computing adversarial loss (binary cross-entropy) for real vs fake patches.
460        - Backpropagation and optimizer step, optionally with AMP and gradient clipping.
461
462        Parameters:
463        - x (torch.tensor): 
464            Input tensor (e.g., source image).
465        - y (torch.tensor): 
466            Target tensor (real image) for the discriminator.
467        - optimizer (torch.optim.Optimizer): 
468            Optimizer for the discriminator parameters.
469        - amp_scaler (torch.cuda.amp.GradScaler or None): 
470            Automatic mixed precision scaler.
471        - device (torch.device): 
472            Device for AMP autocast.
473        - gradient_clipping_threshold (float or None): 
474            Max norm for gradient clipping; if None, no clipping.
475
476        Returns:
477        - tuple(torch.tensor, torch.tensor, torch.tensor):
478            - Total discriminator loss (mean of real and fake losses).
479            - Loss for real samples.
480            - Loss for fake samples.
481
482        Notes:
483        - Fake images are detached from the generator to prevent updating its weights.
484        - `last_discriminator_loss` is updated.
485        - Supports AMP and optional gradient clipping for stability.
486        """
487        if amp_scaler:
488            with autocast(device_type=device.type): 
489                # make predictions
490                fake_y = self.generator(x).detach()  # don't update generator!!
491
492                discriminator_real = self.discriminator(x, y)
493                discriminator_fake = self.discriminator(x, fake_y)
494
495                # calc loss -> 1: predictions = real, 0: predictions = fake
496                loss_real = self.adversarial_loss(discriminator_real, torch.ones_like(discriminator_real))  # torch.full_like(discriminator_real, 0.9)
497                loss_fake = self.adversarial_loss(discriminator_fake, torch.zeros_like(discriminator_fake))
498                loss_total = (loss_real + loss_fake) * 0.5
499
500            # backward pass -> calc gradients and change the weights towards the opposite of gradients via optimizer
501            optimizer.zero_grad(set_to_none=True)
502            amp_scaler.scale(loss_total).backward()
503            if gradient_clipping_threshold:
504                # Unscale first!
505                amp_scaler.unscale_(optimizer)
506                torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=gradient_clipping_threshold)
507            amp_scaler.step(optimizer)
508            amp_scaler.update()
509        else:
510            # make predictions
511            fake_y = self.generator(x).detach()
512
513            discriminator_real = self.discriminator(x, y)
514            discriminator_fake = self.discriminator(x, fake_y)
515
516            # calc loss -> 1: predictions = real, 0: predictions = fake
517            loss_real = self.adversarial_loss(discriminator_real, torch.ones_like(discriminator_real))  # torch.full_like(discriminator_real, 0.9)
518            loss_fake = self.adversarial_loss(discriminator_fake, torch.zeros_like(discriminator_fake))
519            loss_total = (loss_real + loss_fake) * 0.5
520            optimizer.zero_grad()
521            loss_total.backward()
522            if gradient_clipping_threshold:
523                torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=gradient_clipping_threshold)
524            optimizer.step()
525
526        self.last_discriminator_loss = loss_total.item()
527
528        return loss_total, loss_real, loss_fake

Performs a single optimization step for the discriminator.

This includes:

  • Forward pass through the discriminator for both real and fake samples.
  • Computing adversarial loss (binary cross-entropy) for real vs fake patches.
  • Backpropagation and optimizer step, optionally with AMP and gradient clipping.

Parameters:

  • x (torch.tensor): Input tensor (e.g., source image).
  • y (torch.tensor): Target tensor (real image) for the discriminator.
  • optimizer (torch.optim.Optimizer): Optimizer for the discriminator parameters.
  • amp_scaler (torch.cuda.amp.GradScaler or None): Automatic mixed precision scaler.
  • device (torch.device): Device for AMP autocast.
  • gradient_clipping_threshold (float or None): Max norm for gradient clipping; if None, no clipping.

Returns:

  • tuple(torch.tensor, torch.tensor, torch.tensor):
    • Total discriminator loss (mean of real and fake losses).
    • Loss for real samples.
    • Loss for fake samples.

Notes:

  • Fake images are detached from the generator to prevent updating its weights.
  • last_discriminator_loss is updated.
  • Supports AMP and optional gradient clipping for stability.