image_to_image.losses.weighted_combined_loss

Module to define a Weighted Combine Loss.

Functions:

  • calc_weight_map

Classes:

  • WeightedCombinedLoss

By Tobia Ippolito

  1"""
  2Module to define a Weighted Combine Loss.
  3
  4Functions:
  5- calc_weight_map
  6
  7Classes:
  8- WeightedCombinedLoss
  9
 10By Tobia Ippolito
 11"""
 12# ---------------------------
 13#        > Imports <
 14# ---------------------------
 15import torch
 16import torch.nn as nn
 17import torch.nn.functional as F
 18
 19import kornia
 20
 21
 22
 23# ---------------------------
 24#   > Loss Implementation <
 25# ---------------------------
 26class WeightedCombinedLoss(nn.Module):
 27    """
 28    Computes a weighted combination of multiple loss functions for image-to-image tasks.
 29
 30    Supported losses:
 31        - SILog loss
 32        - Gradient L1 loss
 33        - SSIM loss
 34        - Edge-aware loss
 35        - L1 loss
 36        - Variance loss
 37        - Range loss
 38        - Blur loss
 39
 40    The losses can be weighted individually, and average losses are tracked across steps.
 41    """
 42    def __init__(self, 
 43                 silog_lambda=0.5, 
 44                 weight_silog=0.5, 
 45                 weight_grad=10.0, 
 46                 weight_ssim=5.0,
 47                 weight_edge_aware=10.0,
 48                 weight_l1=1.0,
 49                 weight_var=1.0,
 50                 weight_range=1.0,
 51                 weight_blur=1.0):
 52        """
 53        Initializes the WeightedCombinedLoss with optional weights for each component.
 54
 55        Parameter:
 56        - silog_lambda (float): 
 57            SILog lambda parameter.
 58        - weight_silog (float): 
 59            Weight for SILog loss.
 60        - weight_grad (float): 
 61            Weight for gradient L1 loss.
 62        - weight_ssim (float): 
 63            Weight for SSIM loss.
 64        - weight_edge_aware (float): 
 65            Weight for edge-aware loss.
 66        - weight_l1 (float): 
 67            Weight for L1 loss.
 68        - weight_var (float): 
 69            Weight for variance loss.
 70        - weight_range (float): 
 71            Weight for range loss.
 72        - weight_blur (float): 
 73            Weight for blur loss.
 74        """
 75        super().__init__()
 76        self.silog_lambda = silog_lambda
 77        self.weight_silog = weight_silog
 78        self.weight_grad = weight_grad
 79        self.weight_ssim = weight_ssim
 80        self.weight_edge_aware = weight_edge_aware
 81        self.weight_l1 = weight_l1
 82        self.weight_var = weight_var
 83        self.weight_range = weight_range
 84        self.weight_blur = weight_blur
 85
 86        self.avg_loss_silog = 0
 87        self.avg_loss_grad = 0
 88        self.avg_loss_ssim = 0
 89        self.avg_loss_l1 = 0
 90        self.avg_loss_edge_aware = 0
 91        self.avg_loss_var = 0
 92        self.avg_loss_range = 0
 93        self.avg_loss_blur = 0
 94        self.steps = 0
 95
 96        # Instantiate SSIMLoss module
 97        self.ssim_module = kornia.losses.SSIMLoss(window_size=11, reduction='mean')
 98        # self.ssim_module = kornia.losses.MS_SSIMLoss(reduction='mean')
 99
100
101    def silog_loss(self, pred, target, weight_map):
102        eps = 1e-6
103        pred = torch.clamp(pred, min=eps)
104        target = torch.clamp(target, min=eps)
105        
106        diff_log = torch.log(target) - torch.log(pred)
107        diff_log = diff_log * weight_map
108
109        loss = torch.sqrt(torch.mean(diff_log ** 2) -
110                          self.silog_lambda * torch.mean(diff_log) ** 2)
111        return loss
112
113    def gradient_l1_loss(self, pred, target, weight_map):
114        # Create Channel Dimension
115        if pred.ndim == 3:
116            pred = pred.unsqueeze(1)
117        if target.ndim == 3:
118            target = target.unsqueeze(1)
119        if weight_map.ndim == 3:
120            weight_map = weight_map.unsqueeze(1)
121
122        # Gradient in x-direction (horizontal -> dim=3)
123        pred_grad_x = pred[:, :, :, 1:] - pred[:, :, :, :-1]
124        target_grad_x = target[:, :, :, 1:] - target[:, :, :, :-1]
125
126        # Gradient in y-direction (vertical -> dim=2)
127        pred_grad_y = pred[:, :, 1:, :] - pred[:, :, :-1, :]
128        target_grad_y = target[:, :, 1:, :] - target[:, :, :-1, :]
129
130        weight_x = weight_map[:, :, :, 1:] * weight_map[:, :, :, :-1]
131        weight_y = weight_map[:, :, 1:, :] * weight_map[:, :, :-1, :]
132
133        loss_x = torch.mean(torch.abs(pred_grad_x - target_grad_x) * weight_x)
134        loss_y = torch.mean(torch.abs(pred_grad_y - target_grad_y) * weight_y)
135        
136        # loss_x = F.l1_loss(pred_grad_x, target_grad_x) 
137        # loss_y = F.l1_loss(pred_grad_y, target_grad_y)
138
139        return loss_x + loss_y
140
141    def ssim_loss(self, pred, target, weight_map):
142        # SSIM returns similarity, so we subtract from 1
143        if pred.ndim == 3:
144            pred = pred.unsqueeze(1)
145        if target.ndim == 3:
146            target = target.unsqueeze(1)
147
148        # self.ssim_module = self.ssim_module.to(pred.device)
149        return self.ssim_module(pred, target)
150
151    def edge_aware_loss(self, pred, target, weight_map):
152        if pred.ndim == 3:
153            pred = pred.unsqueeze(1)
154        if target.ndim == 3:
155            target = target.unsqueeze(1)
156        if weight_map.ndim == 3:
157            weight_map = weight_map.unsqueeze(1)
158
159        pred_grad_x = pred[:, :, :, :-1] - pred[:, :, :, 1:]
160        pred_grad_y = pred[:, :, :-1, :] - pred[:, :, 1:, :]
161
162        target_grad_x = torch.mean(torch.abs(target[:, :, :, :-1] - target[:, :, :, 1:]), 1, keepdim=True)
163        target_grad_y = torch.mean(torch.abs(target[:, :, :-1, :] - target[:, :, 1:, :]), 1, keepdim=True)
164
165        weight_x = weight_map[:, :, :, 1:] * weight_map[:, :, :, :-1]
166        weight_y = weight_map[:, :, 1:, :] * weight_map[:, :, :-1, :]
167
168        pred_grad_x *= torch.exp(-target_grad_x* weight_x) 
169        pred_grad_y *= torch.exp(-target_grad_y* weight_y)
170
171        # return (pred_grad_y.abs().mean() + target_grad_y.abs().mean())
172        return (pred_grad_x.abs().mean() + pred_grad_y.abs().mean())
173
174    def l1_loss(self, pred, target, weight_map):
175        loss = torch.abs(target - pred) * weight_map
176        return loss.mean()
177
178    def variance_loss(self, pred, target):
179        pred_var = torch.var(pred)
180        target_var = torch.var(target)
181        return F.mse_loss(pred_var, target_var)
182    
183    def range_loss(self, pred, target):
184        pred_min, pred_max = torch.min(pred), torch.max(pred)
185        target_min, target_max = torch.min(target), torch.max(target)
186        
187        min_loss = F.mse_loss(pred_min, target_min)
188        max_loss = F.mse_loss(pred_max, target_max)
189        
190        return min_loss + max_loss
191
192    def blur_loss(self, pred, target):
193        laplacian_kernel = torch.tensor([[[[0, 1, 0],
194                                           [1, -4, 1],
195                                           [0, 1, 0]]]], dtype=pred.dtype, device=pred.device)
196
197        if pred.ndim == 3:
198            pred = pred.unsqueeze(1)
199        if target.ndim == 3:
200            target = target.unsqueeze(1)
201
202        pred_lap = F.conv2d(pred, laplacian_kernel, padding=1)
203        target_lap = F.conv2d(target, laplacian_kernel, padding=1)
204
205        return F.l1_loss(pred_lap, target_lap)
206
207    def blur_loss(self, pred, target):
208        laplacian_kernel = torch.tensor([[[[0, 1, 0],
209                                           [1, -4, 1],
210                                           [0, 1, 0]]]], dtype=pred.dtype, device=pred.device)
211
212        if pred.ndim == 3:
213            pred = pred.unsqueeze(1)
214        if target.ndim == 3:
215            target = target.unsqueeze(1)
216
217        pred_lap = F.conv2d(pred, laplacian_kernel, padding=1)
218        target_lap = F.conv2d(target, laplacian_kernel, padding=1)
219
220        return F.l1_loss(pred_lap, target_lap)
221
222    def forward(self, pred, target, weight_map=None, should_calc_weight_map=False):
223        """
224        Computes the weighted combined loss between prediction and target.
225
226        Parameter:
227        - pred (torch.Tensor): 
228            Predicted output tensor.
229        - target (torch.Tensor): 
230            Ground truth tensor.
231        - weight_map (torch.Tensor or None): 
232            Optional pixel-wise weighting map.
233        - should_calc_weight_map (bool): 
234            If True and weight_map is None, calculates a weight map from target.
235
236        Returns:
237        - torch.Tensor: Weighted sum of all losses.
238        """
239        if type(weight_map) == type(None):
240            if should_calc_weight_map:
241                weight_map = calc_weight_map(target)
242            else:
243                # no mask/weight-map
244                # FIXME -> right
245                weight_map = torch.ones_like(pred)
246
247        loss_silog = self.silog_loss(pred, target, weight_map)
248        loss_grad = self.gradient_l1_loss(pred, target, weight_map)
249        loss_ssim = self.ssim_loss(pred, target, weight_map)
250        loss_l1 = self.l1_loss(pred, target, weight_map)
251        loss_edge_aware = self.edge_aware_loss(pred, target, weight_map)
252        loss_var = self.variance_loss(pred, target)
253        loss_range = self.range_loss(pred, target)
254        loss_blur = self.blur_loss(pred, target)
255
256        # reset avgs
257        if self.steps > 24:
258            self.step()
259
260        self.avg_loss_silog += loss_silog
261        self.avg_loss_grad += loss_grad
262        self.avg_loss_ssim += loss_ssim
263        self.avg_loss_l1 += loss_l1
264        self.avg_loss_edge_aware += loss_edge_aware
265        self.avg_loss_var += loss_var
266        self.avg_loss_range += loss_range
267        self.avg_loss_blur += loss_blur
268        self.steps += 1
269
270        total_loss = (
271            self.weight_silog * loss_silog +
272            self.weight_grad * loss_grad +
273            self.weight_ssim * loss_ssim +
274            self.weight_edge_aware * loss_edge_aware +
275            self.weight_l1 * loss_l1 +
276            self.weight_var * loss_var +
277            self.weight_range * loss_range +
278            self.weight_blur * loss_blur
279        )
280
281        return total_loss
282
283    def step(self, epoch=None):
284        """
285        Resets the running averages of all tracked losses.
286        """
287        self.avg_loss_silog = 0
288        self.avg_loss_grad = 0
289        self.avg_loss_ssim = 0
290        self.avg_loss_l1 = 0
291        self.avg_loss_edge_aware = 0
292        self.avg_loss_var = 0
293        self.avg_loss_range = 0
294        self.avg_loss_blur = 0
295        self.steps = 0
296
297    def get_avg_losses(self):
298        """
299        Returns the running average of all individual losses.
300
301        Returns:
302        - tuple: (avg_loss_silog, avg_loss_grad, avg_loss_ssim, avg_loss_l1,
303                avg_loss_edge_aware, avg_loss_var, avg_loss_range, avg_loss_blur)
304        """
305        return (self.avg_loss_silog/self.steps,
306                self.avg_loss_grad/self.steps,
307                self.avg_loss_ssim/self.steps,
308                self.avg_loss_l1/self.steps,
309                self.avg_loss_edge_aware/self.steps,
310                self.avg_loss_var/self.steps,
311                self.avg_loss_range/self.steps,
312                self.avg_loss_blur/self.steps
313               )
314
315    def get_dict(self):
316        """
317        Returns a dictionary of average losses and their corresponding weights.
318
319        Returns:
320        - dict: All loss components with their weights.
321        """
322        loss_silog, loss_grad, loss_ssim, loss_l1, loss_edge_aware, loss_var, loss_range, loss_blur = self.get_avg_losses()
323        return {
324                f"loss_silog": loss_silog, 
325                f"loss_grad": loss_grad, 
326                f"loss_ssim": loss_ssim,
327                f"loss_L1": loss_l1,
328                f"loss_edge aware": loss_edge_aware,
329                f"loss_var": loss_var,
330                f"loss_range": loss_range,
331                f"loss_blur": loss_blur,
332                f"weight_loss_silog": self.weight_silog, 
333                f"weight_loss_grad": self.weight_grad,
334                f"_weight_loss_ssim": self.weight_ssim,
335                f"_weight_loss_L1": self.weight_l1,
336                f"weight_loss_edge_aware": self.weight_edge_aware,
337                f"weight_loss_var": self.weight_var,
338                f"weight_loss_range": self.weight_range,
339                f"weight_loss_blur": self.weight_blur
340               }
341
342def calc_weight_map(target):
343    """
344    Calculates a per-pixel weighting map for a target tensor based on unique value frequencies.
345
346    Less frequent values are given higher weights to emphasize their contribution in loss computations.
347
348    Parameter:
349    - target (torch.Tensor): 
350        Ground truth tensor.
351
352    Returns:
353    - torch.Tensor: Weight map tensor of the same shape as target.
354    """
355    values, counts = torch.unique(target.flatten(), return_counts=True)
356    all_counts = counts.sum().float()
357    
358    # weight_factor = 2.0
359    # weights = {values[idx].item(): max(torch.exp( ( (1-(counts[idx].item()/all_counts))) *weight_factor), 0.0001) for idx in range(len(values))}
360    
361    weights = {values[idx].item(): 255.0/counts[idx].item() for idx in range(len(values))}
362
363    # print(f"Weights:")
364    # for cur_value, cur_counts in list(sorted(weights.items(), key=lambda x:x[0])):
365    #     print('    - '+str(round(cur_value, 4))+': '+str(cur_counts.item()))
366
367    weights_map = torch.zeros_like(target, dtype=torch.float)
368    for cur_value in values:
369        cur_value = cur_value.item()
370        weights_map[target == cur_value] = weights[cur_value]
371
372    return weights_map
class WeightedCombinedLoss(torch.nn.modules.module.Module):
 27class WeightedCombinedLoss(nn.Module):
 28    """
 29    Computes a weighted combination of multiple loss functions for image-to-image tasks.
 30
 31    Supported losses:
 32        - SILog loss
 33        - Gradient L1 loss
 34        - SSIM loss
 35        - Edge-aware loss
 36        - L1 loss
 37        - Variance loss
 38        - Range loss
 39        - Blur loss
 40
 41    The losses can be weighted individually, and average losses are tracked across steps.
 42    """
 43    def __init__(self, 
 44                 silog_lambda=0.5, 
 45                 weight_silog=0.5, 
 46                 weight_grad=10.0, 
 47                 weight_ssim=5.0,
 48                 weight_edge_aware=10.0,
 49                 weight_l1=1.0,
 50                 weight_var=1.0,
 51                 weight_range=1.0,
 52                 weight_blur=1.0):
 53        """
 54        Initializes the WeightedCombinedLoss with optional weights for each component.
 55
 56        Parameter:
 57        - silog_lambda (float): 
 58            SILog lambda parameter.
 59        - weight_silog (float): 
 60            Weight for SILog loss.
 61        - weight_grad (float): 
 62            Weight for gradient L1 loss.
 63        - weight_ssim (float): 
 64            Weight for SSIM loss.
 65        - weight_edge_aware (float): 
 66            Weight for edge-aware loss.
 67        - weight_l1 (float): 
 68            Weight for L1 loss.
 69        - weight_var (float): 
 70            Weight for variance loss.
 71        - weight_range (float): 
 72            Weight for range loss.
 73        - weight_blur (float): 
 74            Weight for blur loss.
 75        """
 76        super().__init__()
 77        self.silog_lambda = silog_lambda
 78        self.weight_silog = weight_silog
 79        self.weight_grad = weight_grad
 80        self.weight_ssim = weight_ssim
 81        self.weight_edge_aware = weight_edge_aware
 82        self.weight_l1 = weight_l1
 83        self.weight_var = weight_var
 84        self.weight_range = weight_range
 85        self.weight_blur = weight_blur
 86
 87        self.avg_loss_silog = 0
 88        self.avg_loss_grad = 0
 89        self.avg_loss_ssim = 0
 90        self.avg_loss_l1 = 0
 91        self.avg_loss_edge_aware = 0
 92        self.avg_loss_var = 0
 93        self.avg_loss_range = 0
 94        self.avg_loss_blur = 0
 95        self.steps = 0
 96
 97        # Instantiate SSIMLoss module
 98        self.ssim_module = kornia.losses.SSIMLoss(window_size=11, reduction='mean')
 99        # self.ssim_module = kornia.losses.MS_SSIMLoss(reduction='mean')
100
101
102    def silog_loss(self, pred, target, weight_map):
103        eps = 1e-6
104        pred = torch.clamp(pred, min=eps)
105        target = torch.clamp(target, min=eps)
106        
107        diff_log = torch.log(target) - torch.log(pred)
108        diff_log = diff_log * weight_map
109
110        loss = torch.sqrt(torch.mean(diff_log ** 2) -
111                          self.silog_lambda * torch.mean(diff_log) ** 2)
112        return loss
113
114    def gradient_l1_loss(self, pred, target, weight_map):
115        # Create Channel Dimension
116        if pred.ndim == 3:
117            pred = pred.unsqueeze(1)
118        if target.ndim == 3:
119            target = target.unsqueeze(1)
120        if weight_map.ndim == 3:
121            weight_map = weight_map.unsqueeze(1)
122
123        # Gradient in x-direction (horizontal -> dim=3)
124        pred_grad_x = pred[:, :, :, 1:] - pred[:, :, :, :-1]
125        target_grad_x = target[:, :, :, 1:] - target[:, :, :, :-1]
126
127        # Gradient in y-direction (vertical -> dim=2)
128        pred_grad_y = pred[:, :, 1:, :] - pred[:, :, :-1, :]
129        target_grad_y = target[:, :, 1:, :] - target[:, :, :-1, :]
130
131        weight_x = weight_map[:, :, :, 1:] * weight_map[:, :, :, :-1]
132        weight_y = weight_map[:, :, 1:, :] * weight_map[:, :, :-1, :]
133
134        loss_x = torch.mean(torch.abs(pred_grad_x - target_grad_x) * weight_x)
135        loss_y = torch.mean(torch.abs(pred_grad_y - target_grad_y) * weight_y)
136        
137        # loss_x = F.l1_loss(pred_grad_x, target_grad_x) 
138        # loss_y = F.l1_loss(pred_grad_y, target_grad_y)
139
140        return loss_x + loss_y
141
142    def ssim_loss(self, pred, target, weight_map):
143        # SSIM returns similarity, so we subtract from 1
144        if pred.ndim == 3:
145            pred = pred.unsqueeze(1)
146        if target.ndim == 3:
147            target = target.unsqueeze(1)
148
149        # self.ssim_module = self.ssim_module.to(pred.device)
150        return self.ssim_module(pred, target)
151
152    def edge_aware_loss(self, pred, target, weight_map):
153        if pred.ndim == 3:
154            pred = pred.unsqueeze(1)
155        if target.ndim == 3:
156            target = target.unsqueeze(1)
157        if weight_map.ndim == 3:
158            weight_map = weight_map.unsqueeze(1)
159
160        pred_grad_x = pred[:, :, :, :-1] - pred[:, :, :, 1:]
161        pred_grad_y = pred[:, :, :-1, :] - pred[:, :, 1:, :]
162
163        target_grad_x = torch.mean(torch.abs(target[:, :, :, :-1] - target[:, :, :, 1:]), 1, keepdim=True)
164        target_grad_y = torch.mean(torch.abs(target[:, :, :-1, :] - target[:, :, 1:, :]), 1, keepdim=True)
165
166        weight_x = weight_map[:, :, :, 1:] * weight_map[:, :, :, :-1]
167        weight_y = weight_map[:, :, 1:, :] * weight_map[:, :, :-1, :]
168
169        pred_grad_x *= torch.exp(-target_grad_x* weight_x) 
170        pred_grad_y *= torch.exp(-target_grad_y* weight_y)
171
172        # return (pred_grad_y.abs().mean() + target_grad_y.abs().mean())
173        return (pred_grad_x.abs().mean() + pred_grad_y.abs().mean())
174
175    def l1_loss(self, pred, target, weight_map):
176        loss = torch.abs(target - pred) * weight_map
177        return loss.mean()
178
179    def variance_loss(self, pred, target):
180        pred_var = torch.var(pred)
181        target_var = torch.var(target)
182        return F.mse_loss(pred_var, target_var)
183    
184    def range_loss(self, pred, target):
185        pred_min, pred_max = torch.min(pred), torch.max(pred)
186        target_min, target_max = torch.min(target), torch.max(target)
187        
188        min_loss = F.mse_loss(pred_min, target_min)
189        max_loss = F.mse_loss(pred_max, target_max)
190        
191        return min_loss + max_loss
192
193    def blur_loss(self, pred, target):
194        laplacian_kernel = torch.tensor([[[[0, 1, 0],
195                                           [1, -4, 1],
196                                           [0, 1, 0]]]], dtype=pred.dtype, device=pred.device)
197
198        if pred.ndim == 3:
199            pred = pred.unsqueeze(1)
200        if target.ndim == 3:
201            target = target.unsqueeze(1)
202
203        pred_lap = F.conv2d(pred, laplacian_kernel, padding=1)
204        target_lap = F.conv2d(target, laplacian_kernel, padding=1)
205
206        return F.l1_loss(pred_lap, target_lap)
207
208    def blur_loss(self, pred, target):
209        laplacian_kernel = torch.tensor([[[[0, 1, 0],
210                                           [1, -4, 1],
211                                           [0, 1, 0]]]], dtype=pred.dtype, device=pred.device)
212
213        if pred.ndim == 3:
214            pred = pred.unsqueeze(1)
215        if target.ndim == 3:
216            target = target.unsqueeze(1)
217
218        pred_lap = F.conv2d(pred, laplacian_kernel, padding=1)
219        target_lap = F.conv2d(target, laplacian_kernel, padding=1)
220
221        return F.l1_loss(pred_lap, target_lap)
222
223    def forward(self, pred, target, weight_map=None, should_calc_weight_map=False):
224        """
225        Computes the weighted combined loss between prediction and target.
226
227        Parameter:
228        - pred (torch.Tensor): 
229            Predicted output tensor.
230        - target (torch.Tensor): 
231            Ground truth tensor.
232        - weight_map (torch.Tensor or None): 
233            Optional pixel-wise weighting map.
234        - should_calc_weight_map (bool): 
235            If True and weight_map is None, calculates a weight map from target.
236
237        Returns:
238        - torch.Tensor: Weighted sum of all losses.
239        """
240        if type(weight_map) == type(None):
241            if should_calc_weight_map:
242                weight_map = calc_weight_map(target)
243            else:
244                # no mask/weight-map
245                # FIXME -> right
246                weight_map = torch.ones_like(pred)
247
248        loss_silog = self.silog_loss(pred, target, weight_map)
249        loss_grad = self.gradient_l1_loss(pred, target, weight_map)
250        loss_ssim = self.ssim_loss(pred, target, weight_map)
251        loss_l1 = self.l1_loss(pred, target, weight_map)
252        loss_edge_aware = self.edge_aware_loss(pred, target, weight_map)
253        loss_var = self.variance_loss(pred, target)
254        loss_range = self.range_loss(pred, target)
255        loss_blur = self.blur_loss(pred, target)
256
257        # reset avgs
258        if self.steps > 24:
259            self.step()
260
261        self.avg_loss_silog += loss_silog
262        self.avg_loss_grad += loss_grad
263        self.avg_loss_ssim += loss_ssim
264        self.avg_loss_l1 += loss_l1
265        self.avg_loss_edge_aware += loss_edge_aware
266        self.avg_loss_var += loss_var
267        self.avg_loss_range += loss_range
268        self.avg_loss_blur += loss_blur
269        self.steps += 1
270
271        total_loss = (
272            self.weight_silog * loss_silog +
273            self.weight_grad * loss_grad +
274            self.weight_ssim * loss_ssim +
275            self.weight_edge_aware * loss_edge_aware +
276            self.weight_l1 * loss_l1 +
277            self.weight_var * loss_var +
278            self.weight_range * loss_range +
279            self.weight_blur * loss_blur
280        )
281
282        return total_loss
283
284    def step(self, epoch=None):
285        """
286        Resets the running averages of all tracked losses.
287        """
288        self.avg_loss_silog = 0
289        self.avg_loss_grad = 0
290        self.avg_loss_ssim = 0
291        self.avg_loss_l1 = 0
292        self.avg_loss_edge_aware = 0
293        self.avg_loss_var = 0
294        self.avg_loss_range = 0
295        self.avg_loss_blur = 0
296        self.steps = 0
297
298    def get_avg_losses(self):
299        """
300        Returns the running average of all individual losses.
301
302        Returns:
303        - tuple: (avg_loss_silog, avg_loss_grad, avg_loss_ssim, avg_loss_l1,
304                avg_loss_edge_aware, avg_loss_var, avg_loss_range, avg_loss_blur)
305        """
306        return (self.avg_loss_silog/self.steps,
307                self.avg_loss_grad/self.steps,
308                self.avg_loss_ssim/self.steps,
309                self.avg_loss_l1/self.steps,
310                self.avg_loss_edge_aware/self.steps,
311                self.avg_loss_var/self.steps,
312                self.avg_loss_range/self.steps,
313                self.avg_loss_blur/self.steps
314               )
315
316    def get_dict(self):
317        """
318        Returns a dictionary of average losses and their corresponding weights.
319
320        Returns:
321        - dict: All loss components with their weights.
322        """
323        loss_silog, loss_grad, loss_ssim, loss_l1, loss_edge_aware, loss_var, loss_range, loss_blur = self.get_avg_losses()
324        return {
325                f"loss_silog": loss_silog, 
326                f"loss_grad": loss_grad, 
327                f"loss_ssim": loss_ssim,
328                f"loss_L1": loss_l1,
329                f"loss_edge aware": loss_edge_aware,
330                f"loss_var": loss_var,
331                f"loss_range": loss_range,
332                f"loss_blur": loss_blur,
333                f"weight_loss_silog": self.weight_silog, 
334                f"weight_loss_grad": self.weight_grad,
335                f"_weight_loss_ssim": self.weight_ssim,
336                f"_weight_loss_L1": self.weight_l1,
337                f"weight_loss_edge_aware": self.weight_edge_aware,
338                f"weight_loss_var": self.weight_var,
339                f"weight_loss_range": self.weight_range,
340                f"weight_loss_blur": self.weight_blur
341               }

Computes a weighted combination of multiple loss functions for image-to-image tasks.

Supported losses: - SILog loss - Gradient L1 loss - SSIM loss - Edge-aware loss - L1 loss - Variance loss - Range loss - Blur loss

The losses can be weighted individually, and average losses are tracked across steps.

WeightedCombinedLoss( silog_lambda=0.5, weight_silog=0.5, weight_grad=10.0, weight_ssim=5.0, weight_edge_aware=10.0, weight_l1=1.0, weight_var=1.0, weight_range=1.0, weight_blur=1.0)
43    def __init__(self, 
44                 silog_lambda=0.5, 
45                 weight_silog=0.5, 
46                 weight_grad=10.0, 
47                 weight_ssim=5.0,
48                 weight_edge_aware=10.0,
49                 weight_l1=1.0,
50                 weight_var=1.0,
51                 weight_range=1.0,
52                 weight_blur=1.0):
53        """
54        Initializes the WeightedCombinedLoss with optional weights for each component.
55
56        Parameter:
57        - silog_lambda (float): 
58            SILog lambda parameter.
59        - weight_silog (float): 
60            Weight for SILog loss.
61        - weight_grad (float): 
62            Weight for gradient L1 loss.
63        - weight_ssim (float): 
64            Weight for SSIM loss.
65        - weight_edge_aware (float): 
66            Weight for edge-aware loss.
67        - weight_l1 (float): 
68            Weight for L1 loss.
69        - weight_var (float): 
70            Weight for variance loss.
71        - weight_range (float): 
72            Weight for range loss.
73        - weight_blur (float): 
74            Weight for blur loss.
75        """
76        super().__init__()
77        self.silog_lambda = silog_lambda
78        self.weight_silog = weight_silog
79        self.weight_grad = weight_grad
80        self.weight_ssim = weight_ssim
81        self.weight_edge_aware = weight_edge_aware
82        self.weight_l1 = weight_l1
83        self.weight_var = weight_var
84        self.weight_range = weight_range
85        self.weight_blur = weight_blur
86
87        self.avg_loss_silog = 0
88        self.avg_loss_grad = 0
89        self.avg_loss_ssim = 0
90        self.avg_loss_l1 = 0
91        self.avg_loss_edge_aware = 0
92        self.avg_loss_var = 0
93        self.avg_loss_range = 0
94        self.avg_loss_blur = 0
95        self.steps = 0
96
97        # Instantiate SSIMLoss module
98        self.ssim_module = kornia.losses.SSIMLoss(window_size=11, reduction='mean')
99        # self.ssim_module = kornia.losses.MS_SSIMLoss(reduction='mean')

Initializes the WeightedCombinedLoss with optional weights for each component.

Parameter:

  • silog_lambda (float): SILog lambda parameter.
  • weight_silog (float): Weight for SILog loss.
  • weight_grad (float): Weight for gradient L1 loss.
  • weight_ssim (float): Weight for SSIM loss.
  • weight_edge_aware (float): Weight for edge-aware loss.
  • weight_l1 (float): Weight for L1 loss.
  • weight_var (float): Weight for variance loss.
  • weight_range (float): Weight for range loss.
  • weight_blur (float): Weight for blur loss.
silog_lambda
weight_silog
weight_grad
weight_ssim
weight_edge_aware
weight_l1
weight_var
weight_range
weight_blur
avg_loss_silog
avg_loss_grad
avg_loss_ssim
avg_loss_l1
avg_loss_edge_aware
avg_loss_var
avg_loss_range
avg_loss_blur
steps
ssim_module
def silog_loss(self, pred, target, weight_map):
102    def silog_loss(self, pred, target, weight_map):
103        eps = 1e-6
104        pred = torch.clamp(pred, min=eps)
105        target = torch.clamp(target, min=eps)
106        
107        diff_log = torch.log(target) - torch.log(pred)
108        diff_log = diff_log * weight_map
109
110        loss = torch.sqrt(torch.mean(diff_log ** 2) -
111                          self.silog_lambda * torch.mean(diff_log) ** 2)
112        return loss
def gradient_l1_loss(self, pred, target, weight_map):
114    def gradient_l1_loss(self, pred, target, weight_map):
115        # Create Channel Dimension
116        if pred.ndim == 3:
117            pred = pred.unsqueeze(1)
118        if target.ndim == 3:
119            target = target.unsqueeze(1)
120        if weight_map.ndim == 3:
121            weight_map = weight_map.unsqueeze(1)
122
123        # Gradient in x-direction (horizontal -> dim=3)
124        pred_grad_x = pred[:, :, :, 1:] - pred[:, :, :, :-1]
125        target_grad_x = target[:, :, :, 1:] - target[:, :, :, :-1]
126
127        # Gradient in y-direction (vertical -> dim=2)
128        pred_grad_y = pred[:, :, 1:, :] - pred[:, :, :-1, :]
129        target_grad_y = target[:, :, 1:, :] - target[:, :, :-1, :]
130
131        weight_x = weight_map[:, :, :, 1:] * weight_map[:, :, :, :-1]
132        weight_y = weight_map[:, :, 1:, :] * weight_map[:, :, :-1, :]
133
134        loss_x = torch.mean(torch.abs(pred_grad_x - target_grad_x) * weight_x)
135        loss_y = torch.mean(torch.abs(pred_grad_y - target_grad_y) * weight_y)
136        
137        # loss_x = F.l1_loss(pred_grad_x, target_grad_x) 
138        # loss_y = F.l1_loss(pred_grad_y, target_grad_y)
139
140        return loss_x + loss_y
def ssim_loss(self, pred, target, weight_map):
142    def ssim_loss(self, pred, target, weight_map):
143        # SSIM returns similarity, so we subtract from 1
144        if pred.ndim == 3:
145            pred = pred.unsqueeze(1)
146        if target.ndim == 3:
147            target = target.unsqueeze(1)
148
149        # self.ssim_module = self.ssim_module.to(pred.device)
150        return self.ssim_module(pred, target)
def edge_aware_loss(self, pred, target, weight_map):
152    def edge_aware_loss(self, pred, target, weight_map):
153        if pred.ndim == 3:
154            pred = pred.unsqueeze(1)
155        if target.ndim == 3:
156            target = target.unsqueeze(1)
157        if weight_map.ndim == 3:
158            weight_map = weight_map.unsqueeze(1)
159
160        pred_grad_x = pred[:, :, :, :-1] - pred[:, :, :, 1:]
161        pred_grad_y = pred[:, :, :-1, :] - pred[:, :, 1:, :]
162
163        target_grad_x = torch.mean(torch.abs(target[:, :, :, :-1] - target[:, :, :, 1:]), 1, keepdim=True)
164        target_grad_y = torch.mean(torch.abs(target[:, :, :-1, :] - target[:, :, 1:, :]), 1, keepdim=True)
165
166        weight_x = weight_map[:, :, :, 1:] * weight_map[:, :, :, :-1]
167        weight_y = weight_map[:, :, 1:, :] * weight_map[:, :, :-1, :]
168
169        pred_grad_x *= torch.exp(-target_grad_x* weight_x) 
170        pred_grad_y *= torch.exp(-target_grad_y* weight_y)
171
172        # return (pred_grad_y.abs().mean() + target_grad_y.abs().mean())
173        return (pred_grad_x.abs().mean() + pred_grad_y.abs().mean())
def l1_loss(self, pred, target, weight_map):
175    def l1_loss(self, pred, target, weight_map):
176        loss = torch.abs(target - pred) * weight_map
177        return loss.mean()
def variance_loss(self, pred, target):
179    def variance_loss(self, pred, target):
180        pred_var = torch.var(pred)
181        target_var = torch.var(target)
182        return F.mse_loss(pred_var, target_var)
def range_loss(self, pred, target):
184    def range_loss(self, pred, target):
185        pred_min, pred_max = torch.min(pred), torch.max(pred)
186        target_min, target_max = torch.min(target), torch.max(target)
187        
188        min_loss = F.mse_loss(pred_min, target_min)
189        max_loss = F.mse_loss(pred_max, target_max)
190        
191        return min_loss + max_loss
def blur_loss(self, pred, target):
208    def blur_loss(self, pred, target):
209        laplacian_kernel = torch.tensor([[[[0, 1, 0],
210                                           [1, -4, 1],
211                                           [0, 1, 0]]]], dtype=pred.dtype, device=pred.device)
212
213        if pred.ndim == 3:
214            pred = pred.unsqueeze(1)
215        if target.ndim == 3:
216            target = target.unsqueeze(1)
217
218        pred_lap = F.conv2d(pred, laplacian_kernel, padding=1)
219        target_lap = F.conv2d(target, laplacian_kernel, padding=1)
220
221        return F.l1_loss(pred_lap, target_lap)
def forward(self, pred, target, weight_map=None, should_calc_weight_map=False):
223    def forward(self, pred, target, weight_map=None, should_calc_weight_map=False):
224        """
225        Computes the weighted combined loss between prediction and target.
226
227        Parameter:
228        - pred (torch.Tensor): 
229            Predicted output tensor.
230        - target (torch.Tensor): 
231            Ground truth tensor.
232        - weight_map (torch.Tensor or None): 
233            Optional pixel-wise weighting map.
234        - should_calc_weight_map (bool): 
235            If True and weight_map is None, calculates a weight map from target.
236
237        Returns:
238        - torch.Tensor: Weighted sum of all losses.
239        """
240        if type(weight_map) == type(None):
241            if should_calc_weight_map:
242                weight_map = calc_weight_map(target)
243            else:
244                # no mask/weight-map
245                # FIXME -> right
246                weight_map = torch.ones_like(pred)
247
248        loss_silog = self.silog_loss(pred, target, weight_map)
249        loss_grad = self.gradient_l1_loss(pred, target, weight_map)
250        loss_ssim = self.ssim_loss(pred, target, weight_map)
251        loss_l1 = self.l1_loss(pred, target, weight_map)
252        loss_edge_aware = self.edge_aware_loss(pred, target, weight_map)
253        loss_var = self.variance_loss(pred, target)
254        loss_range = self.range_loss(pred, target)
255        loss_blur = self.blur_loss(pred, target)
256
257        # reset avgs
258        if self.steps > 24:
259            self.step()
260
261        self.avg_loss_silog += loss_silog
262        self.avg_loss_grad += loss_grad
263        self.avg_loss_ssim += loss_ssim
264        self.avg_loss_l1 += loss_l1
265        self.avg_loss_edge_aware += loss_edge_aware
266        self.avg_loss_var += loss_var
267        self.avg_loss_range += loss_range
268        self.avg_loss_blur += loss_blur
269        self.steps += 1
270
271        total_loss = (
272            self.weight_silog * loss_silog +
273            self.weight_grad * loss_grad +
274            self.weight_ssim * loss_ssim +
275            self.weight_edge_aware * loss_edge_aware +
276            self.weight_l1 * loss_l1 +
277            self.weight_var * loss_var +
278            self.weight_range * loss_range +
279            self.weight_blur * loss_blur
280        )
281
282        return total_loss

Computes the weighted combined loss between prediction and target.

Parameter:

  • pred (torch.Tensor): Predicted output tensor.
  • target (torch.Tensor): Ground truth tensor.
  • weight_map (torch.Tensor or None): Optional pixel-wise weighting map.
  • should_calc_weight_map (bool): If True and weight_map is None, calculates a weight map from target.

Returns:

  • torch.Tensor: Weighted sum of all losses.
def step(self, epoch=None):
284    def step(self, epoch=None):
285        """
286        Resets the running averages of all tracked losses.
287        """
288        self.avg_loss_silog = 0
289        self.avg_loss_grad = 0
290        self.avg_loss_ssim = 0
291        self.avg_loss_l1 = 0
292        self.avg_loss_edge_aware = 0
293        self.avg_loss_var = 0
294        self.avg_loss_range = 0
295        self.avg_loss_blur = 0
296        self.steps = 0

Resets the running averages of all tracked losses.

def get_avg_losses(self):
298    def get_avg_losses(self):
299        """
300        Returns the running average of all individual losses.
301
302        Returns:
303        - tuple: (avg_loss_silog, avg_loss_grad, avg_loss_ssim, avg_loss_l1,
304                avg_loss_edge_aware, avg_loss_var, avg_loss_range, avg_loss_blur)
305        """
306        return (self.avg_loss_silog/self.steps,
307                self.avg_loss_grad/self.steps,
308                self.avg_loss_ssim/self.steps,
309                self.avg_loss_l1/self.steps,
310                self.avg_loss_edge_aware/self.steps,
311                self.avg_loss_var/self.steps,
312                self.avg_loss_range/self.steps,
313                self.avg_loss_blur/self.steps
314               )

Returns the running average of all individual losses.

Returns:

  • tuple: (avg_loss_silog, avg_loss_grad, avg_loss_ssim, avg_loss_l1, avg_loss_edge_aware, avg_loss_var, avg_loss_range, avg_loss_blur)
def get_dict(self):
316    def get_dict(self):
317        """
318        Returns a dictionary of average losses and their corresponding weights.
319
320        Returns:
321        - dict: All loss components with their weights.
322        """
323        loss_silog, loss_grad, loss_ssim, loss_l1, loss_edge_aware, loss_var, loss_range, loss_blur = self.get_avg_losses()
324        return {
325                f"loss_silog": loss_silog, 
326                f"loss_grad": loss_grad, 
327                f"loss_ssim": loss_ssim,
328                f"loss_L1": loss_l1,
329                f"loss_edge aware": loss_edge_aware,
330                f"loss_var": loss_var,
331                f"loss_range": loss_range,
332                f"loss_blur": loss_blur,
333                f"weight_loss_silog": self.weight_silog, 
334                f"weight_loss_grad": self.weight_grad,
335                f"_weight_loss_ssim": self.weight_ssim,
336                f"_weight_loss_L1": self.weight_l1,
337                f"weight_loss_edge_aware": self.weight_edge_aware,
338                f"weight_loss_var": self.weight_var,
339                f"weight_loss_range": self.weight_range,
340                f"weight_loss_blur": self.weight_blur
341               }

Returns a dictionary of average losses and their corresponding weights.

Returns:

  • dict: All loss components with their weights.
def calc_weight_map(target):
343def calc_weight_map(target):
344    """
345    Calculates a per-pixel weighting map for a target tensor based on unique value frequencies.
346
347    Less frequent values are given higher weights to emphasize their contribution in loss computations.
348
349    Parameter:
350    - target (torch.Tensor): 
351        Ground truth tensor.
352
353    Returns:
354    - torch.Tensor: Weight map tensor of the same shape as target.
355    """
356    values, counts = torch.unique(target.flatten(), return_counts=True)
357    all_counts = counts.sum().float()
358    
359    # weight_factor = 2.0
360    # weights = {values[idx].item(): max(torch.exp( ( (1-(counts[idx].item()/all_counts))) *weight_factor), 0.0001) for idx in range(len(values))}
361    
362    weights = {values[idx].item(): 255.0/counts[idx].item() for idx in range(len(values))}
363
364    # print(f"Weights:")
365    # for cur_value, cur_counts in list(sorted(weights.items(), key=lambda x:x[0])):
366    #     print('    - '+str(round(cur_value, 4))+': '+str(cur_counts.item()))
367
368    weights_map = torch.zeros_like(target, dtype=torch.float)
369    for cur_value in values:
370        cur_value = cur_value.item()
371        weights_map[target == cur_value] = weights[cur_value]
372
373    return weights_map

Calculates a per-pixel weighting map for a target tensor based on unique value frequencies.

Less frequent values are given higher weights to emphasize their contribution in loss computations.

Parameter:

  • target (torch.Tensor): Ground truth tensor.

Returns:

  • torch.Tensor: Weight map tensor of the same shape as target.