image_to_image.data.residual_physgen
A PhysGen Dataset Wrapper to get base-propagation and complex-propagation in one dataloader.
See:
1""" 2A PhysGen Dataset Wrapper to get base-propagation and 3complex-propagation in one dataloader. 4 5See: 6- https://huggingface.co/datasets/mspitzna/physicsgen 7- https://arxiv.org/abs/2503.05333 8- https://github.com/physicsgen/physicsgen 9""" 10# --------------------------- 11# > Imports < 12# --------------------------- 13import torch 14import torch.nn.functional as F 15from torch.utils.data import DataLoader, Dataset 16 17from .physgen import PhysGenDataset 18 19 20 21# --------------------------- 22# > Helper < 23# --------------------------- 24def to_device(dataset): 25 """ 26 Move dataset tensors to the appropriate device (CPU or GPU). 27 28 This helper function expects a dataset item formatted as 29 [input_tensor, target_tensor, index]. It automatically moves 30 all tensor elements to the available device. 31 32 Parameter: 33 - dataset (list): 34 A list of three elements: 35 [input_tensor (torch.Tensor), target_tensor (torch.Tensor), index (int)]. 36 37 Returns: 38 - list: 39 A list [input_tensor_on_device, target_tensor_on_device, index]. 40 41 Raises: 42 - ValueError: 43 If the provided dataset item does not have exactly 3 elements. 44 """ 45 # Input: [Tensor(), Tensor(), int] 46 if len(dataset) != 3: 47 raise ValueError("Expected dataset to be a list of 3 values") 48 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 49 return [dataset[0].to(device), dataset[1].to(device), dataset[2]] 50 51 52 53# --------------------------- 54# > Dataset < 55# --------------------------- 56class PhysGenResidualDataset(Dataset): 57 """ 58 Dataset wrapper combining multiple PhysGen dataset variations 59 for residual learning experiments. 60 61 This dataset constructs three related PhysGen datasets: 62 1. Baseline dataset (sound_baseline → 'standard' output) 63 2. Complex dataset (user-selected variation → 'complex_only' output) 64 3. Fusion dataset (user-selected variation → 'standard' output) 65 66 It is designed for residual or multi-source learning setups 67 where the model uses both baseline and complex physics simulations. 68 """ 69 def __init__(self, variation="sound_baseline", mode="train", 70 fake_rgb_output=False, make_14_dividable_size=False, 71 reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False): 72 """ 73 Initialize the PhysGenResidualDataset with multiple data sources. 74 75 Parameter: 76 - variation (str, default='sound_baseline'): 77 Specifies which physics variation to use for the complex and fusion datasets. 78 Common options: {'sound_reflection', 'sound_diffraction', 'sound_combined'}. 79 - mode (str, default='train'): 80 Specifies dataset mode. Options: {'train', 'validation'}. 81 - fake_rgb_output (bool, default=False): 82 If True, single-channel inputs are expanded to fake RGB format. 83 - make_14_dividable_size (bool, default=False): 84 If True, ensures images are resized so that their height and width are divisible by 14. 85 - reflexion_channels (bool, default=False): 86 If ray-traces should add to the input. Only for the complex part. 87 - reflexion_steps (int, default=36): 88 Defines how many traces should get created. 89 - reflexions_as_channels (bool, default=False): 90 If True, every trace gets its own channel, else every trace in one channel. 91 """ 92 self.train_dataset_base = PhysGenDataset(mode='train', variation="sound_baseline", input_type="osm", output_type="standard", 93 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size) 94 self.val_dataset_base = PhysGenDataset(mode='validation', variation="sound_baseline", input_type="osm", output_type="standard", 95 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size) 96 97 self.train_dataset_complex = PhysGenDataset(mode='train', variation=variation, input_type="osm", output_type="complex_only", 98 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size, 99 reflexion_channels=reflexion_channels, reflexion_steps=reflexion_steps, reflexions_as_channels=reflexions_as_channels) 100 self.val_dataset_complex = PhysGenDataset(mode='validation', variation=variation, input_type="osm", output_type="complex_only", 101 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size, 102 reflexion_channels=reflexion_channels, reflexion_steps=reflexion_steps, reflexions_as_channels=reflexions_as_channels) 103 104 self.train_dataset_fusion = PhysGenDataset(mode='train', variation=variation, input_type="osm", output_type="standard", 105 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size) 106 self.val_dataset_fusion = PhysGenDataset(mode='validation', variation=variation, input_type="osm", output_type="standard", 107 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size) 108 109 self.datasets = [(self.train_dataset_base, self.val_dataset_base), (self.train_dataset_complex, self.val_dataset_complex), (self.train_dataset_fusion, self.val_dataset_fusion)] 110 111 def __len__(self): 112 """ 113 Return the number of samples in the baseline training dataset. 114 115 Returns: 116 - int: Number of samples in the training split of the baseline dataset. 117 """ 118 return len(self.train_dataset_base) 119 120 def __getitem__(self, idx, is_validation=False): 121 """ 122 Retrieve a combined sample across baseline, complex, and fusion datasets. 123 124 This method returns a tuple containing: 125 - inputs: (base_input, complex_input) 126 - targets: (base_target, complex_target, full_target) 127 128 Parameter: 129 - idx (int): 130 Index of the sample to retrieve. 131 - is_validation (bool, default=False): 132 If True, samples are drawn from the validation split; 133 otherwise, from the training split. 134 135 Returns: 136 - tuple: 137 ((base_input, complex_input), (base_target, complex_target, full_target)) 138 """ 139 data_idx = 1 if is_validation else 0 140 141 base_input, base_target = self.datasets[0][data_idx][idx] 142 complex_input, complex_target = self.datasets[1][data_idx][idx] 143 _, target_ = self.datasets[2][data_idx][idx] 144 145 return (base_input, complex_input), (base_target, complex_target, target_)
def
to_device(dataset):
25def to_device(dataset): 26 """ 27 Move dataset tensors to the appropriate device (CPU or GPU). 28 29 This helper function expects a dataset item formatted as 30 [input_tensor, target_tensor, index]. It automatically moves 31 all tensor elements to the available device. 32 33 Parameter: 34 - dataset (list): 35 A list of three elements: 36 [input_tensor (torch.Tensor), target_tensor (torch.Tensor), index (int)]. 37 38 Returns: 39 - list: 40 A list [input_tensor_on_device, target_tensor_on_device, index]. 41 42 Raises: 43 - ValueError: 44 If the provided dataset item does not have exactly 3 elements. 45 """ 46 # Input: [Tensor(), Tensor(), int] 47 if len(dataset) != 3: 48 raise ValueError("Expected dataset to be a list of 3 values") 49 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 50 return [dataset[0].to(device), dataset[1].to(device), dataset[2]]
Move dataset tensors to the appropriate device (CPU or GPU).
This helper function expects a dataset item formatted as [input_tensor, target_tensor, index]. It automatically moves all tensor elements to the available device.
Parameter:
- dataset (list): A list of three elements: [input_tensor (torch.Tensor), target_tensor (torch.Tensor), index (int)].
Returns:
- list: A list [input_tensor_on_device, target_tensor_on_device, index].
Raises:
- ValueError: If the provided dataset item does not have exactly 3 elements.
class
PhysGenResidualDataset(typing.Generic[+_T_co]):
57class PhysGenResidualDataset(Dataset): 58 """ 59 Dataset wrapper combining multiple PhysGen dataset variations 60 for residual learning experiments. 61 62 This dataset constructs three related PhysGen datasets: 63 1. Baseline dataset (sound_baseline → 'standard' output) 64 2. Complex dataset (user-selected variation → 'complex_only' output) 65 3. Fusion dataset (user-selected variation → 'standard' output) 66 67 It is designed for residual or multi-source learning setups 68 where the model uses both baseline and complex physics simulations. 69 """ 70 def __init__(self, variation="sound_baseline", mode="train", 71 fake_rgb_output=False, make_14_dividable_size=False, 72 reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False): 73 """ 74 Initialize the PhysGenResidualDataset with multiple data sources. 75 76 Parameter: 77 - variation (str, default='sound_baseline'): 78 Specifies which physics variation to use for the complex and fusion datasets. 79 Common options: {'sound_reflection', 'sound_diffraction', 'sound_combined'}. 80 - mode (str, default='train'): 81 Specifies dataset mode. Options: {'train', 'validation'}. 82 - fake_rgb_output (bool, default=False): 83 If True, single-channel inputs are expanded to fake RGB format. 84 - make_14_dividable_size (bool, default=False): 85 If True, ensures images are resized so that their height and width are divisible by 14. 86 - reflexion_channels (bool, default=False): 87 If ray-traces should add to the input. Only for the complex part. 88 - reflexion_steps (int, default=36): 89 Defines how many traces should get created. 90 - reflexions_as_channels (bool, default=False): 91 If True, every trace gets its own channel, else every trace in one channel. 92 """ 93 self.train_dataset_base = PhysGenDataset(mode='train', variation="sound_baseline", input_type="osm", output_type="standard", 94 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size) 95 self.val_dataset_base = PhysGenDataset(mode='validation', variation="sound_baseline", input_type="osm", output_type="standard", 96 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size) 97 98 self.train_dataset_complex = PhysGenDataset(mode='train', variation=variation, input_type="osm", output_type="complex_only", 99 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size, 100 reflexion_channels=reflexion_channels, reflexion_steps=reflexion_steps, reflexions_as_channels=reflexions_as_channels) 101 self.val_dataset_complex = PhysGenDataset(mode='validation', variation=variation, input_type="osm", output_type="complex_only", 102 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size, 103 reflexion_channels=reflexion_channels, reflexion_steps=reflexion_steps, reflexions_as_channels=reflexions_as_channels) 104 105 self.train_dataset_fusion = PhysGenDataset(mode='train', variation=variation, input_type="osm", output_type="standard", 106 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size) 107 self.val_dataset_fusion = PhysGenDataset(mode='validation', variation=variation, input_type="osm", output_type="standard", 108 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size) 109 110 self.datasets = [(self.train_dataset_base, self.val_dataset_base), (self.train_dataset_complex, self.val_dataset_complex), (self.train_dataset_fusion, self.val_dataset_fusion)] 111 112 def __len__(self): 113 """ 114 Return the number of samples in the baseline training dataset. 115 116 Returns: 117 - int: Number of samples in the training split of the baseline dataset. 118 """ 119 return len(self.train_dataset_base) 120 121 def __getitem__(self, idx, is_validation=False): 122 """ 123 Retrieve a combined sample across baseline, complex, and fusion datasets. 124 125 This method returns a tuple containing: 126 - inputs: (base_input, complex_input) 127 - targets: (base_target, complex_target, full_target) 128 129 Parameter: 130 - idx (int): 131 Index of the sample to retrieve. 132 - is_validation (bool, default=False): 133 If True, samples are drawn from the validation split; 134 otherwise, from the training split. 135 136 Returns: 137 - tuple: 138 ((base_input, complex_input), (base_target, complex_target, full_target)) 139 """ 140 data_idx = 1 if is_validation else 0 141 142 base_input, base_target = self.datasets[0][data_idx][idx] 143 complex_input, complex_target = self.datasets[1][data_idx][idx] 144 _, target_ = self.datasets[2][data_idx][idx] 145 146 return (base_input, complex_input), (base_target, complex_target, target_)
Dataset wrapper combining multiple PhysGen dataset variations for residual learning experiments.
This dataset constructs three related PhysGen datasets:
- Baseline dataset (sound_baseline → 'standard' output)
- Complex dataset (user-selected variation → 'complex_only' output)
- Fusion dataset (user-selected variation → 'standard' output)
It is designed for residual or multi-source learning setups where the model uses both baseline and complex physics simulations.
PhysGenResidualDataset( variation='sound_baseline', mode='train', fake_rgb_output=False, make_14_dividable_size=False, reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False)
70 def __init__(self, variation="sound_baseline", mode="train", 71 fake_rgb_output=False, make_14_dividable_size=False, 72 reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False): 73 """ 74 Initialize the PhysGenResidualDataset with multiple data sources. 75 76 Parameter: 77 - variation (str, default='sound_baseline'): 78 Specifies which physics variation to use for the complex and fusion datasets. 79 Common options: {'sound_reflection', 'sound_diffraction', 'sound_combined'}. 80 - mode (str, default='train'): 81 Specifies dataset mode. Options: {'train', 'validation'}. 82 - fake_rgb_output (bool, default=False): 83 If True, single-channel inputs are expanded to fake RGB format. 84 - make_14_dividable_size (bool, default=False): 85 If True, ensures images are resized so that their height and width are divisible by 14. 86 - reflexion_channels (bool, default=False): 87 If ray-traces should add to the input. Only for the complex part. 88 - reflexion_steps (int, default=36): 89 Defines how many traces should get created. 90 - reflexions_as_channels (bool, default=False): 91 If True, every trace gets its own channel, else every trace in one channel. 92 """ 93 self.train_dataset_base = PhysGenDataset(mode='train', variation="sound_baseline", input_type="osm", output_type="standard", 94 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size) 95 self.val_dataset_base = PhysGenDataset(mode='validation', variation="sound_baseline", input_type="osm", output_type="standard", 96 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size) 97 98 self.train_dataset_complex = PhysGenDataset(mode='train', variation=variation, input_type="osm", output_type="complex_only", 99 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size, 100 reflexion_channels=reflexion_channels, reflexion_steps=reflexion_steps, reflexions_as_channels=reflexions_as_channels) 101 self.val_dataset_complex = PhysGenDataset(mode='validation', variation=variation, input_type="osm", output_type="complex_only", 102 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size, 103 reflexion_channels=reflexion_channels, reflexion_steps=reflexion_steps, reflexions_as_channels=reflexions_as_channels) 104 105 self.train_dataset_fusion = PhysGenDataset(mode='train', variation=variation, input_type="osm", output_type="standard", 106 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size) 107 self.val_dataset_fusion = PhysGenDataset(mode='validation', variation=variation, input_type="osm", output_type="standard", 108 fake_rgb_output=fake_rgb_output, make_14_dividable_size=make_14_dividable_size) 109 110 self.datasets = [(self.train_dataset_base, self.val_dataset_base), (self.train_dataset_complex, self.val_dataset_complex), (self.train_dataset_fusion, self.val_dataset_fusion)]
Initialize the PhysGenResidualDataset with multiple data sources.
Parameter:
- variation (str, default='sound_baseline'): Specifies which physics variation to use for the complex and fusion datasets. Common options: {'sound_reflection', 'sound_diffraction', 'sound_combined'}.
- mode (str, default='train'): Specifies dataset mode. Options: {'train', 'validation'}.
- fake_rgb_output (bool, default=False): If True, single-channel inputs are expanded to fake RGB format.
- make_14_dividable_size (bool, default=False): If True, ensures images are resized so that their height and width are divisible by 14.
- reflexion_channels (bool, default=False): If ray-traces should add to the input. Only for the complex part.
- reflexion_steps (int, default=36): Defines how many traces should get created.
- reflexions_as_channels (bool, default=False): If True, every trace gets its own channel, else every trace in one channel.