image_to_image.data.residual_physgen

A PhysGen Dataset Wrapper to get base-propagation and complex-propagation in one dataloader.

See:

  1"""
  2A PhysGen Dataset Wrapper to get base-propagation and 
  3complex-propagation in one dataloader.
  4
  5See:
  6- https://huggingface.co/datasets/mspitzna/physicsgen
  7- https://arxiv.org/abs/2503.05333
  8- https://github.com/physicsgen/physicsgen
  9"""
 10# ---------------------------
 11#        > Imports <
 12# ---------------------------
 13import torch
 14import torch.nn.functional as F
 15from torch.utils.data import DataLoader, Dataset
 16
 17from .physgen import PhysGenDataset
 18
 19
 20
 21# ---------------------------
 22#         > Helper <
 23# ---------------------------
 24def to_device(dataset):
 25    """
 26    Move dataset tensors to the appropriate device (CPU or GPU).
 27
 28    This helper function expects a dataset item formatted as 
 29    [input_tensor, target_tensor, index]. It automatically moves 
 30    all tensor elements to the available device.
 31
 32    Parameter:
 33    - dataset (list): 
 34        A list of three elements: 
 35        [input_tensor (torch.Tensor), target_tensor (torch.Tensor), index (int)].
 36
 37    Returns: 
 38    - list: 
 39        A list [input_tensor_on_device, target_tensor_on_device, index].
 40
 41    Raises:
 42    - ValueError: 
 43        If the provided dataset item does not have exactly 3 elements.
 44    """
 45    # Input: [Tensor(), Tensor(), int]
 46    if len(dataset) != 3:
 47        raise ValueError("Expected dataset to be a list of 3 values")
 48    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 49    return [dataset[0].to(device), dataset[1].to(device), dataset[2]]
 50
 51
 52
 53# ---------------------------
 54#        > Dataset <
 55# ---------------------------
 56class PhysGenResidualDataset(Dataset):
 57    """
 58    Dataset wrapper combining multiple PhysGen dataset variations 
 59    for residual learning experiments.
 60
 61    This dataset constructs three related PhysGen datasets:
 62    1. Baseline dataset (sound_baseline → 'standard' output)
 63    2. Complex dataset (user-selected variation → 'complex_only' output)
 64    3. Fusion dataset (user-selected variation → 'standard' output)
 65
 66    It is designed for residual or multi-source learning setups 
 67    where the model uses both baseline and complex physics simulations.
 68    """
 69    def __init__(self, variation="sound_baseline", mode="train", 
 70                 fake_rgb_output=False, make_14_dividable_size=False,
 71                 reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False):
 72        """
 73        Initialize the PhysGenResidualDataset with multiple data sources.
 74
 75        Parameter:
 76        - variation (str, default='sound_baseline'): 
 77            Specifies which physics variation to use for the complex and fusion datasets.
 78            Common options: {'sound_reflection', 'sound_diffraction', 'sound_combined'}.
 79        - mode (str, default='train'): 
 80            Specifies dataset mode. Options: {'train', 'validation'}.
 81        - fake_rgb_output (bool, default=False): 
 82            If True, single-channel inputs are expanded to fake RGB format.
 83        - make_14_dividable_size (bool, default=False): 
 84            If True, ensures images are resized so that their height and width are divisible by 14.
 85        - reflexion_channels (bool, default=False): 
 86            If ray-traces should add to the input. Only for the complex part.
 87        - reflexion_steps (int, default=36): 
 88            Defines how many traces should get created.
 89        - reflexions_as_channels (bool, default=False): 
 90            If True, every trace gets its own channel, else every trace in one channel.
 91        """
 92        self.train_dataset_base = PhysGenDataset(mode='train', variation="sound_baseline", input_type="osm", output_type="standard", 
 93                                                 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size)
 94        self.val_dataset_base = PhysGenDataset(mode='validation', variation="sound_baseline", input_type="osm", output_type="standard", 
 95                                               fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size)
 96
 97        self.train_dataset_complex = PhysGenDataset(mode='train', variation=variation, input_type="osm", output_type="complex_only", 
 98                                                    fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size, 
 99                                                    reflexion_channels=reflexion_channels, reflexion_steps=reflexion_steps, reflexions_as_channels=reflexions_as_channels)
100        self.val_dataset_complex = PhysGenDataset(mode='validation', variation=variation, input_type="osm", output_type="complex_only", 
101                                                  fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size,
102                                                  reflexion_channels=reflexion_channels, reflexion_steps=reflexion_steps, reflexions_as_channels=reflexions_as_channels)
103
104        self.train_dataset_fusion = PhysGenDataset(mode='train', variation=variation, input_type="osm", output_type="standard", 
105                                                   fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size)
106        self.val_dataset_fusion = PhysGenDataset(mode='validation', variation=variation, input_type="osm", output_type="standard", 
107                                                 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size)
108        
109        self.datasets = [(self.train_dataset_base, self.val_dataset_base), (self.train_dataset_complex, self.val_dataset_complex), (self.train_dataset_fusion, self.val_dataset_fusion)]
110
111    def __len__(self):
112        """
113        Return the number of samples in the baseline training dataset.
114
115        Returns:
116        - int: Number of samples in the training split of the baseline dataset.
117        """
118        return len(self.train_dataset_base)
119
120    def __getitem__(self, idx, is_validation=False):
121        """
122        Retrieve a combined sample across baseline, complex, and fusion datasets.
123
124        This method returns a tuple containing:
125        - inputs: (base_input, complex_input)
126        - targets: (base_target, complex_target, full_target)
127
128        Parameter:
129        - idx (int): 
130            Index of the sample to retrieve.
131        - is_validation (bool, default=False): 
132            If True, samples are drawn from the validation split; 
133            otherwise, from the training split.
134
135        Returns:
136        - tuple: 
137            ((base_input, complex_input), (base_target, complex_target, full_target))
138        """
139        data_idx = 1 if is_validation else 0
140
141        base_input, base_target = self.datasets[0][data_idx][idx]
142        complex_input, complex_target = self.datasets[1][data_idx][idx]
143        _, target_ = self.datasets[2][data_idx][idx]
144
145        return (base_input, complex_input), (base_target, complex_target, target_)
def to_device(dataset):
25def to_device(dataset):
26    """
27    Move dataset tensors to the appropriate device (CPU or GPU).
28
29    This helper function expects a dataset item formatted as 
30    [input_tensor, target_tensor, index]. It automatically moves 
31    all tensor elements to the available device.
32
33    Parameter:
34    - dataset (list): 
35        A list of three elements: 
36        [input_tensor (torch.Tensor), target_tensor (torch.Tensor), index (int)].
37
38    Returns: 
39    - list: 
40        A list [input_tensor_on_device, target_tensor_on_device, index].
41
42    Raises:
43    - ValueError: 
44        If the provided dataset item does not have exactly 3 elements.
45    """
46    # Input: [Tensor(), Tensor(), int]
47    if len(dataset) != 3:
48        raise ValueError("Expected dataset to be a list of 3 values")
49    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50    return [dataset[0].to(device), dataset[1].to(device), dataset[2]]

Move dataset tensors to the appropriate device (CPU or GPU).

This helper function expects a dataset item formatted as [input_tensor, target_tensor, index]. It automatically moves all tensor elements to the available device.

Parameter:

  • dataset (list): A list of three elements: [input_tensor (torch.Tensor), target_tensor (torch.Tensor), index (int)].

Returns:

  • list: A list [input_tensor_on_device, target_tensor_on_device, index].

Raises:

  • ValueError: If the provided dataset item does not have exactly 3 elements.
class PhysGenResidualDataset(typing.Generic[+_T_co]):
 57class PhysGenResidualDataset(Dataset):
 58    """
 59    Dataset wrapper combining multiple PhysGen dataset variations 
 60    for residual learning experiments.
 61
 62    This dataset constructs three related PhysGen datasets:
 63    1. Baseline dataset (sound_baseline → 'standard' output)
 64    2. Complex dataset (user-selected variation → 'complex_only' output)
 65    3. Fusion dataset (user-selected variation → 'standard' output)
 66
 67    It is designed for residual or multi-source learning setups 
 68    where the model uses both baseline and complex physics simulations.
 69    """
 70    def __init__(self, variation="sound_baseline", mode="train", 
 71                 fake_rgb_output=False, make_14_dividable_size=False,
 72                 reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False):
 73        """
 74        Initialize the PhysGenResidualDataset with multiple data sources.
 75
 76        Parameter:
 77        - variation (str, default='sound_baseline'): 
 78            Specifies which physics variation to use for the complex and fusion datasets.
 79            Common options: {'sound_reflection', 'sound_diffraction', 'sound_combined'}.
 80        - mode (str, default='train'): 
 81            Specifies dataset mode. Options: {'train', 'validation'}.
 82        - fake_rgb_output (bool, default=False): 
 83            If True, single-channel inputs are expanded to fake RGB format.
 84        - make_14_dividable_size (bool, default=False): 
 85            If True, ensures images are resized so that their height and width are divisible by 14.
 86        - reflexion_channels (bool, default=False): 
 87            If ray-traces should add to the input. Only for the complex part.
 88        - reflexion_steps (int, default=36): 
 89            Defines how many traces should get created.
 90        - reflexions_as_channels (bool, default=False): 
 91            If True, every trace gets its own channel, else every trace in one channel.
 92        """
 93        self.train_dataset_base = PhysGenDataset(mode='train', variation="sound_baseline", input_type="osm", output_type="standard", 
 94                                                 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size)
 95        self.val_dataset_base = PhysGenDataset(mode='validation', variation="sound_baseline", input_type="osm", output_type="standard", 
 96                                               fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size)
 97
 98        self.train_dataset_complex = PhysGenDataset(mode='train', variation=variation, input_type="osm", output_type="complex_only", 
 99                                                    fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size, 
100                                                    reflexion_channels=reflexion_channels, reflexion_steps=reflexion_steps, reflexions_as_channels=reflexions_as_channels)
101        self.val_dataset_complex = PhysGenDataset(mode='validation', variation=variation, input_type="osm", output_type="complex_only", 
102                                                  fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size,
103                                                  reflexion_channels=reflexion_channels, reflexion_steps=reflexion_steps, reflexions_as_channels=reflexions_as_channels)
104
105        self.train_dataset_fusion = PhysGenDataset(mode='train', variation=variation, input_type="osm", output_type="standard", 
106                                                   fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size)
107        self.val_dataset_fusion = PhysGenDataset(mode='validation', variation=variation, input_type="osm", output_type="standard", 
108                                                 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size)
109        
110        self.datasets = [(self.train_dataset_base, self.val_dataset_base), (self.train_dataset_complex, self.val_dataset_complex), (self.train_dataset_fusion, self.val_dataset_fusion)]
111
112    def __len__(self):
113        """
114        Return the number of samples in the baseline training dataset.
115
116        Returns:
117        - int: Number of samples in the training split of the baseline dataset.
118        """
119        return len(self.train_dataset_base)
120
121    def __getitem__(self, idx, is_validation=False):
122        """
123        Retrieve a combined sample across baseline, complex, and fusion datasets.
124
125        This method returns a tuple containing:
126        - inputs: (base_input, complex_input)
127        - targets: (base_target, complex_target, full_target)
128
129        Parameter:
130        - idx (int): 
131            Index of the sample to retrieve.
132        - is_validation (bool, default=False): 
133            If True, samples are drawn from the validation split; 
134            otherwise, from the training split.
135
136        Returns:
137        - tuple: 
138            ((base_input, complex_input), (base_target, complex_target, full_target))
139        """
140        data_idx = 1 if is_validation else 0
141
142        base_input, base_target = self.datasets[0][data_idx][idx]
143        complex_input, complex_target = self.datasets[1][data_idx][idx]
144        _, target_ = self.datasets[2][data_idx][idx]
145
146        return (base_input, complex_input), (base_target, complex_target, target_)

Dataset wrapper combining multiple PhysGen dataset variations for residual learning experiments.

This dataset constructs three related PhysGen datasets:

  1. Baseline dataset (sound_baseline → 'standard' output)
  2. Complex dataset (user-selected variation → 'complex_only' output)
  3. Fusion dataset (user-selected variation → 'standard' output)

It is designed for residual or multi-source learning setups where the model uses both baseline and complex physics simulations.

PhysGenResidualDataset( variation='sound_baseline', mode='train', fake_rgb_output=False, make_14_dividable_size=False, reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False)
 70    def __init__(self, variation="sound_baseline", mode="train", 
 71                 fake_rgb_output=False, make_14_dividable_size=False,
 72                 reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False):
 73        """
 74        Initialize the PhysGenResidualDataset with multiple data sources.
 75
 76        Parameter:
 77        - variation (str, default='sound_baseline'): 
 78            Specifies which physics variation to use for the complex and fusion datasets.
 79            Common options: {'sound_reflection', 'sound_diffraction', 'sound_combined'}.
 80        - mode (str, default='train'): 
 81            Specifies dataset mode. Options: {'train', 'validation'}.
 82        - fake_rgb_output (bool, default=False): 
 83            If True, single-channel inputs are expanded to fake RGB format.
 84        - make_14_dividable_size (bool, default=False): 
 85            If True, ensures images are resized so that their height and width are divisible by 14.
 86        - reflexion_channels (bool, default=False): 
 87            If ray-traces should add to the input. Only for the complex part.
 88        - reflexion_steps (int, default=36): 
 89            Defines how many traces should get created.
 90        - reflexions_as_channels (bool, default=False): 
 91            If True, every trace gets its own channel, else every trace in one channel.
 92        """
 93        self.train_dataset_base = PhysGenDataset(mode='train', variation="sound_baseline", input_type="osm", output_type="standard", 
 94                                                 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size)
 95        self.val_dataset_base = PhysGenDataset(mode='validation', variation="sound_baseline", input_type="osm", output_type="standard", 
 96                                               fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size)
 97
 98        self.train_dataset_complex = PhysGenDataset(mode='train', variation=variation, input_type="osm", output_type="complex_only", 
 99                                                    fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size, 
100                                                    reflexion_channels=reflexion_channels, reflexion_steps=reflexion_steps, reflexions_as_channels=reflexions_as_channels)
101        self.val_dataset_complex = PhysGenDataset(mode='validation', variation=variation, input_type="osm", output_type="complex_only", 
102                                                  fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size,
103                                                  reflexion_channels=reflexion_channels, reflexion_steps=reflexion_steps, reflexions_as_channels=reflexions_as_channels)
104
105        self.train_dataset_fusion = PhysGenDataset(mode='train', variation=variation, input_type="osm", output_type="standard", 
106                                                   fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size)
107        self.val_dataset_fusion = PhysGenDataset(mode='validation', variation=variation, input_type="osm", output_type="standard", 
108                                                 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size)
109        
110        self.datasets = [(self.train_dataset_base, self.val_dataset_base), (self.train_dataset_complex, self.val_dataset_complex), (self.train_dataset_fusion, self.val_dataset_fusion)]

Initialize the PhysGenResidualDataset with multiple data sources.

Parameter:

  • variation (str, default='sound_baseline'): Specifies which physics variation to use for the complex and fusion datasets. Common options: {'sound_reflection', 'sound_diffraction', 'sound_combined'}.
  • mode (str, default='train'): Specifies dataset mode. Options: {'train', 'validation'}.
  • fake_rgb_output (bool, default=False): If True, single-channel inputs are expanded to fake RGB format.
  • make_14_dividable_size (bool, default=False): If True, ensures images are resized so that their height and width are divisible by 14.
  • reflexion_channels (bool, default=False): If ray-traces should add to the input. Only for the complex part.
  • reflexion_steps (int, default=36): Defines how many traces should get created.
  • reflexions_as_channels (bool, default=False): If True, every trace gets its own channel, else every trace in one channel.
train_dataset_base
val_dataset_base
train_dataset_complex
val_dataset_complex
train_dataset_fusion
val_dataset_fusion
datasets