image_to_image.models.residual_design_model
Module to define a Residual-Design Model.
A model which consists of 2 models.
The data is (should) be splitted up in 2 parts (sub-problems) for example in Physgen Dataset, the base-propagation and the complex-propagation.
Classes:
- CombineNet
- ResidualDesignModel
By Tobia Ippolito
1""" 2Module to define a Residual-Design Model.<br> 3A model which consists of 2 models. 4 5The data is (should) be splitted up in 2 parts (sub-problems) 6for example in Physgen Dataset, the base-propagation and 7the complex-propagation. 8 9Classes: 10- CombineNet 11- ResidualDesignModel 12 13By Tobia Ippolito 14""" 15# --------------------------- 16# > Imports < 17# --------------------------- 18import torch 19import torch.nn as nn 20import torch.optim as optim 21 22 23 24# --------------------------- 25# > Helper < 26# --------------------------- 27class CombineNet(nn.Module): 28 """ 29 Helper network for combining two images. 30 The two outputs of the submodels of the ResidualDesignModel. 31 32 The CombineNet is a lightweight convolutional neural network that takes 33 two input tensors, merges them channel-wise, and learns to predict a 34 combined output representation. It can be used for post-processing, 35 fusion of multiple model outputs, or blending of different feature spaces. 36 37 Architecture Overview: 38 - 3 convolutional layers with batch normalization and GELU activation. 39 - Final sigmoid activation to normalize outputs between [0, 1]. 40 - Optimized using L1 loss (Mean Absolute Error). 41 """ 42 def __init__(self, input_channels=1, output_channels=1, hidden_channels=32): 43 """ 44 Init of the CombineNet model. 45 46 Parameter: 47 - input_channels (int): 48 Number of channels for each of the two input tensors 49 (e.g., 1 for grayscale, 3 for RGB). 50 - output_channels (int): 51 Number of output channels of the combined result. 52 - hidden_channels (int): 53 Number of feature channels in the hidden layers (internal representation). 54 """ 55 super().__init__() 56 57 self.model = nn.Sequential( 58 nn.Conv2d(input_channels, hidden_channels, kernel_size=3, stride=1, padding=1), 59 nn.BatchNorm2d(hidden_channels), 60 nn.GELU(), 61 62 nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, stride=1, padding=1), 63 nn.BatchNorm2d(hidden_channels), 64 nn.GELU(), 65 66 nn.Conv2d(hidden_channels, output_channels, kernel_size=3, stride=1, padding=1), 67 nn.Sigmoid() 68 ) 69 70 self.loss = torch.nn.L1Loss() 71 self.last_loss = float("inf") 72 self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.5, 0.999)) 73 74 75 def forward(self, x, y): 76 """ 77 Forward pass of the CombineNet. 78 79 Parameter: 80 - x (torch.tensor): 81 First input tensor of shape (batch_size, input_channels, height, width). 82 - y (torch.tensor): 83 Second input tensor of shape (batch_size, input_channels, height, width). 84 85 Returns: 86 - torch.tensor: 87 Combined output tensor of shape (batch_size, output_channels, height, width). 88 """ 89 return self.model(torch.cat([x, y], dim=1)) 90 91 92 def backward(self, y_base, y_complex, y): 93 """ 94 Backward pass (training step) for the CombineNet. 95 96 Parameter: 97 - y_base (torch.tensor): 98 First input tensor (e.g., base model output). 99 - y_complex (torch.tensor): 100 Second input tensor (e.g., complex or refined prediction). 101 - y (torch.tensor): 102 Ground truth tensor (target output for supervision). 103 104 Returns: 105 - float: 106 The scalar loss value (L1 loss) from the current optimization step. 107 """ 108 self.optimizer.zero_grad() 109 y_pred = self.forward(y_base, y_complex) 110 loss = self.loss(y, y_pred) 111 loss.backward() 112 self.last_loss = loss.item() 113 self.optimizer.step() 114 return self.last_loss 115 116 117 118# --------------------------- 119# > Residual Design Model < 120# --------------------------- 121class ResidualDesignModel(nn.Module): 122 """ 123 Residual Design Model for combining predictions from a base and a complex model. 124 125 The ResidualDesignModel enables two modes of combination: 126 1. **Mathematical Residual (`math`)**: 127 - Computes a weighted sum of the base and complex model outputs. 128 - The weight `alpha` is learnable and optimized via L1 loss. 129 2. **Neural Network Fusion (`nn`)**: 130 - Uses a small CNN (`CombineNet`) to learn a nonlinear combination of the outputs. 131 """ 132 def __init__(self, 133 base_model: nn.Module, 134 complex_model: nn.Module, 135 combine_mode="math"): # math or nn 136 """ 137 Init of the ResidualDesignModel. 138 139 Parameter: 140 - base_model (nn.Module): Pretrained or instantiated base model. 141 - complex_model (nn.Module): Pretrained or instantiated complex model. 142 - combine_mode (str, default='math'): Mode for combining outputs. Options: 143 - 'math': Weighted residual combination with learnable alpha. 144 - 'nn': Nonlinear fusion using a small CombineNet. 145 """ 146 super().__init__() 147 148 self.base_model = base_model 149 self.complex_model = complex_model 150 self.combine_mode = combine_mode 151 152 self.input_channels = (self.base_model.get_input_channels(), self.complex_model.get_input_channels()) # max(self.base_model.get_input_channels(), self.complex_model.get_input_channels()) 153 self.output_channels = min(self.base_model.get_output_channels(), self.complex_model.get_output_channels()) 154 155 self.combine_net = CombineNet(input_channels=self.base_model.get_output_channels() + self.complex_model.get_output_channels(), 156 output_channels=self.output_channels, hidden_channels=32) 157 158 self.alpha = nn.Parameter(torch.tensor(0.5)) 159 self.alpha_optimizer = optim.Adam([self.alpha], lr=1e-5) 160 self.alpha_criterion = nn.L1Loss() 161 162 self.last_base_loss = float('nan') 163 self.last_complex_loss = float('nan') 164 self.last_combined_loss = float('nan') 165 self.last_combined_math_loss = float('nan') 166 167 168 def get_input_channels(self): 169 """ 170 Returns a tuple with the input channels of the base and complex models. 171 172 Returns: 173 - tuple: (base_model_input_channels, complex_model_input_channels) 174 """ 175 return self.input_channels 176 177 178 def get_output_channels(self): 179 """ 180 Returns the number of output channels for the combined prediction. 181 182 Returns: 183 - int: Minimum of base_model and complex_model output channels. 184 """ 185 return self.output_channels 186 187 def forward(self, x): 188 self.forward(x[0], x[1]) 189 190 def forward(self, x_base, x_complex): 191 """ 192 Forward pass of the ResidualDesignModel. 193 194 Parameter: 195 - x_base (torch.tensor): Input tensor for the base model. 196 - x_complex (torch.tensor): Input tensor for the complex model. 197 198 Returns: 199 - torch.tensor: Combined prediction, either via weighted residual or CombineNet. 200 """ 201 # print(x_base.size) 202 # print(x_complex.size) 203 y_base = self.base_model(x_base) 204 y_complex = self.complex_model(x_complex) 205 206 if self.combine_mode == 'math': 207 # y = (y_complex*(-0.5)) + y_base 208 y = y_base + self.alpha * y_complex 209 if len(y.shape) == 4: 210 y = y.squeeze(1) 211 return y 212 else: 213 return self.combine_net(y_base, y_complex) 214 215 def backward(self, y_base, y_complex, y): 216 """ 217 Backward pass to optimize the alpha parameter for mathematical residual combination. 218 219 Parameter: 220 - y_base (torch.tensor): 221 Output of the base model. 222 - y_complex (torch.tensor): 223 Output of the complex model. 224 - y (torch.tensor): 225 Ground truth tensor. 226 """ 227 self.alpha_optimizer.zero_grad() 228 y_pred = y_base + self.alpha * y_complex 229 combine_loss = self.alpha_criterion(y_pred, y) 230 combine_loss.backward() 231 self.last_combined_math_loss = combine_loss.item() 232 self.alpha_optimizer.step() 233 234 def get_dict(self): 235 """ 236 Returns a dictionary with the most recent loss values. 237 238 Returns: 239 - dict: Loss components (base, complex). 240 241 Notes: 242 - Useful for logging or monitoring training progress. 243 """ 244 return { 245 f"loss_base": self.last_base_loss, 246 f"loss_complex": self.last_complex_loss, 247 f"loss_combined_net": self.combine_net.last_loss, 248 f"loss_combined_math": self.last_combined_math_loss 249 }
28class CombineNet(nn.Module): 29 """ 30 Helper network for combining two images. 31 The two outputs of the submodels of the ResidualDesignModel. 32 33 The CombineNet is a lightweight convolutional neural network that takes 34 two input tensors, merges them channel-wise, and learns to predict a 35 combined output representation. It can be used for post-processing, 36 fusion of multiple model outputs, or blending of different feature spaces. 37 38 Architecture Overview: 39 - 3 convolutional layers with batch normalization and GELU activation. 40 - Final sigmoid activation to normalize outputs between [0, 1]. 41 - Optimized using L1 loss (Mean Absolute Error). 42 """ 43 def __init__(self, input_channels=1, output_channels=1, hidden_channels=32): 44 """ 45 Init of the CombineNet model. 46 47 Parameter: 48 - input_channels (int): 49 Number of channels for each of the two input tensors 50 (e.g., 1 for grayscale, 3 for RGB). 51 - output_channels (int): 52 Number of output channels of the combined result. 53 - hidden_channels (int): 54 Number of feature channels in the hidden layers (internal representation). 55 """ 56 super().__init__() 57 58 self.model = nn.Sequential( 59 nn.Conv2d(input_channels, hidden_channels, kernel_size=3, stride=1, padding=1), 60 nn.BatchNorm2d(hidden_channels), 61 nn.GELU(), 62 63 nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, stride=1, padding=1), 64 nn.BatchNorm2d(hidden_channels), 65 nn.GELU(), 66 67 nn.Conv2d(hidden_channels, output_channels, kernel_size=3, stride=1, padding=1), 68 nn.Sigmoid() 69 ) 70 71 self.loss = torch.nn.L1Loss() 72 self.last_loss = float("inf") 73 self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.5, 0.999)) 74 75 76 def forward(self, x, y): 77 """ 78 Forward pass of the CombineNet. 79 80 Parameter: 81 - x (torch.tensor): 82 First input tensor of shape (batch_size, input_channels, height, width). 83 - y (torch.tensor): 84 Second input tensor of shape (batch_size, input_channels, height, width). 85 86 Returns: 87 - torch.tensor: 88 Combined output tensor of shape (batch_size, output_channels, height, width). 89 """ 90 return self.model(torch.cat([x, y], dim=1)) 91 92 93 def backward(self, y_base, y_complex, y): 94 """ 95 Backward pass (training step) for the CombineNet. 96 97 Parameter: 98 - y_base (torch.tensor): 99 First input tensor (e.g., base model output). 100 - y_complex (torch.tensor): 101 Second input tensor (e.g., complex or refined prediction). 102 - y (torch.tensor): 103 Ground truth tensor (target output for supervision). 104 105 Returns: 106 - float: 107 The scalar loss value (L1 loss) from the current optimization step. 108 """ 109 self.optimizer.zero_grad() 110 y_pred = self.forward(y_base, y_complex) 111 loss = self.loss(y, y_pred) 112 loss.backward() 113 self.last_loss = loss.item() 114 self.optimizer.step() 115 return self.last_loss
Helper network for combining two images. The two outputs of the submodels of the ResidualDesignModel.
The CombineNet is a lightweight convolutional neural network that takes two input tensors, merges them channel-wise, and learns to predict a combined output representation. It can be used for post-processing, fusion of multiple model outputs, or blending of different feature spaces.
Architecture Overview:
- 3 convolutional layers with batch normalization and GELU activation.
- Final sigmoid activation to normalize outputs between [0, 1].
- Optimized using L1 loss (Mean Absolute Error).
43 def __init__(self, input_channels=1, output_channels=1, hidden_channels=32): 44 """ 45 Init of the CombineNet model. 46 47 Parameter: 48 - input_channels (int): 49 Number of channels for each of the two input tensors 50 (e.g., 1 for grayscale, 3 for RGB). 51 - output_channels (int): 52 Number of output channels of the combined result. 53 - hidden_channels (int): 54 Number of feature channels in the hidden layers (internal representation). 55 """ 56 super().__init__() 57 58 self.model = nn.Sequential( 59 nn.Conv2d(input_channels, hidden_channels, kernel_size=3, stride=1, padding=1), 60 nn.BatchNorm2d(hidden_channels), 61 nn.GELU(), 62 63 nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, stride=1, padding=1), 64 nn.BatchNorm2d(hidden_channels), 65 nn.GELU(), 66 67 nn.Conv2d(hidden_channels, output_channels, kernel_size=3, stride=1, padding=1), 68 nn.Sigmoid() 69 ) 70 71 self.loss = torch.nn.L1Loss() 72 self.last_loss = float("inf") 73 self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.5, 0.999))
Init of the CombineNet model.
Parameter:
- input_channels (int): Number of channels for each of the two input tensors (e.g., 1 for grayscale, 3 for RGB).
- output_channels (int): Number of output channels of the combined result.
- hidden_channels (int): Number of feature channels in the hidden layers (internal representation).
76 def forward(self, x, y): 77 """ 78 Forward pass of the CombineNet. 79 80 Parameter: 81 - x (torch.tensor): 82 First input tensor of shape (batch_size, input_channels, height, width). 83 - y (torch.tensor): 84 Second input tensor of shape (batch_size, input_channels, height, width). 85 86 Returns: 87 - torch.tensor: 88 Combined output tensor of shape (batch_size, output_channels, height, width). 89 """ 90 return self.model(torch.cat([x, y], dim=1))
Forward pass of the CombineNet.
Parameter:
- x (torch.tensor): First input tensor of shape (batch_size, input_channels, height, width).
- y (torch.tensor): Second input tensor of shape (batch_size, input_channels, height, width).
Returns:
- torch.tensor: Combined output tensor of shape (batch_size, output_channels, height, width).
93 def backward(self, y_base, y_complex, y): 94 """ 95 Backward pass (training step) for the CombineNet. 96 97 Parameter: 98 - y_base (torch.tensor): 99 First input tensor (e.g., base model output). 100 - y_complex (torch.tensor): 101 Second input tensor (e.g., complex or refined prediction). 102 - y (torch.tensor): 103 Ground truth tensor (target output for supervision). 104 105 Returns: 106 - float: 107 The scalar loss value (L1 loss) from the current optimization step. 108 """ 109 self.optimizer.zero_grad() 110 y_pred = self.forward(y_base, y_complex) 111 loss = self.loss(y, y_pred) 112 loss.backward() 113 self.last_loss = loss.item() 114 self.optimizer.step() 115 return self.last_loss
Backward pass (training step) for the CombineNet.
Parameter:
- y_base (torch.tensor): First input tensor (e.g., base model output).
- y_complex (torch.tensor): Second input tensor (e.g., complex or refined prediction).
- y (torch.tensor): Ground truth tensor (target output for supervision).
Returns:
- float: The scalar loss value (L1 loss) from the current optimization step.
122class ResidualDesignModel(nn.Module): 123 """ 124 Residual Design Model for combining predictions from a base and a complex model. 125 126 The ResidualDesignModel enables two modes of combination: 127 1. **Mathematical Residual (`math`)**: 128 - Computes a weighted sum of the base and complex model outputs. 129 - The weight `alpha` is learnable and optimized via L1 loss. 130 2. **Neural Network Fusion (`nn`)**: 131 - Uses a small CNN (`CombineNet`) to learn a nonlinear combination of the outputs. 132 """ 133 def __init__(self, 134 base_model: nn.Module, 135 complex_model: nn.Module, 136 combine_mode="math"): # math or nn 137 """ 138 Init of the ResidualDesignModel. 139 140 Parameter: 141 - base_model (nn.Module): Pretrained or instantiated base model. 142 - complex_model (nn.Module): Pretrained or instantiated complex model. 143 - combine_mode (str, default='math'): Mode for combining outputs. Options: 144 - 'math': Weighted residual combination with learnable alpha. 145 - 'nn': Nonlinear fusion using a small CombineNet. 146 """ 147 super().__init__() 148 149 self.base_model = base_model 150 self.complex_model = complex_model 151 self.combine_mode = combine_mode 152 153 self.input_channels = (self.base_model.get_input_channels(), self.complex_model.get_input_channels()) # max(self.base_model.get_input_channels(), self.complex_model.get_input_channels()) 154 self.output_channels = min(self.base_model.get_output_channels(), self.complex_model.get_output_channels()) 155 156 self.combine_net = CombineNet(input_channels=self.base_model.get_output_channels() + self.complex_model.get_output_channels(), 157 output_channels=self.output_channels, hidden_channels=32) 158 159 self.alpha = nn.Parameter(torch.tensor(0.5)) 160 self.alpha_optimizer = optim.Adam([self.alpha], lr=1e-5) 161 self.alpha_criterion = nn.L1Loss() 162 163 self.last_base_loss = float('nan') 164 self.last_complex_loss = float('nan') 165 self.last_combined_loss = float('nan') 166 self.last_combined_math_loss = float('nan') 167 168 169 def get_input_channels(self): 170 """ 171 Returns a tuple with the input channels of the base and complex models. 172 173 Returns: 174 - tuple: (base_model_input_channels, complex_model_input_channels) 175 """ 176 return self.input_channels 177 178 179 def get_output_channels(self): 180 """ 181 Returns the number of output channels for the combined prediction. 182 183 Returns: 184 - int: Minimum of base_model and complex_model output channels. 185 """ 186 return self.output_channels 187 188 def forward(self, x): 189 self.forward(x[0], x[1]) 190 191 def forward(self, x_base, x_complex): 192 """ 193 Forward pass of the ResidualDesignModel. 194 195 Parameter: 196 - x_base (torch.tensor): Input tensor for the base model. 197 - x_complex (torch.tensor): Input tensor for the complex model. 198 199 Returns: 200 - torch.tensor: Combined prediction, either via weighted residual or CombineNet. 201 """ 202 # print(x_base.size) 203 # print(x_complex.size) 204 y_base = self.base_model(x_base) 205 y_complex = self.complex_model(x_complex) 206 207 if self.combine_mode == 'math': 208 # y = (y_complex*(-0.5)) + y_base 209 y = y_base + self.alpha * y_complex 210 if len(y.shape) == 4: 211 y = y.squeeze(1) 212 return y 213 else: 214 return self.combine_net(y_base, y_complex) 215 216 def backward(self, y_base, y_complex, y): 217 """ 218 Backward pass to optimize the alpha parameter for mathematical residual combination. 219 220 Parameter: 221 - y_base (torch.tensor): 222 Output of the base model. 223 - y_complex (torch.tensor): 224 Output of the complex model. 225 - y (torch.tensor): 226 Ground truth tensor. 227 """ 228 self.alpha_optimizer.zero_grad() 229 y_pred = y_base + self.alpha * y_complex 230 combine_loss = self.alpha_criterion(y_pred, y) 231 combine_loss.backward() 232 self.last_combined_math_loss = combine_loss.item() 233 self.alpha_optimizer.step() 234 235 def get_dict(self): 236 """ 237 Returns a dictionary with the most recent loss values. 238 239 Returns: 240 - dict: Loss components (base, complex). 241 242 Notes: 243 - Useful for logging or monitoring training progress. 244 """ 245 return { 246 f"loss_base": self.last_base_loss, 247 f"loss_complex": self.last_complex_loss, 248 f"loss_combined_net": self.combine_net.last_loss, 249 f"loss_combined_math": self.last_combined_math_loss 250 }
Residual Design Model for combining predictions from a base and a complex model.
The ResidualDesignModel enables two modes of combination:
- Mathematical Residual (
math):- Computes a weighted sum of the base and complex model outputs.
- The weight
alphais learnable and optimized via L1 loss.
- Neural Network Fusion (
nn):- Uses a small CNN (
CombineNet) to learn a nonlinear combination of the outputs.
- Uses a small CNN (
133 def __init__(self, 134 base_model: nn.Module, 135 complex_model: nn.Module, 136 combine_mode="math"): # math or nn 137 """ 138 Init of the ResidualDesignModel. 139 140 Parameter: 141 - base_model (nn.Module): Pretrained or instantiated base model. 142 - complex_model (nn.Module): Pretrained or instantiated complex model. 143 - combine_mode (str, default='math'): Mode for combining outputs. Options: 144 - 'math': Weighted residual combination with learnable alpha. 145 - 'nn': Nonlinear fusion using a small CombineNet. 146 """ 147 super().__init__() 148 149 self.base_model = base_model 150 self.complex_model = complex_model 151 self.combine_mode = combine_mode 152 153 self.input_channels = (self.base_model.get_input_channels(), self.complex_model.get_input_channels()) # max(self.base_model.get_input_channels(), self.complex_model.get_input_channels()) 154 self.output_channels = min(self.base_model.get_output_channels(), self.complex_model.get_output_channels()) 155 156 self.combine_net = CombineNet(input_channels=self.base_model.get_output_channels() + self.complex_model.get_output_channels(), 157 output_channels=self.output_channels, hidden_channels=32) 158 159 self.alpha = nn.Parameter(torch.tensor(0.5)) 160 self.alpha_optimizer = optim.Adam([self.alpha], lr=1e-5) 161 self.alpha_criterion = nn.L1Loss() 162 163 self.last_base_loss = float('nan') 164 self.last_complex_loss = float('nan') 165 self.last_combined_loss = float('nan') 166 self.last_combined_math_loss = float('nan')
Init of the ResidualDesignModel.
Parameter:
- base_model (nn.Module): Pretrained or instantiated base model.
- complex_model (nn.Module): Pretrained or instantiated complex model.
- combine_mode (str, default='math'): Mode for combining outputs. Options:
- 'math': Weighted residual combination with learnable alpha.
- 'nn': Nonlinear fusion using a small CombineNet.
169 def get_input_channels(self): 170 """ 171 Returns a tuple with the input channels of the base and complex models. 172 173 Returns: 174 - tuple: (base_model_input_channels, complex_model_input_channels) 175 """ 176 return self.input_channels
Returns a tuple with the input channels of the base and complex models.
Returns:
- tuple: (base_model_input_channels, complex_model_input_channels)
179 def get_output_channels(self): 180 """ 181 Returns the number of output channels for the combined prediction. 182 183 Returns: 184 - int: Minimum of base_model and complex_model output channels. 185 """ 186 return self.output_channels
Returns the number of output channels for the combined prediction.
Returns:
- int: Minimum of base_model and complex_model output channels.
191 def forward(self, x_base, x_complex): 192 """ 193 Forward pass of the ResidualDesignModel. 194 195 Parameter: 196 - x_base (torch.tensor): Input tensor for the base model. 197 - x_complex (torch.tensor): Input tensor for the complex model. 198 199 Returns: 200 - torch.tensor: Combined prediction, either via weighted residual or CombineNet. 201 """ 202 # print(x_base.size) 203 # print(x_complex.size) 204 y_base = self.base_model(x_base) 205 y_complex = self.complex_model(x_complex) 206 207 if self.combine_mode == 'math': 208 # y = (y_complex*(-0.5)) + y_base 209 y = y_base + self.alpha * y_complex 210 if len(y.shape) == 4: 211 y = y.squeeze(1) 212 return y 213 else: 214 return self.combine_net(y_base, y_complex)
Forward pass of the ResidualDesignModel.
Parameter:
- x_base (torch.tensor): Input tensor for the base model.
- x_complex (torch.tensor): Input tensor for the complex model.
Returns:
- torch.tensor: Combined prediction, either via weighted residual or CombineNet.
216 def backward(self, y_base, y_complex, y): 217 """ 218 Backward pass to optimize the alpha parameter for mathematical residual combination. 219 220 Parameter: 221 - y_base (torch.tensor): 222 Output of the base model. 223 - y_complex (torch.tensor): 224 Output of the complex model. 225 - y (torch.tensor): 226 Ground truth tensor. 227 """ 228 self.alpha_optimizer.zero_grad() 229 y_pred = y_base + self.alpha * y_complex 230 combine_loss = self.alpha_criterion(y_pred, y) 231 combine_loss.backward() 232 self.last_combined_math_loss = combine_loss.item() 233 self.alpha_optimizer.step()
Backward pass to optimize the alpha parameter for mathematical residual combination.
Parameter:
- y_base (torch.tensor): Output of the base model.
- y_complex (torch.tensor): Output of the complex model.
- y (torch.tensor): Ground truth tensor.
235 def get_dict(self): 236 """ 237 Returns a dictionary with the most recent loss values. 238 239 Returns: 240 - dict: Loss components (base, complex). 241 242 Notes: 243 - Useful for logging or monitoring training progress. 244 """ 245 return { 246 f"loss_base": self.last_base_loss, 247 f"loss_complex": self.last_complex_loss, 248 f"loss_combined_net": self.combine_net.last_loss, 249 f"loss_combined_math": self.last_combined_math_loss 250 }
Returns a dictionary with the most recent loss values.
Returns:
- dict: Loss components (base, complex).
Notes:
- Useful for logging or monitoring training progress.