image_to_image.losses.weighted_combined_loss
Module to define a Weighted Combine Loss.
Functions:
- calc_weight_map
Classes:
- WeightedCombinedLoss
By Tobia Ippolito
1""" 2Module to define a Weighted Combine Loss. 3 4Functions: 5- calc_weight_map 6 7Classes: 8- WeightedCombinedLoss 9 10By Tobia Ippolito 11""" 12# --------------------------- 13# > Imports < 14# --------------------------- 15import torch 16import torch.nn as nn 17import torch.nn.functional as F 18 19import kornia 20 21 22 23# --------------------------- 24# > Loss Implementation < 25# --------------------------- 26class WeightedCombinedLoss(nn.Module): 27 """ 28 Computes a weighted combination of multiple loss functions for image-to-image tasks. 29 30 Supported losses: 31 - SILog loss 32 - Gradient L1 loss 33 - SSIM loss 34 - Edge-aware loss 35 - L1 loss 36 - Variance loss 37 - Range loss 38 - Blur loss 39 40 The losses can be weighted individually, and average losses are tracked across steps. 41 """ 42 def __init__(self, 43 silog_lambda=0.5, 44 weight_silog=0.5, 45 weight_grad=10.0, 46 weight_ssim=5.0, 47 weight_edge_aware=10.0, 48 weight_l1=1.0, 49 weight_var=1.0, 50 weight_range=1.0, 51 weight_blur=1.0): 52 """ 53 Initializes the WeightedCombinedLoss with optional weights for each component. 54 55 Parameter: 56 - silog_lambda (float): 57 SILog lambda parameter. 58 - weight_silog (float): 59 Weight for SILog loss. 60 - weight_grad (float): 61 Weight for gradient L1 loss. 62 - weight_ssim (float): 63 Weight for SSIM loss. 64 - weight_edge_aware (float): 65 Weight for edge-aware loss. 66 - weight_l1 (float): 67 Weight for L1 loss. 68 - weight_var (float): 69 Weight for variance loss. 70 - weight_range (float): 71 Weight for range loss. 72 - weight_blur (float): 73 Weight for blur loss. 74 """ 75 super().__init__() 76 self.silog_lambda = silog_lambda 77 self.weight_silog = weight_silog 78 self.weight_grad = weight_grad 79 self.weight_ssim = weight_ssim 80 self.weight_edge_aware = weight_edge_aware 81 self.weight_l1 = weight_l1 82 self.weight_var = weight_var 83 self.weight_range = weight_range 84 self.weight_blur = weight_blur 85 86 self.avg_loss_silog = 0 87 self.avg_loss_grad = 0 88 self.avg_loss_ssim = 0 89 self.avg_loss_l1 = 0 90 self.avg_loss_edge_aware = 0 91 self.avg_loss_var = 0 92 self.avg_loss_range = 0 93 self.avg_loss_blur = 0 94 self.steps = 0 95 96 # Instantiate SSIMLoss module 97 self.ssim_module = kornia.losses.SSIMLoss(window_size=11, reduction='mean') 98 # self.ssim_module = kornia.losses.MS_SSIMLoss(reduction='mean') 99 100 101 def silog_loss(self, pred, target, weight_map): 102 eps = 1e-6 103 pred = torch.clamp(pred, min=eps) 104 target = torch.clamp(target, min=eps) 105 106 diff_log = torch.log(target) - torch.log(pred) 107 diff_log = diff_log * weight_map 108 109 loss = torch.sqrt(torch.mean(diff_log ** 2) - 110 self.silog_lambda * torch.mean(diff_log) ** 2) 111 return loss 112 113 def gradient_l1_loss(self, pred, target, weight_map): 114 # Create Channel Dimension 115 if pred.ndim == 3: 116 pred = pred.unsqueeze(1) 117 if target.ndim == 3: 118 target = target.unsqueeze(1) 119 if weight_map.ndim == 3: 120 weight_map = weight_map.unsqueeze(1) 121 122 # Gradient in x-direction (horizontal -> dim=3) 123 pred_grad_x = pred[:, :, :, 1:] - pred[:, :, :, :-1] 124 target_grad_x = target[:, :, :, 1:] - target[:, :, :, :-1] 125 126 # Gradient in y-direction (vertical -> dim=2) 127 pred_grad_y = pred[:, :, 1:, :] - pred[:, :, :-1, :] 128 target_grad_y = target[:, :, 1:, :] - target[:, :, :-1, :] 129 130 weight_x = weight_map[:, :, :, 1:] * weight_map[:, :, :, :-1] 131 weight_y = weight_map[:, :, 1:, :] * weight_map[:, :, :-1, :] 132 133 loss_x = torch.mean(torch.abs(pred_grad_x - target_grad_x) * weight_x) 134 loss_y = torch.mean(torch.abs(pred_grad_y - target_grad_y) * weight_y) 135 136 # loss_x = F.l1_loss(pred_grad_x, target_grad_x) 137 # loss_y = F.l1_loss(pred_grad_y, target_grad_y) 138 139 return loss_x + loss_y 140 141 def ssim_loss(self, pred, target, weight_map): 142 # SSIM returns similarity, so we subtract from 1 143 if pred.ndim == 3: 144 pred = pred.unsqueeze(1) 145 if target.ndim == 3: 146 target = target.unsqueeze(1) 147 148 # self.ssim_module = self.ssim_module.to(pred.device) 149 return self.ssim_module(pred, target) 150 151 def edge_aware_loss(self, pred, target, weight_map): 152 if pred.ndim == 3: 153 pred = pred.unsqueeze(1) 154 if target.ndim == 3: 155 target = target.unsqueeze(1) 156 if weight_map.ndim == 3: 157 weight_map = weight_map.unsqueeze(1) 158 159 pred_grad_x = pred[:, :, :, :-1] - pred[:, :, :, 1:] 160 pred_grad_y = pred[:, :, :-1, :] - pred[:, :, 1:, :] 161 162 target_grad_x = torch.mean(torch.abs(target[:, :, :, :-1] - target[:, :, :, 1:]), 1, keepdim=True) 163 target_grad_y = torch.mean(torch.abs(target[:, :, :-1, :] - target[:, :, 1:, :]), 1, keepdim=True) 164 165 weight_x = weight_map[:, :, :, 1:] * weight_map[:, :, :, :-1] 166 weight_y = weight_map[:, :, 1:, :] * weight_map[:, :, :-1, :] 167 168 pred_grad_x *= torch.exp(-target_grad_x* weight_x) 169 pred_grad_y *= torch.exp(-target_grad_y* weight_y) 170 171 # return (pred_grad_y.abs().mean() + target_grad_y.abs().mean()) 172 return (pred_grad_x.abs().mean() + pred_grad_y.abs().mean()) 173 174 def l1_loss(self, pred, target, weight_map): 175 loss = torch.abs(target - pred) * weight_map 176 return loss.mean() 177 178 def variance_loss(self, pred, target): 179 pred_var = torch.var(pred) 180 target_var = torch.var(target) 181 return F.mse_loss(pred_var, target_var) 182 183 def range_loss(self, pred, target): 184 pred_min, pred_max = torch.min(pred), torch.max(pred) 185 target_min, target_max = torch.min(target), torch.max(target) 186 187 min_loss = F.mse_loss(pred_min, target_min) 188 max_loss = F.mse_loss(pred_max, target_max) 189 190 return min_loss + max_loss 191 192 def blur_loss(self, pred, target): 193 laplacian_kernel = torch.tensor([[[[0, 1, 0], 194 [1, -4, 1], 195 [0, 1, 0]]]], dtype=pred.dtype, device=pred.device) 196 197 if pred.ndim == 3: 198 pred = pred.unsqueeze(1) 199 if target.ndim == 3: 200 target = target.unsqueeze(1) 201 202 pred_lap = F.conv2d(pred, laplacian_kernel, padding=1) 203 target_lap = F.conv2d(target, laplacian_kernel, padding=1) 204 205 return F.l1_loss(pred_lap, target_lap) 206 207 def blur_loss(self, pred, target): 208 laplacian_kernel = torch.tensor([[[[0, 1, 0], 209 [1, -4, 1], 210 [0, 1, 0]]]], dtype=pred.dtype, device=pred.device) 211 212 if pred.ndim == 3: 213 pred = pred.unsqueeze(1) 214 if target.ndim == 3: 215 target = target.unsqueeze(1) 216 217 pred_lap = F.conv2d(pred, laplacian_kernel, padding=1) 218 target_lap = F.conv2d(target, laplacian_kernel, padding=1) 219 220 return F.l1_loss(pred_lap, target_lap) 221 222 def forward(self, pred, target, weight_map=None, should_calc_weight_map=False): 223 """ 224 Computes the weighted combined loss between prediction and target. 225 226 Parameter: 227 - pred (torch.Tensor): 228 Predicted output tensor. 229 - target (torch.Tensor): 230 Ground truth tensor. 231 - weight_map (torch.Tensor or None): 232 Optional pixel-wise weighting map. 233 - should_calc_weight_map (bool): 234 If True and weight_map is None, calculates a weight map from target. 235 236 Returns: 237 - torch.Tensor: Weighted sum of all losses. 238 """ 239 if type(weight_map) == type(None): 240 if should_calc_weight_map: 241 weight_map = calc_weight_map(target) 242 else: 243 # no mask/weight-map 244 # FIXME -> right 245 weight_map = torch.ones_like(pred) 246 247 loss_silog = self.silog_loss(pred, target, weight_map) 248 loss_grad = self.gradient_l1_loss(pred, target, weight_map) 249 loss_ssim = self.ssim_loss(pred, target, weight_map) 250 loss_l1 = self.l1_loss(pred, target, weight_map) 251 loss_edge_aware = self.edge_aware_loss(pred, target, weight_map) 252 loss_var = self.variance_loss(pred, target) 253 loss_range = self.range_loss(pred, target) 254 loss_blur = self.blur_loss(pred, target) 255 256 # reset avgs 257 if self.steps > 24: 258 self.step() 259 260 self.avg_loss_silog += loss_silog 261 self.avg_loss_grad += loss_grad 262 self.avg_loss_ssim += loss_ssim 263 self.avg_loss_l1 += loss_l1 264 self.avg_loss_edge_aware += loss_edge_aware 265 self.avg_loss_var += loss_var 266 self.avg_loss_range += loss_range 267 self.avg_loss_blur += loss_blur 268 self.steps += 1 269 270 total_loss = ( 271 self.weight_silog * loss_silog + 272 self.weight_grad * loss_grad + 273 self.weight_ssim * loss_ssim + 274 self.weight_edge_aware * loss_edge_aware + 275 self.weight_l1 * loss_l1 + 276 self.weight_var * loss_var + 277 self.weight_range * loss_range + 278 self.weight_blur * loss_blur 279 ) 280 281 return total_loss 282 283 def step(self, epoch=None): 284 """ 285 Resets the running averages of all tracked losses. 286 """ 287 self.avg_loss_silog = 0 288 self.avg_loss_grad = 0 289 self.avg_loss_ssim = 0 290 self.avg_loss_l1 = 0 291 self.avg_loss_edge_aware = 0 292 self.avg_loss_var = 0 293 self.avg_loss_range = 0 294 self.avg_loss_blur = 0 295 self.steps = 0 296 297 def get_avg_losses(self): 298 """ 299 Returns the running average of all individual losses. 300 301 Returns: 302 - tuple: (avg_loss_silog, avg_loss_grad, avg_loss_ssim, avg_loss_l1, 303 avg_loss_edge_aware, avg_loss_var, avg_loss_range, avg_loss_blur) 304 """ 305 return (self.avg_loss_silog/self.steps, 306 self.avg_loss_grad/self.steps, 307 self.avg_loss_ssim/self.steps, 308 self.avg_loss_l1/self.steps, 309 self.avg_loss_edge_aware/self.steps, 310 self.avg_loss_var/self.steps, 311 self.avg_loss_range/self.steps, 312 self.avg_loss_blur/self.steps 313 ) 314 315 def get_dict(self): 316 """ 317 Returns a dictionary of average losses and their corresponding weights. 318 319 Returns: 320 - dict: All loss components with their weights. 321 """ 322 loss_silog, loss_grad, loss_ssim, loss_l1, loss_edge_aware, loss_var, loss_range, loss_blur = self.get_avg_losses() 323 return { 324 f"loss_silog": loss_silog, 325 f"loss_grad": loss_grad, 326 f"loss_ssim": loss_ssim, 327 f"loss_L1": loss_l1, 328 f"loss_edge aware": loss_edge_aware, 329 f"loss_var": loss_var, 330 f"loss_range": loss_range, 331 f"loss_blur": loss_blur, 332 f"weight_loss_silog": self.weight_silog, 333 f"weight_loss_grad": self.weight_grad, 334 f"_weight_loss_ssim": self.weight_ssim, 335 f"_weight_loss_L1": self.weight_l1, 336 f"weight_loss_edge_aware": self.weight_edge_aware, 337 f"weight_loss_var": self.weight_var, 338 f"weight_loss_range": self.weight_range, 339 f"weight_loss_blur": self.weight_blur 340 } 341 342def calc_weight_map(target): 343 """ 344 Calculates a per-pixel weighting map for a target tensor based on unique value frequencies. 345 346 Less frequent values are given higher weights to emphasize their contribution in loss computations. 347 348 Parameter: 349 - target (torch.Tensor): 350 Ground truth tensor. 351 352 Returns: 353 - torch.Tensor: Weight map tensor of the same shape as target. 354 """ 355 values, counts = torch.unique(target.flatten(), return_counts=True) 356 all_counts = counts.sum().float() 357 358 # weight_factor = 2.0 359 # weights = {values[idx].item(): max(torch.exp( ( (1-(counts[idx].item()/all_counts))) *weight_factor), 0.0001) for idx in range(len(values))} 360 361 weights = {values[idx].item(): 255.0/counts[idx].item() for idx in range(len(values))} 362 363 # print(f"Weights:") 364 # for cur_value, cur_counts in list(sorted(weights.items(), key=lambda x:x[0])): 365 # print(' - '+str(round(cur_value, 4))+': '+str(cur_counts.item())) 366 367 weights_map = torch.zeros_like(target, dtype=torch.float) 368 for cur_value in values: 369 cur_value = cur_value.item() 370 weights_map[target == cur_value] = weights[cur_value] 371 372 return weights_map
27class WeightedCombinedLoss(nn.Module): 28 """ 29 Computes a weighted combination of multiple loss functions for image-to-image tasks. 30 31 Supported losses: 32 - SILog loss 33 - Gradient L1 loss 34 - SSIM loss 35 - Edge-aware loss 36 - L1 loss 37 - Variance loss 38 - Range loss 39 - Blur loss 40 41 The losses can be weighted individually, and average losses are tracked across steps. 42 """ 43 def __init__(self, 44 silog_lambda=0.5, 45 weight_silog=0.5, 46 weight_grad=10.0, 47 weight_ssim=5.0, 48 weight_edge_aware=10.0, 49 weight_l1=1.0, 50 weight_var=1.0, 51 weight_range=1.0, 52 weight_blur=1.0): 53 """ 54 Initializes the WeightedCombinedLoss with optional weights for each component. 55 56 Parameter: 57 - silog_lambda (float): 58 SILog lambda parameter. 59 - weight_silog (float): 60 Weight for SILog loss. 61 - weight_grad (float): 62 Weight for gradient L1 loss. 63 - weight_ssim (float): 64 Weight for SSIM loss. 65 - weight_edge_aware (float): 66 Weight for edge-aware loss. 67 - weight_l1 (float): 68 Weight for L1 loss. 69 - weight_var (float): 70 Weight for variance loss. 71 - weight_range (float): 72 Weight for range loss. 73 - weight_blur (float): 74 Weight for blur loss. 75 """ 76 super().__init__() 77 self.silog_lambda = silog_lambda 78 self.weight_silog = weight_silog 79 self.weight_grad = weight_grad 80 self.weight_ssim = weight_ssim 81 self.weight_edge_aware = weight_edge_aware 82 self.weight_l1 = weight_l1 83 self.weight_var = weight_var 84 self.weight_range = weight_range 85 self.weight_blur = weight_blur 86 87 self.avg_loss_silog = 0 88 self.avg_loss_grad = 0 89 self.avg_loss_ssim = 0 90 self.avg_loss_l1 = 0 91 self.avg_loss_edge_aware = 0 92 self.avg_loss_var = 0 93 self.avg_loss_range = 0 94 self.avg_loss_blur = 0 95 self.steps = 0 96 97 # Instantiate SSIMLoss module 98 self.ssim_module = kornia.losses.SSIMLoss(window_size=11, reduction='mean') 99 # self.ssim_module = kornia.losses.MS_SSIMLoss(reduction='mean') 100 101 102 def silog_loss(self, pred, target, weight_map): 103 eps = 1e-6 104 pred = torch.clamp(pred, min=eps) 105 target = torch.clamp(target, min=eps) 106 107 diff_log = torch.log(target) - torch.log(pred) 108 diff_log = diff_log * weight_map 109 110 loss = torch.sqrt(torch.mean(diff_log ** 2) - 111 self.silog_lambda * torch.mean(diff_log) ** 2) 112 return loss 113 114 def gradient_l1_loss(self, pred, target, weight_map): 115 # Create Channel Dimension 116 if pred.ndim == 3: 117 pred = pred.unsqueeze(1) 118 if target.ndim == 3: 119 target = target.unsqueeze(1) 120 if weight_map.ndim == 3: 121 weight_map = weight_map.unsqueeze(1) 122 123 # Gradient in x-direction (horizontal -> dim=3) 124 pred_grad_x = pred[:, :, :, 1:] - pred[:, :, :, :-1] 125 target_grad_x = target[:, :, :, 1:] - target[:, :, :, :-1] 126 127 # Gradient in y-direction (vertical -> dim=2) 128 pred_grad_y = pred[:, :, 1:, :] - pred[:, :, :-1, :] 129 target_grad_y = target[:, :, 1:, :] - target[:, :, :-1, :] 130 131 weight_x = weight_map[:, :, :, 1:] * weight_map[:, :, :, :-1] 132 weight_y = weight_map[:, :, 1:, :] * weight_map[:, :, :-1, :] 133 134 loss_x = torch.mean(torch.abs(pred_grad_x - target_grad_x) * weight_x) 135 loss_y = torch.mean(torch.abs(pred_grad_y - target_grad_y) * weight_y) 136 137 # loss_x = F.l1_loss(pred_grad_x, target_grad_x) 138 # loss_y = F.l1_loss(pred_grad_y, target_grad_y) 139 140 return loss_x + loss_y 141 142 def ssim_loss(self, pred, target, weight_map): 143 # SSIM returns similarity, so we subtract from 1 144 if pred.ndim == 3: 145 pred = pred.unsqueeze(1) 146 if target.ndim == 3: 147 target = target.unsqueeze(1) 148 149 # self.ssim_module = self.ssim_module.to(pred.device) 150 return self.ssim_module(pred, target) 151 152 def edge_aware_loss(self, pred, target, weight_map): 153 if pred.ndim == 3: 154 pred = pred.unsqueeze(1) 155 if target.ndim == 3: 156 target = target.unsqueeze(1) 157 if weight_map.ndim == 3: 158 weight_map = weight_map.unsqueeze(1) 159 160 pred_grad_x = pred[:, :, :, :-1] - pred[:, :, :, 1:] 161 pred_grad_y = pred[:, :, :-1, :] - pred[:, :, 1:, :] 162 163 target_grad_x = torch.mean(torch.abs(target[:, :, :, :-1] - target[:, :, :, 1:]), 1, keepdim=True) 164 target_grad_y = torch.mean(torch.abs(target[:, :, :-1, :] - target[:, :, 1:, :]), 1, keepdim=True) 165 166 weight_x = weight_map[:, :, :, 1:] * weight_map[:, :, :, :-1] 167 weight_y = weight_map[:, :, 1:, :] * weight_map[:, :, :-1, :] 168 169 pred_grad_x *= torch.exp(-target_grad_x* weight_x) 170 pred_grad_y *= torch.exp(-target_grad_y* weight_y) 171 172 # return (pred_grad_y.abs().mean() + target_grad_y.abs().mean()) 173 return (pred_grad_x.abs().mean() + pred_grad_y.abs().mean()) 174 175 def l1_loss(self, pred, target, weight_map): 176 loss = torch.abs(target - pred) * weight_map 177 return loss.mean() 178 179 def variance_loss(self, pred, target): 180 pred_var = torch.var(pred) 181 target_var = torch.var(target) 182 return F.mse_loss(pred_var, target_var) 183 184 def range_loss(self, pred, target): 185 pred_min, pred_max = torch.min(pred), torch.max(pred) 186 target_min, target_max = torch.min(target), torch.max(target) 187 188 min_loss = F.mse_loss(pred_min, target_min) 189 max_loss = F.mse_loss(pred_max, target_max) 190 191 return min_loss + max_loss 192 193 def blur_loss(self, pred, target): 194 laplacian_kernel = torch.tensor([[[[0, 1, 0], 195 [1, -4, 1], 196 [0, 1, 0]]]], dtype=pred.dtype, device=pred.device) 197 198 if pred.ndim == 3: 199 pred = pred.unsqueeze(1) 200 if target.ndim == 3: 201 target = target.unsqueeze(1) 202 203 pred_lap = F.conv2d(pred, laplacian_kernel, padding=1) 204 target_lap = F.conv2d(target, laplacian_kernel, padding=1) 205 206 return F.l1_loss(pred_lap, target_lap) 207 208 def blur_loss(self, pred, target): 209 laplacian_kernel = torch.tensor([[[[0, 1, 0], 210 [1, -4, 1], 211 [0, 1, 0]]]], dtype=pred.dtype, device=pred.device) 212 213 if pred.ndim == 3: 214 pred = pred.unsqueeze(1) 215 if target.ndim == 3: 216 target = target.unsqueeze(1) 217 218 pred_lap = F.conv2d(pred, laplacian_kernel, padding=1) 219 target_lap = F.conv2d(target, laplacian_kernel, padding=1) 220 221 return F.l1_loss(pred_lap, target_lap) 222 223 def forward(self, pred, target, weight_map=None, should_calc_weight_map=False): 224 """ 225 Computes the weighted combined loss between prediction and target. 226 227 Parameter: 228 - pred (torch.Tensor): 229 Predicted output tensor. 230 - target (torch.Tensor): 231 Ground truth tensor. 232 - weight_map (torch.Tensor or None): 233 Optional pixel-wise weighting map. 234 - should_calc_weight_map (bool): 235 If True and weight_map is None, calculates a weight map from target. 236 237 Returns: 238 - torch.Tensor: Weighted sum of all losses. 239 """ 240 if type(weight_map) == type(None): 241 if should_calc_weight_map: 242 weight_map = calc_weight_map(target) 243 else: 244 # no mask/weight-map 245 # FIXME -> right 246 weight_map = torch.ones_like(pred) 247 248 loss_silog = self.silog_loss(pred, target, weight_map) 249 loss_grad = self.gradient_l1_loss(pred, target, weight_map) 250 loss_ssim = self.ssim_loss(pred, target, weight_map) 251 loss_l1 = self.l1_loss(pred, target, weight_map) 252 loss_edge_aware = self.edge_aware_loss(pred, target, weight_map) 253 loss_var = self.variance_loss(pred, target) 254 loss_range = self.range_loss(pred, target) 255 loss_blur = self.blur_loss(pred, target) 256 257 # reset avgs 258 if self.steps > 24: 259 self.step() 260 261 self.avg_loss_silog += loss_silog 262 self.avg_loss_grad += loss_grad 263 self.avg_loss_ssim += loss_ssim 264 self.avg_loss_l1 += loss_l1 265 self.avg_loss_edge_aware += loss_edge_aware 266 self.avg_loss_var += loss_var 267 self.avg_loss_range += loss_range 268 self.avg_loss_blur += loss_blur 269 self.steps += 1 270 271 total_loss = ( 272 self.weight_silog * loss_silog + 273 self.weight_grad * loss_grad + 274 self.weight_ssim * loss_ssim + 275 self.weight_edge_aware * loss_edge_aware + 276 self.weight_l1 * loss_l1 + 277 self.weight_var * loss_var + 278 self.weight_range * loss_range + 279 self.weight_blur * loss_blur 280 ) 281 282 return total_loss 283 284 def step(self, epoch=None): 285 """ 286 Resets the running averages of all tracked losses. 287 """ 288 self.avg_loss_silog = 0 289 self.avg_loss_grad = 0 290 self.avg_loss_ssim = 0 291 self.avg_loss_l1 = 0 292 self.avg_loss_edge_aware = 0 293 self.avg_loss_var = 0 294 self.avg_loss_range = 0 295 self.avg_loss_blur = 0 296 self.steps = 0 297 298 def get_avg_losses(self): 299 """ 300 Returns the running average of all individual losses. 301 302 Returns: 303 - tuple: (avg_loss_silog, avg_loss_grad, avg_loss_ssim, avg_loss_l1, 304 avg_loss_edge_aware, avg_loss_var, avg_loss_range, avg_loss_blur) 305 """ 306 return (self.avg_loss_silog/self.steps, 307 self.avg_loss_grad/self.steps, 308 self.avg_loss_ssim/self.steps, 309 self.avg_loss_l1/self.steps, 310 self.avg_loss_edge_aware/self.steps, 311 self.avg_loss_var/self.steps, 312 self.avg_loss_range/self.steps, 313 self.avg_loss_blur/self.steps 314 ) 315 316 def get_dict(self): 317 """ 318 Returns a dictionary of average losses and their corresponding weights. 319 320 Returns: 321 - dict: All loss components with their weights. 322 """ 323 loss_silog, loss_grad, loss_ssim, loss_l1, loss_edge_aware, loss_var, loss_range, loss_blur = self.get_avg_losses() 324 return { 325 f"loss_silog": loss_silog, 326 f"loss_grad": loss_grad, 327 f"loss_ssim": loss_ssim, 328 f"loss_L1": loss_l1, 329 f"loss_edge aware": loss_edge_aware, 330 f"loss_var": loss_var, 331 f"loss_range": loss_range, 332 f"loss_blur": loss_blur, 333 f"weight_loss_silog": self.weight_silog, 334 f"weight_loss_grad": self.weight_grad, 335 f"_weight_loss_ssim": self.weight_ssim, 336 f"_weight_loss_L1": self.weight_l1, 337 f"weight_loss_edge_aware": self.weight_edge_aware, 338 f"weight_loss_var": self.weight_var, 339 f"weight_loss_range": self.weight_range, 340 f"weight_loss_blur": self.weight_blur 341 }
Computes a weighted combination of multiple loss functions for image-to-image tasks.
Supported losses: - SILog loss - Gradient L1 loss - SSIM loss - Edge-aware loss - L1 loss - Variance loss - Range loss - Blur loss
The losses can be weighted individually, and average losses are tracked across steps.
43 def __init__(self, 44 silog_lambda=0.5, 45 weight_silog=0.5, 46 weight_grad=10.0, 47 weight_ssim=5.0, 48 weight_edge_aware=10.0, 49 weight_l1=1.0, 50 weight_var=1.0, 51 weight_range=1.0, 52 weight_blur=1.0): 53 """ 54 Initializes the WeightedCombinedLoss with optional weights for each component. 55 56 Parameter: 57 - silog_lambda (float): 58 SILog lambda parameter. 59 - weight_silog (float): 60 Weight for SILog loss. 61 - weight_grad (float): 62 Weight for gradient L1 loss. 63 - weight_ssim (float): 64 Weight for SSIM loss. 65 - weight_edge_aware (float): 66 Weight for edge-aware loss. 67 - weight_l1 (float): 68 Weight for L1 loss. 69 - weight_var (float): 70 Weight for variance loss. 71 - weight_range (float): 72 Weight for range loss. 73 - weight_blur (float): 74 Weight for blur loss. 75 """ 76 super().__init__() 77 self.silog_lambda = silog_lambda 78 self.weight_silog = weight_silog 79 self.weight_grad = weight_grad 80 self.weight_ssim = weight_ssim 81 self.weight_edge_aware = weight_edge_aware 82 self.weight_l1 = weight_l1 83 self.weight_var = weight_var 84 self.weight_range = weight_range 85 self.weight_blur = weight_blur 86 87 self.avg_loss_silog = 0 88 self.avg_loss_grad = 0 89 self.avg_loss_ssim = 0 90 self.avg_loss_l1 = 0 91 self.avg_loss_edge_aware = 0 92 self.avg_loss_var = 0 93 self.avg_loss_range = 0 94 self.avg_loss_blur = 0 95 self.steps = 0 96 97 # Instantiate SSIMLoss module 98 self.ssim_module = kornia.losses.SSIMLoss(window_size=11, reduction='mean') 99 # self.ssim_module = kornia.losses.MS_SSIMLoss(reduction='mean')
Initializes the WeightedCombinedLoss with optional weights for each component.
Parameter:
- silog_lambda (float): SILog lambda parameter.
- weight_silog (float): Weight for SILog loss.
- weight_grad (float): Weight for gradient L1 loss.
- weight_ssim (float): Weight for SSIM loss.
- weight_edge_aware (float): Weight for edge-aware loss.
- weight_l1 (float): Weight for L1 loss.
- weight_var (float): Weight for variance loss.
- weight_range (float): Weight for range loss.
- weight_blur (float): Weight for blur loss.
102 def silog_loss(self, pred, target, weight_map): 103 eps = 1e-6 104 pred = torch.clamp(pred, min=eps) 105 target = torch.clamp(target, min=eps) 106 107 diff_log = torch.log(target) - torch.log(pred) 108 diff_log = diff_log * weight_map 109 110 loss = torch.sqrt(torch.mean(diff_log ** 2) - 111 self.silog_lambda * torch.mean(diff_log) ** 2) 112 return loss
114 def gradient_l1_loss(self, pred, target, weight_map): 115 # Create Channel Dimension 116 if pred.ndim == 3: 117 pred = pred.unsqueeze(1) 118 if target.ndim == 3: 119 target = target.unsqueeze(1) 120 if weight_map.ndim == 3: 121 weight_map = weight_map.unsqueeze(1) 122 123 # Gradient in x-direction (horizontal -> dim=3) 124 pred_grad_x = pred[:, :, :, 1:] - pred[:, :, :, :-1] 125 target_grad_x = target[:, :, :, 1:] - target[:, :, :, :-1] 126 127 # Gradient in y-direction (vertical -> dim=2) 128 pred_grad_y = pred[:, :, 1:, :] - pred[:, :, :-1, :] 129 target_grad_y = target[:, :, 1:, :] - target[:, :, :-1, :] 130 131 weight_x = weight_map[:, :, :, 1:] * weight_map[:, :, :, :-1] 132 weight_y = weight_map[:, :, 1:, :] * weight_map[:, :, :-1, :] 133 134 loss_x = torch.mean(torch.abs(pred_grad_x - target_grad_x) * weight_x) 135 loss_y = torch.mean(torch.abs(pred_grad_y - target_grad_y) * weight_y) 136 137 # loss_x = F.l1_loss(pred_grad_x, target_grad_x) 138 # loss_y = F.l1_loss(pred_grad_y, target_grad_y) 139 140 return loss_x + loss_y
142 def ssim_loss(self, pred, target, weight_map): 143 # SSIM returns similarity, so we subtract from 1 144 if pred.ndim == 3: 145 pred = pred.unsqueeze(1) 146 if target.ndim == 3: 147 target = target.unsqueeze(1) 148 149 # self.ssim_module = self.ssim_module.to(pred.device) 150 return self.ssim_module(pred, target)
152 def edge_aware_loss(self, pred, target, weight_map): 153 if pred.ndim == 3: 154 pred = pred.unsqueeze(1) 155 if target.ndim == 3: 156 target = target.unsqueeze(1) 157 if weight_map.ndim == 3: 158 weight_map = weight_map.unsqueeze(1) 159 160 pred_grad_x = pred[:, :, :, :-1] - pred[:, :, :, 1:] 161 pred_grad_y = pred[:, :, :-1, :] - pred[:, :, 1:, :] 162 163 target_grad_x = torch.mean(torch.abs(target[:, :, :, :-1] - target[:, :, :, 1:]), 1, keepdim=True) 164 target_grad_y = torch.mean(torch.abs(target[:, :, :-1, :] - target[:, :, 1:, :]), 1, keepdim=True) 165 166 weight_x = weight_map[:, :, :, 1:] * weight_map[:, :, :, :-1] 167 weight_y = weight_map[:, :, 1:, :] * weight_map[:, :, :-1, :] 168 169 pred_grad_x *= torch.exp(-target_grad_x* weight_x) 170 pred_grad_y *= torch.exp(-target_grad_y* weight_y) 171 172 # return (pred_grad_y.abs().mean() + target_grad_y.abs().mean()) 173 return (pred_grad_x.abs().mean() + pred_grad_y.abs().mean())
208 def blur_loss(self, pred, target): 209 laplacian_kernel = torch.tensor([[[[0, 1, 0], 210 [1, -4, 1], 211 [0, 1, 0]]]], dtype=pred.dtype, device=pred.device) 212 213 if pred.ndim == 3: 214 pred = pred.unsqueeze(1) 215 if target.ndim == 3: 216 target = target.unsqueeze(1) 217 218 pred_lap = F.conv2d(pred, laplacian_kernel, padding=1) 219 target_lap = F.conv2d(target, laplacian_kernel, padding=1) 220 221 return F.l1_loss(pred_lap, target_lap)
223 def forward(self, pred, target, weight_map=None, should_calc_weight_map=False): 224 """ 225 Computes the weighted combined loss between prediction and target. 226 227 Parameter: 228 - pred (torch.Tensor): 229 Predicted output tensor. 230 - target (torch.Tensor): 231 Ground truth tensor. 232 - weight_map (torch.Tensor or None): 233 Optional pixel-wise weighting map. 234 - should_calc_weight_map (bool): 235 If True and weight_map is None, calculates a weight map from target. 236 237 Returns: 238 - torch.Tensor: Weighted sum of all losses. 239 """ 240 if type(weight_map) == type(None): 241 if should_calc_weight_map: 242 weight_map = calc_weight_map(target) 243 else: 244 # no mask/weight-map 245 # FIXME -> right 246 weight_map = torch.ones_like(pred) 247 248 loss_silog = self.silog_loss(pred, target, weight_map) 249 loss_grad = self.gradient_l1_loss(pred, target, weight_map) 250 loss_ssim = self.ssim_loss(pred, target, weight_map) 251 loss_l1 = self.l1_loss(pred, target, weight_map) 252 loss_edge_aware = self.edge_aware_loss(pred, target, weight_map) 253 loss_var = self.variance_loss(pred, target) 254 loss_range = self.range_loss(pred, target) 255 loss_blur = self.blur_loss(pred, target) 256 257 # reset avgs 258 if self.steps > 24: 259 self.step() 260 261 self.avg_loss_silog += loss_silog 262 self.avg_loss_grad += loss_grad 263 self.avg_loss_ssim += loss_ssim 264 self.avg_loss_l1 += loss_l1 265 self.avg_loss_edge_aware += loss_edge_aware 266 self.avg_loss_var += loss_var 267 self.avg_loss_range += loss_range 268 self.avg_loss_blur += loss_blur 269 self.steps += 1 270 271 total_loss = ( 272 self.weight_silog * loss_silog + 273 self.weight_grad * loss_grad + 274 self.weight_ssim * loss_ssim + 275 self.weight_edge_aware * loss_edge_aware + 276 self.weight_l1 * loss_l1 + 277 self.weight_var * loss_var + 278 self.weight_range * loss_range + 279 self.weight_blur * loss_blur 280 ) 281 282 return total_loss
Computes the weighted combined loss between prediction and target.
Parameter:
- pred (torch.Tensor): Predicted output tensor.
- target (torch.Tensor): Ground truth tensor.
- weight_map (torch.Tensor or None): Optional pixel-wise weighting map.
- should_calc_weight_map (bool): If True and weight_map is None, calculates a weight map from target.
Returns:
- torch.Tensor: Weighted sum of all losses.
284 def step(self, epoch=None): 285 """ 286 Resets the running averages of all tracked losses. 287 """ 288 self.avg_loss_silog = 0 289 self.avg_loss_grad = 0 290 self.avg_loss_ssim = 0 291 self.avg_loss_l1 = 0 292 self.avg_loss_edge_aware = 0 293 self.avg_loss_var = 0 294 self.avg_loss_range = 0 295 self.avg_loss_blur = 0 296 self.steps = 0
Resets the running averages of all tracked losses.
298 def get_avg_losses(self): 299 """ 300 Returns the running average of all individual losses. 301 302 Returns: 303 - tuple: (avg_loss_silog, avg_loss_grad, avg_loss_ssim, avg_loss_l1, 304 avg_loss_edge_aware, avg_loss_var, avg_loss_range, avg_loss_blur) 305 """ 306 return (self.avg_loss_silog/self.steps, 307 self.avg_loss_grad/self.steps, 308 self.avg_loss_ssim/self.steps, 309 self.avg_loss_l1/self.steps, 310 self.avg_loss_edge_aware/self.steps, 311 self.avg_loss_var/self.steps, 312 self.avg_loss_range/self.steps, 313 self.avg_loss_blur/self.steps 314 )
Returns the running average of all individual losses.
Returns:
- tuple: (avg_loss_silog, avg_loss_grad, avg_loss_ssim, avg_loss_l1, avg_loss_edge_aware, avg_loss_var, avg_loss_range, avg_loss_blur)
316 def get_dict(self): 317 """ 318 Returns a dictionary of average losses and their corresponding weights. 319 320 Returns: 321 - dict: All loss components with their weights. 322 """ 323 loss_silog, loss_grad, loss_ssim, loss_l1, loss_edge_aware, loss_var, loss_range, loss_blur = self.get_avg_losses() 324 return { 325 f"loss_silog": loss_silog, 326 f"loss_grad": loss_grad, 327 f"loss_ssim": loss_ssim, 328 f"loss_L1": loss_l1, 329 f"loss_edge aware": loss_edge_aware, 330 f"loss_var": loss_var, 331 f"loss_range": loss_range, 332 f"loss_blur": loss_blur, 333 f"weight_loss_silog": self.weight_silog, 334 f"weight_loss_grad": self.weight_grad, 335 f"_weight_loss_ssim": self.weight_ssim, 336 f"_weight_loss_L1": self.weight_l1, 337 f"weight_loss_edge_aware": self.weight_edge_aware, 338 f"weight_loss_var": self.weight_var, 339 f"weight_loss_range": self.weight_range, 340 f"weight_loss_blur": self.weight_blur 341 }
Returns a dictionary of average losses and their corresponding weights.
Returns:
- dict: All loss components with their weights.
343def calc_weight_map(target): 344 """ 345 Calculates a per-pixel weighting map for a target tensor based on unique value frequencies. 346 347 Less frequent values are given higher weights to emphasize their contribution in loss computations. 348 349 Parameter: 350 - target (torch.Tensor): 351 Ground truth tensor. 352 353 Returns: 354 - torch.Tensor: Weight map tensor of the same shape as target. 355 """ 356 values, counts = torch.unique(target.flatten(), return_counts=True) 357 all_counts = counts.sum().float() 358 359 # weight_factor = 2.0 360 # weights = {values[idx].item(): max(torch.exp( ( (1-(counts[idx].item()/all_counts))) *weight_factor), 0.0001) for idx in range(len(values))} 361 362 weights = {values[idx].item(): 255.0/counts[idx].item() for idx in range(len(values))} 363 364 # print(f"Weights:") 365 # for cur_value, cur_counts in list(sorted(weights.items(), key=lambda x:x[0])): 366 # print(' - '+str(round(cur_value, 4))+': '+str(cur_counts.item())) 367 368 weights_map = torch.zeros_like(target, dtype=torch.float) 369 for cur_value in values: 370 cur_value = cur_value.item() 371 weights_map[target == cur_value] = weights[cur_value] 372 373 return weights_map
Calculates a per-pixel weighting map for a target tensor based on unique value frequencies.
Less frequent values are given higher weights to emphasize their contribution in loss computations.
Parameter:
- target (torch.Tensor): Ground truth tensor.
Returns:
- torch.Tensor: Weight map tensor of the same shape as target.