image_to_image.models.residual_design_model

Module to define a Residual-Design Model.
A model which consists of 2 models.

The data is (should) be splitted up in 2 parts (sub-problems) for example in Physgen Dataset, the base-propagation and the complex-propagation.

Classes:

  • CombineNet
  • ResidualDesignModel

By Tobia Ippolito

  1"""
  2Module to define a Residual-Design Model.<br>
  3A model which consists of 2 models.
  4
  5The data is (should) be splitted up in 2 parts (sub-problems) 
  6for example in Physgen Dataset, the base-propagation and 
  7the complex-propagation.
  8
  9Classes:
 10- CombineNet
 11- ResidualDesignModel
 12
 13By Tobia Ippolito
 14"""
 15# ---------------------------
 16#        > Imports <
 17# ---------------------------
 18import torch
 19import torch.nn as nn
 20import torch.optim as optim
 21
 22
 23
 24# ---------------------------
 25#         > Helper <
 26# ---------------------------
 27class CombineNet(nn.Module):
 28    """
 29    Helper network for combining two images. 
 30    The two outputs of the submodels of the ResidualDesignModel. 
 31
 32    The CombineNet is a lightweight convolutional neural network that takes 
 33    two input tensors, merges them channel-wise, and learns to predict a 
 34    combined output representation. It can be used for post-processing, 
 35    fusion of multiple model outputs, or blending of different feature spaces.
 36
 37    Architecture Overview:
 38    - 3 convolutional layers with batch normalization and GELU activation.
 39    - Final sigmoid activation to normalize outputs between [0, 1].
 40    - Optimized using L1 loss (Mean Absolute Error).
 41    """
 42    def __init__(self, input_channels=1, output_channels=1, hidden_channels=32):
 43        """
 44        Init of the CombineNet model.
 45
 46        Parameter:
 47        - input_channels (int): 
 48            Number of channels for each of the two input tensors 
 49            (e.g., 1 for grayscale, 3 for RGB).
 50        - output_channels (int): 
 51            Number of output channels of the combined result.
 52        - hidden_channels (int): 
 53            Number of feature channels in the hidden layers (internal representation).
 54        """
 55        super().__init__()
 56
 57        self.model = nn.Sequential(
 58            nn.Conv2d(input_channels, hidden_channels, kernel_size=3, stride=1, padding=1),
 59            nn.BatchNorm2d(hidden_channels),
 60            nn.GELU(),
 61
 62            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, stride=1, padding=1),
 63            nn.BatchNorm2d(hidden_channels),
 64            nn.GELU(),
 65
 66            nn.Conv2d(hidden_channels, output_channels, kernel_size=3, stride=1, padding=1),
 67            nn.Sigmoid()
 68        )
 69
 70        self.loss = torch.nn.L1Loss()
 71        self.last_loss = float("inf")
 72        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.5, 0.999))
 73
 74
 75    def forward(self, x, y):
 76        """
 77        Forward pass of the CombineNet.
 78
 79        Parameter:
 80        - x (torch.tensor): 
 81            First input tensor of shape (batch_size, input_channels, height, width).
 82        - y (torch.tensor): 
 83            Second input tensor of shape (batch_size, input_channels, height, width).
 84
 85        Returns:
 86        - torch.tensor: 
 87            Combined output tensor of shape (batch_size, output_channels, height, width).
 88        """
 89        return self.model(torch.cat([x, y], dim=1))
 90    
 91
 92    def backward(self, y_base, y_complex, y):
 93        """
 94        Backward pass (training step) for the CombineNet.
 95
 96        Parameter:
 97        - y_base (torch.tensor): 
 98            First input tensor (e.g., base model output).
 99        - y_complex (torch.tensor): 
100            Second input tensor (e.g., complex or refined prediction).
101        - y (torch.tensor): 
102            Ground truth tensor (target output for supervision).
103
104        Returns:
105        - float: 
106            The scalar loss value (L1 loss) from the current optimization step.
107        """
108        self.optimizer.zero_grad()
109        y_pred = self.forward(y_base, y_complex)
110        loss = self.loss(y, y_pred)
111        loss.backward()
112        self.last_loss = loss.item()
113        self.optimizer.step()
114        return self.last_loss
115
116
117
118# ---------------------------
119#  > Residual Design Model <
120# ---------------------------
121class ResidualDesignModel(nn.Module):
122    """
123    Residual Design Model for combining predictions from a base and a complex model.
124
125    The ResidualDesignModel enables two modes of combination:
126    1. **Mathematical Residual (`math`)**:
127       - Computes a weighted sum of the base and complex model outputs.
128       - The weight `alpha` is learnable and optimized via L1 loss.
129    2. **Neural Network Fusion (`nn`)**:
130       - Uses a small CNN (`CombineNet`) to learn a nonlinear combination of the outputs.
131    """
132    def __init__(self,
133                 base_model: nn.Module,
134                 complex_model: nn.Module,
135                 combine_mode="math"):  # math or nn
136        """
137        Init of the ResidualDesignModel.
138
139        Parameter:
140        - base_model (nn.Module): Pretrained or instantiated base model.
141        - complex_model (nn.Module): Pretrained or instantiated complex model.
142        - combine_mode (str, default='math'): Mode for combining outputs. Options:
143            - 'math': Weighted residual combination with learnable alpha.
144            - 'nn': Nonlinear fusion using a small CombineNet.
145        """
146        super().__init__()
147
148        self.base_model = base_model
149        self.complex_model = complex_model
150        self.combine_mode = combine_mode
151
152        self.input_channels = (self.base_model.get_input_channels(), self.complex_model.get_input_channels())  # max(self.base_model.get_input_channels(), self.complex_model.get_input_channels())
153        self.output_channels = min(self.base_model.get_output_channels(), self.complex_model.get_output_channels())
154
155        self.combine_net = CombineNet(input_channels=self.base_model.get_output_channels() + self.complex_model.get_output_channels(), 
156                                      output_channels=self.output_channels, hidden_channels=32)
157
158        self.alpha = nn.Parameter(torch.tensor(0.5))
159        self.alpha_optimizer = optim.Adam([self.alpha], lr=1e-5)
160        self.alpha_criterion = nn.L1Loss()
161
162        self.last_base_loss = float('nan')
163        self.last_complex_loss = float('nan')
164        self.last_combined_loss = float('nan')
165        self.last_combined_math_loss = float('nan')
166
167
168    def get_input_channels(self):
169        """
170        Returns a tuple with the input channels of the base and complex models.
171
172        Returns:
173        - tuple: (base_model_input_channels, complex_model_input_channels)
174        """
175        return self.input_channels
176    
177
178    def get_output_channels(self):
179        """
180        Returns the number of output channels for the combined prediction.
181
182        Returns:
183        - int: Minimum of base_model and complex_model output channels.
184        """
185        return self.output_channels
186
187    def forward(self, x):
188        self.forward(x[0], x[1])
189
190    def forward(self, x_base, x_complex):
191        """
192        Forward pass of the ResidualDesignModel.
193
194        Parameter:
195        - x_base (torch.tensor): Input tensor for the base model.
196        - x_complex (torch.tensor): Input tensor for the complex model.
197
198        Returns:
199        - torch.tensor: Combined prediction, either via weighted residual or CombineNet.
200        """
201        # print(x_base.size)
202        # print(x_complex.size)
203        y_base = self.base_model(x_base)
204        y_complex = self.complex_model(x_complex)
205
206        if self.combine_mode == 'math':
207            # y = (y_complex*(-0.5)) + y_base
208            y = y_base + self.alpha * y_complex
209            if len(y.shape) == 4:
210                y = y.squeeze(1)
211            return y
212        else:
213            return self.combine_net(y_base, y_complex)
214        
215    def backward(self, y_base, y_complex, y):
216        """
217        Backward pass to optimize the alpha parameter for mathematical residual combination.
218
219        Parameter:
220        - y_base (torch.tensor): 
221            Output of the base model.
222        - y_complex (torch.tensor): 
223            Output of the complex model.
224        - y (torch.tensor): 
225            Ground truth tensor.
226        """
227        self.alpha_optimizer.zero_grad()
228        y_pred = y_base + self.alpha * y_complex
229        combine_loss = self.alpha_criterion(y_pred, y)
230        combine_loss.backward()
231        self.last_combined_math_loss = combine_loss.item()
232        self.alpha_optimizer.step()
233        
234    def get_dict(self):
235        """
236        Returns a dictionary with the most recent loss values.
237
238        Returns:
239        - dict: Loss components (base, complex).
240
241        Notes:
242        - Useful for logging or monitoring training progress.
243        """
244        return {
245                f"loss_base": self.last_base_loss, 
246                f"loss_complex": self.last_complex_loss, 
247                f"loss_combined_net": self.combine_net.last_loss,
248                f"loss_combined_math": self.last_combined_math_loss
249               }
class CombineNet(torch.nn.modules.module.Module):
 28class CombineNet(nn.Module):
 29    """
 30    Helper network for combining two images. 
 31    The two outputs of the submodels of the ResidualDesignModel. 
 32
 33    The CombineNet is a lightweight convolutional neural network that takes 
 34    two input tensors, merges them channel-wise, and learns to predict a 
 35    combined output representation. It can be used for post-processing, 
 36    fusion of multiple model outputs, or blending of different feature spaces.
 37
 38    Architecture Overview:
 39    - 3 convolutional layers with batch normalization and GELU activation.
 40    - Final sigmoid activation to normalize outputs between [0, 1].
 41    - Optimized using L1 loss (Mean Absolute Error).
 42    """
 43    def __init__(self, input_channels=1, output_channels=1, hidden_channels=32):
 44        """
 45        Init of the CombineNet model.
 46
 47        Parameter:
 48        - input_channels (int): 
 49            Number of channels for each of the two input tensors 
 50            (e.g., 1 for grayscale, 3 for RGB).
 51        - output_channels (int): 
 52            Number of output channels of the combined result.
 53        - hidden_channels (int): 
 54            Number of feature channels in the hidden layers (internal representation).
 55        """
 56        super().__init__()
 57
 58        self.model = nn.Sequential(
 59            nn.Conv2d(input_channels, hidden_channels, kernel_size=3, stride=1, padding=1),
 60            nn.BatchNorm2d(hidden_channels),
 61            nn.GELU(),
 62
 63            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, stride=1, padding=1),
 64            nn.BatchNorm2d(hidden_channels),
 65            nn.GELU(),
 66
 67            nn.Conv2d(hidden_channels, output_channels, kernel_size=3, stride=1, padding=1),
 68            nn.Sigmoid()
 69        )
 70
 71        self.loss = torch.nn.L1Loss()
 72        self.last_loss = float("inf")
 73        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.5, 0.999))
 74
 75
 76    def forward(self, x, y):
 77        """
 78        Forward pass of the CombineNet.
 79
 80        Parameter:
 81        - x (torch.tensor): 
 82            First input tensor of shape (batch_size, input_channels, height, width).
 83        - y (torch.tensor): 
 84            Second input tensor of shape (batch_size, input_channels, height, width).
 85
 86        Returns:
 87        - torch.tensor: 
 88            Combined output tensor of shape (batch_size, output_channels, height, width).
 89        """
 90        return self.model(torch.cat([x, y], dim=1))
 91    
 92
 93    def backward(self, y_base, y_complex, y):
 94        """
 95        Backward pass (training step) for the CombineNet.
 96
 97        Parameter:
 98        - y_base (torch.tensor): 
 99            First input tensor (e.g., base model output).
100        - y_complex (torch.tensor): 
101            Second input tensor (e.g., complex or refined prediction).
102        - y (torch.tensor): 
103            Ground truth tensor (target output for supervision).
104
105        Returns:
106        - float: 
107            The scalar loss value (L1 loss) from the current optimization step.
108        """
109        self.optimizer.zero_grad()
110        y_pred = self.forward(y_base, y_complex)
111        loss = self.loss(y, y_pred)
112        loss.backward()
113        self.last_loss = loss.item()
114        self.optimizer.step()
115        return self.last_loss

Helper network for combining two images. The two outputs of the submodels of the ResidualDesignModel.

The CombineNet is a lightweight convolutional neural network that takes two input tensors, merges them channel-wise, and learns to predict a combined output representation. It can be used for post-processing, fusion of multiple model outputs, or blending of different feature spaces.

Architecture Overview:

  • 3 convolutional layers with batch normalization and GELU activation.
  • Final sigmoid activation to normalize outputs between [0, 1].
  • Optimized using L1 loss (Mean Absolute Error).
CombineNet(input_channels=1, output_channels=1, hidden_channels=32)
43    def __init__(self, input_channels=1, output_channels=1, hidden_channels=32):
44        """
45        Init of the CombineNet model.
46
47        Parameter:
48        - input_channels (int): 
49            Number of channels for each of the two input tensors 
50            (e.g., 1 for grayscale, 3 for RGB).
51        - output_channels (int): 
52            Number of output channels of the combined result.
53        - hidden_channels (int): 
54            Number of feature channels in the hidden layers (internal representation).
55        """
56        super().__init__()
57
58        self.model = nn.Sequential(
59            nn.Conv2d(input_channels, hidden_channels, kernel_size=3, stride=1, padding=1),
60            nn.BatchNorm2d(hidden_channels),
61            nn.GELU(),
62
63            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, stride=1, padding=1),
64            nn.BatchNorm2d(hidden_channels),
65            nn.GELU(),
66
67            nn.Conv2d(hidden_channels, output_channels, kernel_size=3, stride=1, padding=1),
68            nn.Sigmoid()
69        )
70
71        self.loss = torch.nn.L1Loss()
72        self.last_loss = float("inf")
73        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.5, 0.999))

Init of the CombineNet model.

Parameter:

  • input_channels (int): Number of channels for each of the two input tensors (e.g., 1 for grayscale, 3 for RGB).
  • output_channels (int): Number of output channels of the combined result.
  • hidden_channels (int): Number of feature channels in the hidden layers (internal representation).
model
loss
last_loss
optimizer
def forward(self, x, y):
76    def forward(self, x, y):
77        """
78        Forward pass of the CombineNet.
79
80        Parameter:
81        - x (torch.tensor): 
82            First input tensor of shape (batch_size, input_channels, height, width).
83        - y (torch.tensor): 
84            Second input tensor of shape (batch_size, input_channels, height, width).
85
86        Returns:
87        - torch.tensor: 
88            Combined output tensor of shape (batch_size, output_channels, height, width).
89        """
90        return self.model(torch.cat([x, y], dim=1))

Forward pass of the CombineNet.

Parameter:

  • x (torch.tensor): First input tensor of shape (batch_size, input_channels, height, width).
  • y (torch.tensor): Second input tensor of shape (batch_size, input_channels, height, width).

Returns:

  • torch.tensor: Combined output tensor of shape (batch_size, output_channels, height, width).
def backward(self, y_base, y_complex, y):
 93    def backward(self, y_base, y_complex, y):
 94        """
 95        Backward pass (training step) for the CombineNet.
 96
 97        Parameter:
 98        - y_base (torch.tensor): 
 99            First input tensor (e.g., base model output).
100        - y_complex (torch.tensor): 
101            Second input tensor (e.g., complex or refined prediction).
102        - y (torch.tensor): 
103            Ground truth tensor (target output for supervision).
104
105        Returns:
106        - float: 
107            The scalar loss value (L1 loss) from the current optimization step.
108        """
109        self.optimizer.zero_grad()
110        y_pred = self.forward(y_base, y_complex)
111        loss = self.loss(y, y_pred)
112        loss.backward()
113        self.last_loss = loss.item()
114        self.optimizer.step()
115        return self.last_loss

Backward pass (training step) for the CombineNet.

Parameter:

  • y_base (torch.tensor): First input tensor (e.g., base model output).
  • y_complex (torch.tensor): Second input tensor (e.g., complex or refined prediction).
  • y (torch.tensor): Ground truth tensor (target output for supervision).

Returns:

  • float: The scalar loss value (L1 loss) from the current optimization step.
class ResidualDesignModel(torch.nn.modules.module.Module):
122class ResidualDesignModel(nn.Module):
123    """
124    Residual Design Model for combining predictions from a base and a complex model.
125
126    The ResidualDesignModel enables two modes of combination:
127    1. **Mathematical Residual (`math`)**:
128       - Computes a weighted sum of the base and complex model outputs.
129       - The weight `alpha` is learnable and optimized via L1 loss.
130    2. **Neural Network Fusion (`nn`)**:
131       - Uses a small CNN (`CombineNet`) to learn a nonlinear combination of the outputs.
132    """
133    def __init__(self,
134                 base_model: nn.Module,
135                 complex_model: nn.Module,
136                 combine_mode="math"):  # math or nn
137        """
138        Init of the ResidualDesignModel.
139
140        Parameter:
141        - base_model (nn.Module): Pretrained or instantiated base model.
142        - complex_model (nn.Module): Pretrained or instantiated complex model.
143        - combine_mode (str, default='math'): Mode for combining outputs. Options:
144            - 'math': Weighted residual combination with learnable alpha.
145            - 'nn': Nonlinear fusion using a small CombineNet.
146        """
147        super().__init__()
148
149        self.base_model = base_model
150        self.complex_model = complex_model
151        self.combine_mode = combine_mode
152
153        self.input_channels = (self.base_model.get_input_channels(), self.complex_model.get_input_channels())  # max(self.base_model.get_input_channels(), self.complex_model.get_input_channels())
154        self.output_channels = min(self.base_model.get_output_channels(), self.complex_model.get_output_channels())
155
156        self.combine_net = CombineNet(input_channels=self.base_model.get_output_channels() + self.complex_model.get_output_channels(), 
157                                      output_channels=self.output_channels, hidden_channels=32)
158
159        self.alpha = nn.Parameter(torch.tensor(0.5))
160        self.alpha_optimizer = optim.Adam([self.alpha], lr=1e-5)
161        self.alpha_criterion = nn.L1Loss()
162
163        self.last_base_loss = float('nan')
164        self.last_complex_loss = float('nan')
165        self.last_combined_loss = float('nan')
166        self.last_combined_math_loss = float('nan')
167
168
169    def get_input_channels(self):
170        """
171        Returns a tuple with the input channels of the base and complex models.
172
173        Returns:
174        - tuple: (base_model_input_channels, complex_model_input_channels)
175        """
176        return self.input_channels
177    
178
179    def get_output_channels(self):
180        """
181        Returns the number of output channels for the combined prediction.
182
183        Returns:
184        - int: Minimum of base_model and complex_model output channels.
185        """
186        return self.output_channels
187
188    def forward(self, x):
189        self.forward(x[0], x[1])
190
191    def forward(self, x_base, x_complex):
192        """
193        Forward pass of the ResidualDesignModel.
194
195        Parameter:
196        - x_base (torch.tensor): Input tensor for the base model.
197        - x_complex (torch.tensor): Input tensor for the complex model.
198
199        Returns:
200        - torch.tensor: Combined prediction, either via weighted residual or CombineNet.
201        """
202        # print(x_base.size)
203        # print(x_complex.size)
204        y_base = self.base_model(x_base)
205        y_complex = self.complex_model(x_complex)
206
207        if self.combine_mode == 'math':
208            # y = (y_complex*(-0.5)) + y_base
209            y = y_base + self.alpha * y_complex
210            if len(y.shape) == 4:
211                y = y.squeeze(1)
212            return y
213        else:
214            return self.combine_net(y_base, y_complex)
215        
216    def backward(self, y_base, y_complex, y):
217        """
218        Backward pass to optimize the alpha parameter for mathematical residual combination.
219
220        Parameter:
221        - y_base (torch.tensor): 
222            Output of the base model.
223        - y_complex (torch.tensor): 
224            Output of the complex model.
225        - y (torch.tensor): 
226            Ground truth tensor.
227        """
228        self.alpha_optimizer.zero_grad()
229        y_pred = y_base + self.alpha * y_complex
230        combine_loss = self.alpha_criterion(y_pred, y)
231        combine_loss.backward()
232        self.last_combined_math_loss = combine_loss.item()
233        self.alpha_optimizer.step()
234        
235    def get_dict(self):
236        """
237        Returns a dictionary with the most recent loss values.
238
239        Returns:
240        - dict: Loss components (base, complex).
241
242        Notes:
243        - Useful for logging or monitoring training progress.
244        """
245        return {
246                f"loss_base": self.last_base_loss, 
247                f"loss_complex": self.last_complex_loss, 
248                f"loss_combined_net": self.combine_net.last_loss,
249                f"loss_combined_math": self.last_combined_math_loss
250               }

Residual Design Model for combining predictions from a base and a complex model.

The ResidualDesignModel enables two modes of combination:

  1. Mathematical Residual (math):
    • Computes a weighted sum of the base and complex model outputs.
    • The weight alpha is learnable and optimized via L1 loss.
  2. Neural Network Fusion (nn):
    • Uses a small CNN (CombineNet) to learn a nonlinear combination of the outputs.
ResidualDesignModel( base_model: torch.nn.modules.module.Module, complex_model: torch.nn.modules.module.Module, combine_mode='math')
133    def __init__(self,
134                 base_model: nn.Module,
135                 complex_model: nn.Module,
136                 combine_mode="math"):  # math or nn
137        """
138        Init of the ResidualDesignModel.
139
140        Parameter:
141        - base_model (nn.Module): Pretrained or instantiated base model.
142        - complex_model (nn.Module): Pretrained or instantiated complex model.
143        - combine_mode (str, default='math'): Mode for combining outputs. Options:
144            - 'math': Weighted residual combination with learnable alpha.
145            - 'nn': Nonlinear fusion using a small CombineNet.
146        """
147        super().__init__()
148
149        self.base_model = base_model
150        self.complex_model = complex_model
151        self.combine_mode = combine_mode
152
153        self.input_channels = (self.base_model.get_input_channels(), self.complex_model.get_input_channels())  # max(self.base_model.get_input_channels(), self.complex_model.get_input_channels())
154        self.output_channels = min(self.base_model.get_output_channels(), self.complex_model.get_output_channels())
155
156        self.combine_net = CombineNet(input_channels=self.base_model.get_output_channels() + self.complex_model.get_output_channels(), 
157                                      output_channels=self.output_channels, hidden_channels=32)
158
159        self.alpha = nn.Parameter(torch.tensor(0.5))
160        self.alpha_optimizer = optim.Adam([self.alpha], lr=1e-5)
161        self.alpha_criterion = nn.L1Loss()
162
163        self.last_base_loss = float('nan')
164        self.last_complex_loss = float('nan')
165        self.last_combined_loss = float('nan')
166        self.last_combined_math_loss = float('nan')

Init of the ResidualDesignModel.

Parameter:

  • base_model (nn.Module): Pretrained or instantiated base model.
  • complex_model (nn.Module): Pretrained or instantiated complex model.
  • combine_mode (str, default='math'): Mode for combining outputs. Options:
    • 'math': Weighted residual combination with learnable alpha.
    • 'nn': Nonlinear fusion using a small CombineNet.
base_model
complex_model
combine_mode
input_channels
output_channels
combine_net
alpha
alpha_optimizer
alpha_criterion
last_base_loss
last_complex_loss
last_combined_loss
last_combined_math_loss
def get_input_channels(self):
169    def get_input_channels(self):
170        """
171        Returns a tuple with the input channels of the base and complex models.
172
173        Returns:
174        - tuple: (base_model_input_channels, complex_model_input_channels)
175        """
176        return self.input_channels

Returns a tuple with the input channels of the base and complex models.

Returns:

  • tuple: (base_model_input_channels, complex_model_input_channels)
def get_output_channels(self):
179    def get_output_channels(self):
180        """
181        Returns the number of output channels for the combined prediction.
182
183        Returns:
184        - int: Minimum of base_model and complex_model output channels.
185        """
186        return self.output_channels

Returns the number of output channels for the combined prediction.

Returns:

  • int: Minimum of base_model and complex_model output channels.
def forward(self, x_base, x_complex):
191    def forward(self, x_base, x_complex):
192        """
193        Forward pass of the ResidualDesignModel.
194
195        Parameter:
196        - x_base (torch.tensor): Input tensor for the base model.
197        - x_complex (torch.tensor): Input tensor for the complex model.
198
199        Returns:
200        - torch.tensor: Combined prediction, either via weighted residual or CombineNet.
201        """
202        # print(x_base.size)
203        # print(x_complex.size)
204        y_base = self.base_model(x_base)
205        y_complex = self.complex_model(x_complex)
206
207        if self.combine_mode == 'math':
208            # y = (y_complex*(-0.5)) + y_base
209            y = y_base + self.alpha * y_complex
210            if len(y.shape) == 4:
211                y = y.squeeze(1)
212            return y
213        else:
214            return self.combine_net(y_base, y_complex)

Forward pass of the ResidualDesignModel.

Parameter:

  • x_base (torch.tensor): Input tensor for the base model.
  • x_complex (torch.tensor): Input tensor for the complex model.

Returns:

  • torch.tensor: Combined prediction, either via weighted residual or CombineNet.
def backward(self, y_base, y_complex, y):
216    def backward(self, y_base, y_complex, y):
217        """
218        Backward pass to optimize the alpha parameter for mathematical residual combination.
219
220        Parameter:
221        - y_base (torch.tensor): 
222            Output of the base model.
223        - y_complex (torch.tensor): 
224            Output of the complex model.
225        - y (torch.tensor): 
226            Ground truth tensor.
227        """
228        self.alpha_optimizer.zero_grad()
229        y_pred = y_base + self.alpha * y_complex
230        combine_loss = self.alpha_criterion(y_pred, y)
231        combine_loss.backward()
232        self.last_combined_math_loss = combine_loss.item()
233        self.alpha_optimizer.step()

Backward pass to optimize the alpha parameter for mathematical residual combination.

Parameter:

  • y_base (torch.tensor): Output of the base model.
  • y_complex (torch.tensor): Output of the complex model.
  • y (torch.tensor): Ground truth tensor.
def get_dict(self):
235    def get_dict(self):
236        """
237        Returns a dictionary with the most recent loss values.
238
239        Returns:
240        - dict: Loss components (base, complex).
241
242        Notes:
243        - Useful for logging or monitoring training progress.
244        """
245        return {
246                f"loss_base": self.last_base_loss, 
247                f"loss_complex": self.last_complex_loss, 
248                f"loss_combined_net": self.combine_net.last_loss,
249                f"loss_combined_math": self.last_combined_math_loss
250               }

Returns a dictionary with the most recent loss values.

Returns:

  • dict: Loss components (base, complex).

Notes:

  • Useful for logging or monitoring training progress.