image_to_image.data.physgen

PhysGen Dataset Loader

PyTorch DataLoader.

Also provide some functions for downloading the dataset.

See:

  1"""
  2PhysGen Dataset Loader
  3
  4PyTorch DataLoader.
  5
  6Also provide some functions for downloading the dataset.
  7
  8See:
  9- https://huggingface.co/datasets/mspitzna/physicsgen
 10- https://arxiv.org/abs/2503.05333
 11- https://github.com/physicsgen/physicsgen
 12"""
 13# ---------------------------
 14#        > Imports <
 15# ---------------------------
 16import os
 17import shutil
 18from PIL import Image
 19
 20from datasets import load_dataset
 21
 22import numpy as np
 23import cv2
 24
 25import torch
 26import torch.nn.functional as F
 27from torch.utils.data import DataLoader, Dataset
 28# import torchvision.transforms as transforms
 29from torchvision import transforms
 30
 31import img_phy_sim as ips
 32import prime_printer as prime
 33
 34
 35
 36# ---------------------------
 37#        > Helper <
 38# ---------------------------
 39def resize_tensor_to_divisible_by_14(tensor: torch.Tensor) -> torch.Tensor:
 40    """
 41    Resize a tensor so that its height and width are divisible by 14.
 42
 43    This function ensures the spatial dimensions (H, W) of a given tensor 
 44    are compatible with architectures that require sizes divisible by 14 
 45    (e.g., ResNet, ResFCN). It resizes using bilinear interpolation.
 46
 47    Parameter:
 48    - tensor (torch.Tensor): 
 49        Input tensor of shape (C, H, W) or (B, C, H, W).
 50
 51    Returns:
 52    - torch.Tensor: 
 53        Resized tensor with dimensions divisible by 14.
 54
 55    Raises:
 56    - ValueError: 
 57        If the tensor has neither 3 nor 4 dimensions.
 58    """
 59    if tensor.dim() == 3:
 60        c, h, w = tensor.shape
 61        new_h = h - (h % 14)
 62        new_w = w - (w % 14)
 63        return F.interpolate(tensor.unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False).squeeze(0)
 64    
 65    elif tensor.dim() == 4:
 66        b, c, h, w = tensor.shape
 67        new_h = h - (h % 14)
 68        new_w = w - (w % 14)
 69        return F.interpolate(tensor, size=(new_h, new_w), mode='bilinear', align_corners=False)
 70    
 71    else:
 72        raise ValueError("Tensor must be 3D (C, H, W) or 4D (B, C, H, W)")
 73
 74
 75
 76# ---------------------------
 77#        > Dataset <
 78# ---------------------------
 79class PhysGenDataset(Dataset):
 80    """
 81    PyTorch Dataset wrapper for the PhysicsGen dataset.
 82
 83    Loads the PhysGen dataset from Hugging Face and provides configurable
 84    input/output modes for physics-based generative learning tasks.
 85
 86    The dataset contains Open Sound Maps (OSM) and simulated soundmaps
 87    for tasks involving sound propagation modeling.
 88    """
 89    def __init__(self, variation="sound_baseline", mode="train", input_type="osm", output_type="standard", 
 90                 fake_rgb_output=False, make_14_dividable_size=False,
 91                 reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False):
 92        """
 93        Loads PhysGen Dataset.
 94
 95        Parameter:
 96        - variation (str, default='sound_baseline'): 
 97            Dataset variation to load. Options include:
 98            {'sound_baseline', 'sound_reflection', 'sound_diffraction', 'sound_combined'}.
 99        - mode (str, default='train'): 
100            Dataset split to use. Options: {'train', 'test', 'validation'}.
101        - input_type (str, default='osm'): 
102            Defines the input image type:
103            - 'osm': open sound map input.
104            - 'base_simulation': uses the baseline sound simulation as input.
105        - output_type (str, default='standard'): 
106            Defines the output image type:
107            - 'standard': full soundmap prediction.
108            - 'complex_only': difference from baseline soundmap.
109        - fake_rgb_output (bool, default=False): 
110            If True, replicates single-channel inputs to fake RGB (3-channel).
111        - make_14_dividable_size (bool, default=False): 
112            If True, resizes tensors so that height and width are divisible by 14.
113        - reflexion_channels (bool, default=False): 
114            If ray-traces should add to the input.
115        - reflexion_steps (int, default=36): 
116            Defines how many traces should get created.
117        - reflexions_as_channels (bool, default=False): 
118            If True, every trace gets its own channel, else every trace in one channel.
119        """
120        self.fake_rgb_output = fake_rgb_output
121        self.make_14_dividable_size = make_14_dividable_size
122        self.reflexion_channels = reflexion_channels
123        self.reflexion_steps = reflexion_steps
124        self.reflexions_as_channels = reflexions_as_channels
125
126        self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
127        # get data
128        self.dataset = load_dataset("mspitzna/physicsgen", name=variation, trust_remote_code=True)
129        # print("Keys:", self.dataset.keys())
130        self.dataset = self.dataset[mode]
131        self.mode = mode
132        
133        self.input_type = input_type
134        self.output_type = output_type
135        if self.input_type == "base_simulation" or self.output_type == "complex_only":
136            self.basesimulation_dataset = load_dataset("mspitzna/physicsgen", name="sound_baseline", trust_remote_code=True)
137            self.basesimulation_dataset = self.basesimulation_dataset[mode]
138
139        self.transform = transforms.Compose([
140            transforms.ToTensor(),  # Converts [0,255] PIL image to [0,1] FloatTensor
141        ])
142        print(f"PhysGen ({variation}) Dataset for {mode} got created")
143
144    def __len__(self):
145        """
146        Returns the number of available samples.
147
148        Returns:
149        - int: 
150            Number of samples in the dataset split.
151        """
152        return len(self.dataset)
153
154    def __getitem__(self, idx):
155        """
156        Retrieve an input-target pair from the dataset.
157
158        This function loads the input image and corresponding target image,
159        applies transformations (resizing, fake RGB, etc.), and returns
160        them as PyTorch tensors.
161
162        Parameter: 
163        - idx (int): 
164            Index of the data sample.
165
166        Returns: 
167        - tuple[torch.Tensor, torch.Tensor]<br>
168            A tuple containing:
169            - input_img : torch.Tensor<br>
170                Input tensor (shape: [C, H, W]).
171            - target_img : torch.Tensor<br>
172                Target tensor (shape: [C, H, W]).
173        """
174        sample = self.dataset[idx]
175        # print(sample)
176        # print(sample.keys())
177        if self.input_type == "base_simulation":
178            input_img = self.basesimulation_dataset[idx]["soundmap"]
179        else:
180            input_img = sample["osm"]  # PIL Image
181        target_img = sample["soundmap"]  # PIL Image
182
183        input_img = self.transform(input_img)
184        target_img = self.transform(target_img)
185
186        # Fix real image size 512x512 > 256x256
187        input_img = F.interpolate(input_img.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False)
188        input_img = input_img.squeeze(0)
189        # target_img = target_img.unsqueeze(0)
190
191        # change size
192        if self.make_14_dividable_size:
193            input_img = resize_tensor_to_divisible_by_14(input_img)
194            target_img = resize_tensor_to_divisible_by_14(target_img)
195
196        # add fake rgb
197        if self.fake_rgb_output and input_img.shape[0] == 1:  # shape (1, H, W)
198            input_img = input_img.repeat(3, 1, 1)  # make it (B, 3, H, W)
199
200        if self.output_type == "complex_only":
201            base_simulation_img = self.transform(self.basesimulation_dataset[idx]["soundmap"])
202            # base_simulation_img = resize_tensor_to_divisible_by_14(self.transform(self.basesimulation_dataset[idx]["soundmap"]))
203            # target_img = torch.abs(target_img[0] - base_simulation_img[0])
204            target_img = target_img[0] - base_simulation_img[0]
205            target_img = target_img.unsqueeze(0)
206            target_img *= -1
207
208        # add raytracing
209        if self.reflexion_channels:
210            ray_path = os.path.join("./rays", "train", str(self.reflexion_steps), f"rays_[{str(idx.item())}].txt")
211            if self.mode == "train" and os.path.exists(ray_path):
212                rays = ips.ray_tracing.open(path=ray_path)
213            else:
214                rays = ips.ray_tracing.trace_beams(rel_position=(0.5, 0.5),	
215                                                    img_src=np.squeeze(input_img.cpu().numpy(), axis=0),	
216                                                    directions_in_degree=ips.math.get_linear_degree_range(step_size=(self.reflexion_steps/360)*100),	
217                                                    wall_values=[0],	
218                                                    wall_thickness=0,	
219                                                    img_border_also_collide=False,	
220                                                    reflexion_order=3,	
221                                                    should_scale_rays=True,	
222                                                    should_scale_img=False)
223            ray_img = ips.ray_tracing.draw_rays(rays,	
224                                                detail_draw=False,	
225                                                output_format='channels' if self.reflexions_as_channels else 'single_image',	
226                                                img_background=None,	
227                                                ray_value=[50, 100, 255],	
228                                                ray_thickness=1,	
229                                                img_shape=(256, 256),
230                                                should_scale_rays_to_image=True,
231                                                show_only_reflections=True)
232            # (256, 256)
233            # print("CHECKPOINT")
234            # print(ray_img.shape)
235            ray_img = self.transform(ray_img)
236            ray_img = ray_img.float()
237            if ray_img.ndim == 2:
238                ray_img = ray_img.unsqueeze(0)  # (1, H, W)
239
240            # print(ray_img.shape)
241            # print(input_img.shape)
242            # Merging with input image 
243            if ray_img.shape[1:] == input_img.shape[1:]:
244                input_img = torch.cat((input_img, ray_img), dim=0)
245            else:
246                raise ValueError(f"Ray image shape {ray_img.shape} does not match input image shape {input_img.shape}.")
247
248        return input_img, target_img
249
250
251
252# ---------------------------
253#   > Helpful Functions <
254# ---------------------------
255# For external not internal
256
257def get_dataloader(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True):
258    """
259    Create a PyTorch DataLoader for the PhysGen dataset.
260
261    This helper simplifies loading the PhysGen dataset for training,
262    validation, or testing.
263
264    Parameter:
265    - mode (str, default='train'): 
266        Dataset split to use ('train', 'test', 'validation').
267    - variation (str, default='sound_reflection'): 
268        Dataset variation to load.
269    - input_type (str, default='osm'): 
270        Defines the input type ('osm' or 'base_simulation').
271    - output_type (str, default='complex_only'): 
272        Defines the output type ('standard' or 'complex_only').
273    - shuffle (bool, default=True): 
274        Whether to shuffle the dataset between epochs.
275
276    Returns:
277    - torch.utils.data.DataLoader: 
278        DataLoader that provides batches of PhysGen samples.
279    """
280    dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type)
281    return DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1)
282
283
284
285def get_image(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True, 
286              return_output=False, as_numpy_array=True):
287    """
288    Retrieve one image (input and optionally output) from the PhysGen dataset.
289
290    Provides an easy way to visualize or inspect a single PhysGen sample
291    without manually instantiating a DataLoader.
292
293    Parameter:
294    - mode (str, default='train'): 
295        Dataset split ('train', 'test', 'validation').
296    variation (str, default='sound_reflection'): 
297        Dataset variation.
298    input_type (str, default='osm'): 
299        Defines the input type ('osm' or 'base_simulation').
300    output_type (str, default='complex_only'): 
301        Defines the output type ('standard' or 'complex_only').
302    shuffle (bool, default=True): 
303        Randomly select the sample.
304    return_output (bool, default=False): 
305        If True, returns both input and target tensors.
306    as_numpy_array (bool, default=True): 
307        If True, converts tensors to NumPy arrays for easier visualization.
308
309    Returns: 
310    - numpy.ndarray or list[numpy.ndarray]: 
311        Input image as NumPy array, or a list [input, target] if `return_output` is True.
312    """
313    dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type)
314    loader = DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1)
315    cur_data = next(iter(loader))
316    input_ = cur_data[0]
317    output_ = cur_data[1]
318
319    if as_numpy_array:
320        input_ = input_.detach().cpu().numpy()
321        output_ = output_.detach().cpu().numpy()
322
323        # remove batch channel
324        input_ = np.squeeze(input_, axis=0)
325        output_ = np.squeeze(output_, axis=0)
326
327        if len(input_.shape) == 3:
328            input_ = np.squeeze(input_, axis=0)
329            output_ = np.squeeze(output_, axis=0)
330
331        input_ = np.transpose(input_, (1, 0))
332        output_ = np.transpose(output_, (1, 0))
333
334
335    result = input_
336    if return_output:
337        result = [input_, output_]
338
339    return result
340
341
342
343def save_dataset(output_real_path, output_osm_path, 
344                 variation, input_type, output_type,
345                 data_mode, 
346                 info_print=False, progress_print=True):
347    """
348    Save PhysGen dataset samples as images to disk.
349
350    This function loads the specified PhysGen dataset, converts input and
351    target tensors to images, and saves them as `.png` files for inspection,
352    debugging, or model-agnostic data use.
353
354    Parameter: 
355    - output_real_path (str): 
356        Directory to save target (real) soundmaps.
357    - output_osm_path (str): 
358        Directory to save input (OSM) maps.
359    - variation (str): 
360        Dataset variation (e.g. 'sound_reflection').
361    - input_type (str): 
362        Input type ('osm' or 'base_simulation').
363    - output_type (str): 
364        Output type ('standard' or 'complex_only').
365    - data_mode (str): 
366        Dataset split ('train', 'test', 'validation').
367    - info_print (bool, default=False): 
368        If True, prints detailed information for each saved sample.
369    - progress_print (bool, default=True): 
370        If True, shows progress updates in the console.
371
372    Raises:
373    - ValueError: 
374        If image data falls outside the valid range [0, 255].
375
376    """
377    # Clearing
378    if os.path.exists(output_osm_path) and os.path.isdir(output_osm_path):
379        shutil.rmtree(output_osm_path)
380        os.makedirs(output_osm_path)
381        print(f"Cleared {output_osm_path}.")
382    else:
383        os.makedirs(output_osm_path)
384        print(f"Created {output_osm_path}.")
385
386    if os.path.exists(output_real_path) and os.path.isdir(output_real_path):
387        shutil.rmtree(output_real_path)
388        os.makedirs(output_real_path)
389        print(f"Cleared {output_real_path}.")
390    else:
391        os.makedirs(output_real_path)
392        print(f"Created {output_real_path}.")
393    
394    # Load Dataset
395    dataset = PhysGenDataset(mode=data_mode, variation=variation, input_type=input_type, output_type=output_type)
396    data_len = len(dataset)
397    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
398
399    # Save Dataset
400    for i, data in enumerate(dataloader):
401        if progress_print:
402            prime.get_progress_bar(total=data_len, progress=i+1, 
403                                   should_clear=True, left_bar_char='|', right_bar_char='|', 
404                                   progress_char='#', empty_char=' ', 
405                                   front_message='Physgen Data Loading', back_message='', size=15)
406
407        input_img, target_img, idx = data
408        idx = idx[0].item() if isinstance(idx, torch.Tensor) else idx
409
410        if info_print:
411            print(f"Prediction shape [osm]: {input_img.shape}")
412            print(f"Prediction shape [target]: {target_img.shape}")
413
414            print(f"OSM Info:\n    -> shape: {input_img.shape}\n    -> min: {input_img.min()}, max: {input_img.max()}")
415
416        real_img = target_img.squeeze(0).cpu().squeeze(0).detach().numpy()
417        if not (0 <= real_img.min() <= 255 and 0 <= real_img.max() <=255):
418            raise ValueError(f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
419        if info_print:
420            print( f"\nReal target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
421        if real_img.max() <= 1.0:
422            real_img *= 255
423        if info_print:
424            print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
425        real_img = real_img.astype(np.uint8)
426        if info_print:
427            print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
428
429        if len(input_img.shape) == 4:
430            osm_img = input_img[0, 0].cpu().detach().numpy()
431        else:
432            osm_img = input_img[0].cpu().detach().numpy()
433        if not (0 <= osm_img.min() <= 255 and 0 <= osm_img.max() <=255):
434            raise ValueError(f"Real target has values out of 0-256 range => min:{osm_img.min()}, max:{osm_img.max()}")
435        if osm_img.max() <= 1.0:
436            osm_img *= 255
437        osm_img = osm_img.astype(np.uint8)
438
439        if info_print:
440            print(f"OSM Info:\n    -> shape: {osm_img.shape}\n    -> min: {osm_img.min()}, max: {osm_img.max()}")
441
442        # Save Results
443        file_name = f"physgen_{idx}.png"
444
445        # save pred image
446        # save_img = os.path.join(output_pred_path, file_name)
447        # cv2.imwrite(save_img, pred_img)
448        # print(f"    -> saved pred at {save_img}")
449
450        # save real image
451        save_img = os.path.join(output_real_path, "target_"+file_name)
452        cv2.imwrite(save_img, real_img)
453        if info_print:
454            print(f"    -> saved real at {save_img}")
455
456        # save osm image
457        save_img = os.path.join(output_osm_path, "input_"+file_name)
458        cv2.imwrite(save_img, osm_img)
459        if info_print:
460            print(f"    -> saved osm at {save_img}")
461    print(f"\nSuccessfull saved {data_len} datapoints into {os.path.abspath(output_real_path)} & {os.path.abspath(output_osm_path)}")
462
463
464
465# ---------------------------
466#     > Dataset Saving <
467# ---------------------------
468if __name__ == "__main__":
469    """
470    Command-line interface for saving PhysGen dataset samples.
471
472    Allows users to export the PhysGen dataset as image pairs for a given
473    variation, input/output configuration, and mode.
474
475    Example
476    -------
477    >>> python physgen_dataset_loader.py \
478            --output_real_path ./real \
479            --output_osm_path ./osm \
480            --variation sound_reflection \
481            --input_type osm \
482            --output_type standard \
483            --data_mode train
484    """
485    import argparse
486
487    parser = argparse.ArgumentParser(description="Save OSM and real PhysGen dataset images.")
488
489    parser.add_argument("--output_real_path", type=str, required=True, help="Path to save real target images")
490    parser.add_argument("--output_osm_path", type=str, required=True, help="Path to save OSM input images")
491    parser.add_argument("--variation", type=str, required=True, help="PhysGen variation (e.g. box_texture, box_position, etc.)")
492    parser.add_argument("--input_type", type=str, required=True, help="Input type (e.g. osm_depth)")
493    parser.add_argument("--output_type", type=str, required=True, help="Output type (e.g. real_depth)")
494    parser.add_argument("--data_mode", type=str, required=True, help="Data Mode: train, test, val")
495    parser.add_argument("--info_print", action="store_true", help="Print additional info")
496    parser.add_argument("--no_progress", action="store_true", help="Disable progress printing")
497
498    args = parser.parse_args()
499
500    save_dataset(
501        output_real_path=args.output_real_path,
502        output_osm_path=args.output_osm_path,
503        variation=args.variation,
504        input_type=args.input_type,
505        output_type=args.output_type,
506        data_mode=args.data_mode,
507        info_print=args.info_print,
508        progress_print=not args.no_progress
509    )
510
511    
def resize_tensor_to_divisible_by_14(tensor: torch.Tensor) -> torch.Tensor:
40def resize_tensor_to_divisible_by_14(tensor: torch.Tensor) -> torch.Tensor:
41    """
42    Resize a tensor so that its height and width are divisible by 14.
43
44    This function ensures the spatial dimensions (H, W) of a given tensor 
45    are compatible with architectures that require sizes divisible by 14 
46    (e.g., ResNet, ResFCN). It resizes using bilinear interpolation.
47
48    Parameter:
49    - tensor (torch.Tensor): 
50        Input tensor of shape (C, H, W) or (B, C, H, W).
51
52    Returns:
53    - torch.Tensor: 
54        Resized tensor with dimensions divisible by 14.
55
56    Raises:
57    - ValueError: 
58        If the tensor has neither 3 nor 4 dimensions.
59    """
60    if tensor.dim() == 3:
61        c, h, w = tensor.shape
62        new_h = h - (h % 14)
63        new_w = w - (w % 14)
64        return F.interpolate(tensor.unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False).squeeze(0)
65    
66    elif tensor.dim() == 4:
67        b, c, h, w = tensor.shape
68        new_h = h - (h % 14)
69        new_w = w - (w % 14)
70        return F.interpolate(tensor, size=(new_h, new_w), mode='bilinear', align_corners=False)
71    
72    else:
73        raise ValueError("Tensor must be 3D (C, H, W) or 4D (B, C, H, W)")

Resize a tensor so that its height and width are divisible by 14.

This function ensures the spatial dimensions (H, W) of a given tensor are compatible with architectures that require sizes divisible by 14 (e.g., ResNet, ResFCN). It resizes using bilinear interpolation.

Parameter:

  • tensor (torch.Tensor): Input tensor of shape (C, H, W) or (B, C, H, W).

Returns:

  • torch.Tensor: Resized tensor with dimensions divisible by 14.

Raises:

  • ValueError: If the tensor has neither 3 nor 4 dimensions.
class PhysGenDataset(typing.Generic[+_T_co]):
 80class PhysGenDataset(Dataset):
 81    """
 82    PyTorch Dataset wrapper for the PhysicsGen dataset.
 83
 84    Loads the PhysGen dataset from Hugging Face and provides configurable
 85    input/output modes for physics-based generative learning tasks.
 86
 87    The dataset contains Open Sound Maps (OSM) and simulated soundmaps
 88    for tasks involving sound propagation modeling.
 89    """
 90    def __init__(self, variation="sound_baseline", mode="train", input_type="osm", output_type="standard", 
 91                 fake_rgb_output=False, make_14_dividable_size=False,
 92                 reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False):
 93        """
 94        Loads PhysGen Dataset.
 95
 96        Parameter:
 97        - variation (str, default='sound_baseline'): 
 98            Dataset variation to load. Options include:
 99            {'sound_baseline', 'sound_reflection', 'sound_diffraction', 'sound_combined'}.
100        - mode (str, default='train'): 
101            Dataset split to use. Options: {'train', 'test', 'validation'}.
102        - input_type (str, default='osm'): 
103            Defines the input image type:
104            - 'osm': open sound map input.
105            - 'base_simulation': uses the baseline sound simulation as input.
106        - output_type (str, default='standard'): 
107            Defines the output image type:
108            - 'standard': full soundmap prediction.
109            - 'complex_only': difference from baseline soundmap.
110        - fake_rgb_output (bool, default=False): 
111            If True, replicates single-channel inputs to fake RGB (3-channel).
112        - make_14_dividable_size (bool, default=False): 
113            If True, resizes tensors so that height and width are divisible by 14.
114        - reflexion_channels (bool, default=False): 
115            If ray-traces should add to the input.
116        - reflexion_steps (int, default=36): 
117            Defines how many traces should get created.
118        - reflexions_as_channels (bool, default=False): 
119            If True, every trace gets its own channel, else every trace in one channel.
120        """
121        self.fake_rgb_output = fake_rgb_output
122        self.make_14_dividable_size = make_14_dividable_size
123        self.reflexion_channels = reflexion_channels
124        self.reflexion_steps = reflexion_steps
125        self.reflexions_as_channels = reflexions_as_channels
126
127        self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
128        # get data
129        self.dataset = load_dataset("mspitzna/physicsgen", name=variation, trust_remote_code=True)
130        # print("Keys:", self.dataset.keys())
131        self.dataset = self.dataset[mode]
132        self.mode = mode
133        
134        self.input_type = input_type
135        self.output_type = output_type
136        if self.input_type == "base_simulation" or self.output_type == "complex_only":
137            self.basesimulation_dataset = load_dataset("mspitzna/physicsgen", name="sound_baseline", trust_remote_code=True)
138            self.basesimulation_dataset = self.basesimulation_dataset[mode]
139
140        self.transform = transforms.Compose([
141            transforms.ToTensor(),  # Converts [0,255] PIL image to [0,1] FloatTensor
142        ])
143        print(f"PhysGen ({variation}) Dataset for {mode} got created")
144
145    def __len__(self):
146        """
147        Returns the number of available samples.
148
149        Returns:
150        - int: 
151            Number of samples in the dataset split.
152        """
153        return len(self.dataset)
154
155    def __getitem__(self, idx):
156        """
157        Retrieve an input-target pair from the dataset.
158
159        This function loads the input image and corresponding target image,
160        applies transformations (resizing, fake RGB, etc.), and returns
161        them as PyTorch tensors.
162
163        Parameter: 
164        - idx (int): 
165            Index of the data sample.
166
167        Returns: 
168        - tuple[torch.Tensor, torch.Tensor]<br>
169            A tuple containing:
170            - input_img : torch.Tensor<br>
171                Input tensor (shape: [C, H, W]).
172            - target_img : torch.Tensor<br>
173                Target tensor (shape: [C, H, W]).
174        """
175        sample = self.dataset[idx]
176        # print(sample)
177        # print(sample.keys())
178        if self.input_type == "base_simulation":
179            input_img = self.basesimulation_dataset[idx]["soundmap"]
180        else:
181            input_img = sample["osm"]  # PIL Image
182        target_img = sample["soundmap"]  # PIL Image
183
184        input_img = self.transform(input_img)
185        target_img = self.transform(target_img)
186
187        # Fix real image size 512x512 > 256x256
188        input_img = F.interpolate(input_img.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False)
189        input_img = input_img.squeeze(0)
190        # target_img = target_img.unsqueeze(0)
191
192        # change size
193        if self.make_14_dividable_size:
194            input_img = resize_tensor_to_divisible_by_14(input_img)
195            target_img = resize_tensor_to_divisible_by_14(target_img)
196
197        # add fake rgb
198        if self.fake_rgb_output and input_img.shape[0] == 1:  # shape (1, H, W)
199            input_img = input_img.repeat(3, 1, 1)  # make it (B, 3, H, W)
200
201        if self.output_type == "complex_only":
202            base_simulation_img = self.transform(self.basesimulation_dataset[idx]["soundmap"])
203            # base_simulation_img = resize_tensor_to_divisible_by_14(self.transform(self.basesimulation_dataset[idx]["soundmap"]))
204            # target_img = torch.abs(target_img[0] - base_simulation_img[0])
205            target_img = target_img[0] - base_simulation_img[0]
206            target_img = target_img.unsqueeze(0)
207            target_img *= -1
208
209        # add raytracing
210        if self.reflexion_channels:
211            ray_path = os.path.join("./rays", "train", str(self.reflexion_steps), f"rays_[{str(idx.item())}].txt")
212            if self.mode == "train" and os.path.exists(ray_path):
213                rays = ips.ray_tracing.open(path=ray_path)
214            else:
215                rays = ips.ray_tracing.trace_beams(rel_position=(0.5, 0.5),	
216                                                    img_src=np.squeeze(input_img.cpu().numpy(), axis=0),	
217                                                    directions_in_degree=ips.math.get_linear_degree_range(step_size=(self.reflexion_steps/360)*100),	
218                                                    wall_values=[0],	
219                                                    wall_thickness=0,	
220                                                    img_border_also_collide=False,	
221                                                    reflexion_order=3,	
222                                                    should_scale_rays=True,	
223                                                    should_scale_img=False)
224            ray_img = ips.ray_tracing.draw_rays(rays,	
225                                                detail_draw=False,	
226                                                output_format='channels' if self.reflexions_as_channels else 'single_image',	
227                                                img_background=None,	
228                                                ray_value=[50, 100, 255],	
229                                                ray_thickness=1,	
230                                                img_shape=(256, 256),
231                                                should_scale_rays_to_image=True,
232                                                show_only_reflections=True)
233            # (256, 256)
234            # print("CHECKPOINT")
235            # print(ray_img.shape)
236            ray_img = self.transform(ray_img)
237            ray_img = ray_img.float()
238            if ray_img.ndim == 2:
239                ray_img = ray_img.unsqueeze(0)  # (1, H, W)
240
241            # print(ray_img.shape)
242            # print(input_img.shape)
243            # Merging with input image 
244            if ray_img.shape[1:] == input_img.shape[1:]:
245                input_img = torch.cat((input_img, ray_img), dim=0)
246            else:
247                raise ValueError(f"Ray image shape {ray_img.shape} does not match input image shape {input_img.shape}.")
248
249        return input_img, target_img

PyTorch Dataset wrapper for the PhysicsGen dataset.

Loads the PhysGen dataset from Hugging Face and provides configurable input/output modes for physics-based generative learning tasks.

The dataset contains Open Sound Maps (OSM) and simulated soundmaps for tasks involving sound propagation modeling.

PhysGenDataset( variation='sound_baseline', mode='train', input_type='osm', output_type='standard', fake_rgb_output=False, make_14_dividable_size=False, reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False)
 90    def __init__(self, variation="sound_baseline", mode="train", input_type="osm", output_type="standard", 
 91                 fake_rgb_output=False, make_14_dividable_size=False,
 92                 reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False):
 93        """
 94        Loads PhysGen Dataset.
 95
 96        Parameter:
 97        - variation (str, default='sound_baseline'): 
 98            Dataset variation to load. Options include:
 99            {'sound_baseline', 'sound_reflection', 'sound_diffraction', 'sound_combined'}.
100        - mode (str, default='train'): 
101            Dataset split to use. Options: {'train', 'test', 'validation'}.
102        - input_type (str, default='osm'): 
103            Defines the input image type:
104            - 'osm': open sound map input.
105            - 'base_simulation': uses the baseline sound simulation as input.
106        - output_type (str, default='standard'): 
107            Defines the output image type:
108            - 'standard': full soundmap prediction.
109            - 'complex_only': difference from baseline soundmap.
110        - fake_rgb_output (bool, default=False): 
111            If True, replicates single-channel inputs to fake RGB (3-channel).
112        - make_14_dividable_size (bool, default=False): 
113            If True, resizes tensors so that height and width are divisible by 14.
114        - reflexion_channels (bool, default=False): 
115            If ray-traces should add to the input.
116        - reflexion_steps (int, default=36): 
117            Defines how many traces should get created.
118        - reflexions_as_channels (bool, default=False): 
119            If True, every trace gets its own channel, else every trace in one channel.
120        """
121        self.fake_rgb_output = fake_rgb_output
122        self.make_14_dividable_size = make_14_dividable_size
123        self.reflexion_channels = reflexion_channels
124        self.reflexion_steps = reflexion_steps
125        self.reflexions_as_channels = reflexions_as_channels
126
127        self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
128        # get data
129        self.dataset = load_dataset("mspitzna/physicsgen", name=variation, trust_remote_code=True)
130        # print("Keys:", self.dataset.keys())
131        self.dataset = self.dataset[mode]
132        self.mode = mode
133        
134        self.input_type = input_type
135        self.output_type = output_type
136        if self.input_type == "base_simulation" or self.output_type == "complex_only":
137            self.basesimulation_dataset = load_dataset("mspitzna/physicsgen", name="sound_baseline", trust_remote_code=True)
138            self.basesimulation_dataset = self.basesimulation_dataset[mode]
139
140        self.transform = transforms.Compose([
141            transforms.ToTensor(),  # Converts [0,255] PIL image to [0,1] FloatTensor
142        ])
143        print(f"PhysGen ({variation}) Dataset for {mode} got created")

Loads PhysGen Dataset.

Parameter:

  • variation (str, default='sound_baseline'): Dataset variation to load. Options include: {'sound_baseline', 'sound_reflection', 'sound_diffraction', 'sound_combined'}.
  • mode (str, default='train'): Dataset split to use. Options: {'train', 'test', 'validation'}.
  • input_type (str, default='osm'): Defines the input image type:
    • 'osm': open sound map input.
    • 'base_simulation': uses the baseline sound simulation as input.
  • output_type (str, default='standard'): Defines the output image type:
    • 'standard': full soundmap prediction.
    • 'complex_only': difference from baseline soundmap.
  • fake_rgb_output (bool, default=False): If True, replicates single-channel inputs to fake RGB (3-channel).
  • make_14_dividable_size (bool, default=False): If True, resizes tensors so that height and width are divisible by 14.
  • reflexion_channels (bool, default=False): If ray-traces should add to the input.
  • reflexion_steps (int, default=36): Defines how many traces should get created.
  • reflexions_as_channels (bool, default=False): If True, every trace gets its own channel, else every trace in one channel.
fake_rgb_output
make_14_dividable_size
reflexion_channels
reflexion_steps
reflexions_as_channels
device
dataset
mode
input_type
output_type
transform
def get_dataloader( mode='train', variation='sound_reflection', input_type='osm', output_type='complex_only', shuffle=True):
258def get_dataloader(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True):
259    """
260    Create a PyTorch DataLoader for the PhysGen dataset.
261
262    This helper simplifies loading the PhysGen dataset for training,
263    validation, or testing.
264
265    Parameter:
266    - mode (str, default='train'): 
267        Dataset split to use ('train', 'test', 'validation').
268    - variation (str, default='sound_reflection'): 
269        Dataset variation to load.
270    - input_type (str, default='osm'): 
271        Defines the input type ('osm' or 'base_simulation').
272    - output_type (str, default='complex_only'): 
273        Defines the output type ('standard' or 'complex_only').
274    - shuffle (bool, default=True): 
275        Whether to shuffle the dataset between epochs.
276
277    Returns:
278    - torch.utils.data.DataLoader: 
279        DataLoader that provides batches of PhysGen samples.
280    """
281    dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type)
282    return DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1)

Create a PyTorch DataLoader for the PhysGen dataset.

This helper simplifies loading the PhysGen dataset for training, validation, or testing.

Parameter:

  • mode (str, default='train'): Dataset split to use ('train', 'test', 'validation').
  • variation (str, default='sound_reflection'): Dataset variation to load.
  • input_type (str, default='osm'): Defines the input type ('osm' or 'base_simulation').
  • output_type (str, default='complex_only'): Defines the output type ('standard' or 'complex_only').
  • shuffle (bool, default=True): Whether to shuffle the dataset between epochs.

Returns:

  • torch.utils.data.DataLoader: DataLoader that provides batches of PhysGen samples.
def get_image( mode='train', variation='sound_reflection', input_type='osm', output_type='complex_only', shuffle=True, return_output=False, as_numpy_array=True):
286def get_image(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True, 
287              return_output=False, as_numpy_array=True):
288    """
289    Retrieve one image (input and optionally output) from the PhysGen dataset.
290
291    Provides an easy way to visualize or inspect a single PhysGen sample
292    without manually instantiating a DataLoader.
293
294    Parameter:
295    - mode (str, default='train'): 
296        Dataset split ('train', 'test', 'validation').
297    variation (str, default='sound_reflection'): 
298        Dataset variation.
299    input_type (str, default='osm'): 
300        Defines the input type ('osm' or 'base_simulation').
301    output_type (str, default='complex_only'): 
302        Defines the output type ('standard' or 'complex_only').
303    shuffle (bool, default=True): 
304        Randomly select the sample.
305    return_output (bool, default=False): 
306        If True, returns both input and target tensors.
307    as_numpy_array (bool, default=True): 
308        If True, converts tensors to NumPy arrays for easier visualization.
309
310    Returns: 
311    - numpy.ndarray or list[numpy.ndarray]: 
312        Input image as NumPy array, or a list [input, target] if `return_output` is True.
313    """
314    dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type)
315    loader = DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1)
316    cur_data = next(iter(loader))
317    input_ = cur_data[0]
318    output_ = cur_data[1]
319
320    if as_numpy_array:
321        input_ = input_.detach().cpu().numpy()
322        output_ = output_.detach().cpu().numpy()
323
324        # remove batch channel
325        input_ = np.squeeze(input_, axis=0)
326        output_ = np.squeeze(output_, axis=0)
327
328        if len(input_.shape) == 3:
329            input_ = np.squeeze(input_, axis=0)
330            output_ = np.squeeze(output_, axis=0)
331
332        input_ = np.transpose(input_, (1, 0))
333        output_ = np.transpose(output_, (1, 0))
334
335
336    result = input_
337    if return_output:
338        result = [input_, output_]
339
340    return result

Retrieve one image (input and optionally output) from the PhysGen dataset.

Provides an easy way to visualize or inspect a single PhysGen sample without manually instantiating a DataLoader.

Parameter:

  • mode (str, default='train'): Dataset split ('train', 'test', 'validation'). variation (str, default='sound_reflection'): Dataset variation. input_type (str, default='osm'): Defines the input type ('osm' or 'base_simulation'). output_type (str, default='complex_only'): Defines the output type ('standard' or 'complex_only'). shuffle (bool, default=True): Randomly select the sample. return_output (bool, default=False): If True, returns both input and target tensors. as_numpy_array (bool, default=True): If True, converts tensors to NumPy arrays for easier visualization.

Returns:

  • numpy.ndarray or list[numpy.ndarray]: Input image as NumPy array, or a list [input, target] if return_output is True.
def save_dataset( output_real_path, output_osm_path, variation, input_type, output_type, data_mode, info_print=False, progress_print=True):
344def save_dataset(output_real_path, output_osm_path, 
345                 variation, input_type, output_type,
346                 data_mode, 
347                 info_print=False, progress_print=True):
348    """
349    Save PhysGen dataset samples as images to disk.
350
351    This function loads the specified PhysGen dataset, converts input and
352    target tensors to images, and saves them as `.png` files for inspection,
353    debugging, or model-agnostic data use.
354
355    Parameter: 
356    - output_real_path (str): 
357        Directory to save target (real) soundmaps.
358    - output_osm_path (str): 
359        Directory to save input (OSM) maps.
360    - variation (str): 
361        Dataset variation (e.g. 'sound_reflection').
362    - input_type (str): 
363        Input type ('osm' or 'base_simulation').
364    - output_type (str): 
365        Output type ('standard' or 'complex_only').
366    - data_mode (str): 
367        Dataset split ('train', 'test', 'validation').
368    - info_print (bool, default=False): 
369        If True, prints detailed information for each saved sample.
370    - progress_print (bool, default=True): 
371        If True, shows progress updates in the console.
372
373    Raises:
374    - ValueError: 
375        If image data falls outside the valid range [0, 255].
376
377    """
378    # Clearing
379    if os.path.exists(output_osm_path) and os.path.isdir(output_osm_path):
380        shutil.rmtree(output_osm_path)
381        os.makedirs(output_osm_path)
382        print(f"Cleared {output_osm_path}.")
383    else:
384        os.makedirs(output_osm_path)
385        print(f"Created {output_osm_path}.")
386
387    if os.path.exists(output_real_path) and os.path.isdir(output_real_path):
388        shutil.rmtree(output_real_path)
389        os.makedirs(output_real_path)
390        print(f"Cleared {output_real_path}.")
391    else:
392        os.makedirs(output_real_path)
393        print(f"Created {output_real_path}.")
394    
395    # Load Dataset
396    dataset = PhysGenDataset(mode=data_mode, variation=variation, input_type=input_type, output_type=output_type)
397    data_len = len(dataset)
398    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
399
400    # Save Dataset
401    for i, data in enumerate(dataloader):
402        if progress_print:
403            prime.get_progress_bar(total=data_len, progress=i+1, 
404                                   should_clear=True, left_bar_char='|', right_bar_char='|', 
405                                   progress_char='#', empty_char=' ', 
406                                   front_message='Physgen Data Loading', back_message='', size=15)
407
408        input_img, target_img, idx = data
409        idx = idx[0].item() if isinstance(idx, torch.Tensor) else idx
410
411        if info_print:
412            print(f"Prediction shape [osm]: {input_img.shape}")
413            print(f"Prediction shape [target]: {target_img.shape}")
414
415            print(f"OSM Info:\n    -> shape: {input_img.shape}\n    -> min: {input_img.min()}, max: {input_img.max()}")
416
417        real_img = target_img.squeeze(0).cpu().squeeze(0).detach().numpy()
418        if not (0 <= real_img.min() <= 255 and 0 <= real_img.max() <=255):
419            raise ValueError(f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
420        if info_print:
421            print( f"\nReal target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
422        if real_img.max() <= 1.0:
423            real_img *= 255
424        if info_print:
425            print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
426        real_img = real_img.astype(np.uint8)
427        if info_print:
428            print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
429
430        if len(input_img.shape) == 4:
431            osm_img = input_img[0, 0].cpu().detach().numpy()
432        else:
433            osm_img = input_img[0].cpu().detach().numpy()
434        if not (0 <= osm_img.min() <= 255 and 0 <= osm_img.max() <=255):
435            raise ValueError(f"Real target has values out of 0-256 range => min:{osm_img.min()}, max:{osm_img.max()}")
436        if osm_img.max() <= 1.0:
437            osm_img *= 255
438        osm_img = osm_img.astype(np.uint8)
439
440        if info_print:
441            print(f"OSM Info:\n    -> shape: {osm_img.shape}\n    -> min: {osm_img.min()}, max: {osm_img.max()}")
442
443        # Save Results
444        file_name = f"physgen_{idx}.png"
445
446        # save pred image
447        # save_img = os.path.join(output_pred_path, file_name)
448        # cv2.imwrite(save_img, pred_img)
449        # print(f"    -> saved pred at {save_img}")
450
451        # save real image
452        save_img = os.path.join(output_real_path, "target_"+file_name)
453        cv2.imwrite(save_img, real_img)
454        if info_print:
455            print(f"    -> saved real at {save_img}")
456
457        # save osm image
458        save_img = os.path.join(output_osm_path, "input_"+file_name)
459        cv2.imwrite(save_img, osm_img)
460        if info_print:
461            print(f"    -> saved osm at {save_img}")
462    print(f"\nSuccessfull saved {data_len} datapoints into {os.path.abspath(output_real_path)} & {os.path.abspath(output_osm_path)}")

Save PhysGen dataset samples as images to disk.

This function loads the specified PhysGen dataset, converts input and target tensors to images, and saves them as .png files for inspection, debugging, or model-agnostic data use.

Parameter:

  • output_real_path (str): Directory to save target (real) soundmaps.
  • output_osm_path (str): Directory to save input (OSM) maps.
  • variation (str): Dataset variation (e.g. 'sound_reflection').
  • input_type (str): Input type ('osm' or 'base_simulation').
  • output_type (str): Output type ('standard' or 'complex_only').
  • data_mode (str): Dataset split ('train', 'test', 'validation').
  • info_print (bool, default=False): If True, prints detailed information for each saved sample.
  • progress_print (bool, default=True): If True, shows progress updates in the console.

Raises:

  • ValueError: If image data falls outside the valid range [0, 255].