image_to_image.data.physgen
PhysGen Dataset Loader
PyTorch DataLoader.
Also provide some functions for downloading the dataset.
See:
1""" 2PhysGen Dataset Loader 3 4PyTorch DataLoader. 5 6Also provide some functions for downloading the dataset. 7 8See: 9- https://huggingface.co/datasets/mspitzna/physicsgen 10- https://arxiv.org/abs/2503.05333 11- https://github.com/physicsgen/physicsgen 12""" 13# --------------------------- 14# > Imports < 15# --------------------------- 16import os 17import shutil 18from PIL import Image 19 20from datasets import load_dataset 21 22import numpy as np 23import cv2 24 25import torch 26import torch.nn.functional as F 27from torch.utils.data import DataLoader, Dataset 28# import torchvision.transforms as transforms 29from torchvision import transforms 30 31import img_phy_sim as ips 32import prime_printer as prime 33 34 35 36# --------------------------- 37# > Helper < 38# --------------------------- 39def resize_tensor_to_divisible_by_14(tensor: torch.Tensor) -> torch.Tensor: 40 """ 41 Resize a tensor so that its height and width are divisible by 14. 42 43 This function ensures the spatial dimensions (H, W) of a given tensor 44 are compatible with architectures that require sizes divisible by 14 45 (e.g., ResNet, ResFCN). It resizes using bilinear interpolation. 46 47 Parameter: 48 - tensor (torch.Tensor): 49 Input tensor of shape (C, H, W) or (B, C, H, W). 50 51 Returns: 52 - torch.Tensor: 53 Resized tensor with dimensions divisible by 14. 54 55 Raises: 56 - ValueError: 57 If the tensor has neither 3 nor 4 dimensions. 58 """ 59 if tensor.dim() == 3: 60 c, h, w = tensor.shape 61 new_h = h - (h % 14) 62 new_w = w - (w % 14) 63 return F.interpolate(tensor.unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False).squeeze(0) 64 65 elif tensor.dim() == 4: 66 b, c, h, w = tensor.shape 67 new_h = h - (h % 14) 68 new_w = w - (w % 14) 69 return F.interpolate(tensor, size=(new_h, new_w), mode='bilinear', align_corners=False) 70 71 else: 72 raise ValueError("Tensor must be 3D (C, H, W) or 4D (B, C, H, W)") 73 74 75 76# --------------------------- 77# > Dataset < 78# --------------------------- 79class PhysGenDataset(Dataset): 80 """ 81 PyTorch Dataset wrapper for the PhysicsGen dataset. 82 83 Loads the PhysGen dataset from Hugging Face and provides configurable 84 input/output modes for physics-based generative learning tasks. 85 86 The dataset contains Open Sound Maps (OSM) and simulated soundmaps 87 for tasks involving sound propagation modeling. 88 """ 89 def __init__(self, variation="sound_baseline", mode="train", input_type="osm", output_type="standard", 90 fake_rgb_output=False, make_14_dividable_size=False, 91 reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False): 92 """ 93 Loads PhysGen Dataset. 94 95 Parameter: 96 - variation (str, default='sound_baseline'): 97 Dataset variation to load. Options include: 98 {'sound_baseline', 'sound_reflection', 'sound_diffraction', 'sound_combined'}. 99 - mode (str, default='train'): 100 Dataset split to use. Options: {'train', 'test', 'validation'}. 101 - input_type (str, default='osm'): 102 Defines the input image type: 103 - 'osm': open sound map input. 104 - 'base_simulation': uses the baseline sound simulation as input. 105 - output_type (str, default='standard'): 106 Defines the output image type: 107 - 'standard': full soundmap prediction. 108 - 'complex_only': difference from baseline soundmap. 109 - fake_rgb_output (bool, default=False): 110 If True, replicates single-channel inputs to fake RGB (3-channel). 111 - make_14_dividable_size (bool, default=False): 112 If True, resizes tensors so that height and width are divisible by 14. 113 - reflexion_channels (bool, default=False): 114 If ray-traces should add to the input. 115 - reflexion_steps (int, default=36): 116 Defines how many traces should get created. 117 - reflexions_as_channels (bool, default=False): 118 If True, every trace gets its own channel, else every trace in one channel. 119 """ 120 self.fake_rgb_output = fake_rgb_output 121 self.make_14_dividable_size = make_14_dividable_size 122 self.reflexion_channels = reflexion_channels 123 self.reflexion_steps = reflexion_steps 124 self.reflexions_as_channels = reflexions_as_channels 125 126 self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' 127 # get data 128 self.dataset = load_dataset("mspitzna/physicsgen", name=variation, trust_remote_code=True) 129 # print("Keys:", self.dataset.keys()) 130 self.dataset = self.dataset[mode] 131 self.mode = mode 132 133 self.input_type = input_type 134 self.output_type = output_type 135 if self.input_type == "base_simulation" or self.output_type == "complex_only": 136 self.basesimulation_dataset = load_dataset("mspitzna/physicsgen", name="sound_baseline", trust_remote_code=True) 137 self.basesimulation_dataset = self.basesimulation_dataset[mode] 138 139 self.transform = transforms.Compose([ 140 transforms.ToTensor(), # Converts [0,255] PIL image to [0,1] FloatTensor 141 ]) 142 print(f"PhysGen ({variation}) Dataset for {mode} got created") 143 144 def __len__(self): 145 """ 146 Returns the number of available samples. 147 148 Returns: 149 - int: 150 Number of samples in the dataset split. 151 """ 152 return len(self.dataset) 153 154 def __getitem__(self, idx): 155 """ 156 Retrieve an input-target pair from the dataset. 157 158 This function loads the input image and corresponding target image, 159 applies transformations (resizing, fake RGB, etc.), and returns 160 them as PyTorch tensors. 161 162 Parameter: 163 - idx (int): 164 Index of the data sample. 165 166 Returns: 167 - tuple[torch.Tensor, torch.Tensor]<br> 168 A tuple containing: 169 - input_img : torch.Tensor<br> 170 Input tensor (shape: [C, H, W]). 171 - target_img : torch.Tensor<br> 172 Target tensor (shape: [C, H, W]). 173 """ 174 sample = self.dataset[idx] 175 # print(sample) 176 # print(sample.keys()) 177 if self.input_type == "base_simulation": 178 input_img = self.basesimulation_dataset[idx]["soundmap"] 179 else: 180 input_img = sample["osm"] # PIL Image 181 target_img = sample["soundmap"] # PIL Image 182 183 input_img = self.transform(input_img) 184 target_img = self.transform(target_img) 185 186 # Fix real image size 512x512 > 256x256 187 input_img = F.interpolate(input_img.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False) 188 input_img = input_img.squeeze(0) 189 # target_img = target_img.unsqueeze(0) 190 191 # change size 192 if self.make_14_dividable_size: 193 input_img = resize_tensor_to_divisible_by_14(input_img) 194 target_img = resize_tensor_to_divisible_by_14(target_img) 195 196 # add fake rgb 197 if self.fake_rgb_output and input_img.shape[0] == 1: # shape (1, H, W) 198 input_img = input_img.repeat(3, 1, 1) # make it (B, 3, H, W) 199 200 if self.output_type == "complex_only": 201 base_simulation_img = self.transform(self.basesimulation_dataset[idx]["soundmap"]) 202 # base_simulation_img = resize_tensor_to_divisible_by_14(self.transform(self.basesimulation_dataset[idx]["soundmap"])) 203 # target_img = torch.abs(target_img[0] - base_simulation_img[0]) 204 target_img = target_img[0] - base_simulation_img[0] 205 target_img = target_img.unsqueeze(0) 206 target_img *= -1 207 208 # add raytracing 209 if self.reflexion_channels: 210 ray_path = os.path.join("./rays", "train", str(self.reflexion_steps), f"rays_[{str(idx.item())}].txt") 211 if self.mode == "train" and os.path.exists(ray_path): 212 rays = ips.ray_tracing.open(path=ray_path) 213 else: 214 rays = ips.ray_tracing.trace_beams(rel_position=(0.5, 0.5), 215 img_src=np.squeeze(input_img.cpu().numpy(), axis=0), 216 directions_in_degree=ips.math.get_linear_degree_range(step_size=(self.reflexion_steps/360)*100), 217 wall_values=[0], 218 wall_thickness=0, 219 img_border_also_collide=False, 220 reflexion_order=3, 221 should_scale_rays=True, 222 should_scale_img=False) 223 ray_img = ips.ray_tracing.draw_rays(rays, 224 detail_draw=False, 225 output_format='channels' if self.reflexions_as_channels else 'single_image', 226 img_background=None, 227 ray_value=[50, 100, 255], 228 ray_thickness=1, 229 img_shape=(256, 256), 230 should_scale_rays_to_image=True, 231 show_only_reflections=True) 232 # (256, 256) 233 # print("CHECKPOINT") 234 # print(ray_img.shape) 235 ray_img = self.transform(ray_img) 236 ray_img = ray_img.float() 237 if ray_img.ndim == 2: 238 ray_img = ray_img.unsqueeze(0) # (1, H, W) 239 240 # print(ray_img.shape) 241 # print(input_img.shape) 242 # Merging with input image 243 if ray_img.shape[1:] == input_img.shape[1:]: 244 input_img = torch.cat((input_img, ray_img), dim=0) 245 else: 246 raise ValueError(f"Ray image shape {ray_img.shape} does not match input image shape {input_img.shape}.") 247 248 return input_img, target_img 249 250 251 252# --------------------------- 253# > Helpful Functions < 254# --------------------------- 255# For external not internal 256 257def get_dataloader(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True): 258 """ 259 Create a PyTorch DataLoader for the PhysGen dataset. 260 261 This helper simplifies loading the PhysGen dataset for training, 262 validation, or testing. 263 264 Parameter: 265 - mode (str, default='train'): 266 Dataset split to use ('train', 'test', 'validation'). 267 - variation (str, default='sound_reflection'): 268 Dataset variation to load. 269 - input_type (str, default='osm'): 270 Defines the input type ('osm' or 'base_simulation'). 271 - output_type (str, default='complex_only'): 272 Defines the output type ('standard' or 'complex_only'). 273 - shuffle (bool, default=True): 274 Whether to shuffle the dataset between epochs. 275 276 Returns: 277 - torch.utils.data.DataLoader: 278 DataLoader that provides batches of PhysGen samples. 279 """ 280 dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type) 281 return DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1) 282 283 284 285def get_image(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True, 286 return_output=False, as_numpy_array=True): 287 """ 288 Retrieve one image (input and optionally output) from the PhysGen dataset. 289 290 Provides an easy way to visualize or inspect a single PhysGen sample 291 without manually instantiating a DataLoader. 292 293 Parameter: 294 - mode (str, default='train'): 295 Dataset split ('train', 'test', 'validation'). 296 variation (str, default='sound_reflection'): 297 Dataset variation. 298 input_type (str, default='osm'): 299 Defines the input type ('osm' or 'base_simulation'). 300 output_type (str, default='complex_only'): 301 Defines the output type ('standard' or 'complex_only'). 302 shuffle (bool, default=True): 303 Randomly select the sample. 304 return_output (bool, default=False): 305 If True, returns both input and target tensors. 306 as_numpy_array (bool, default=True): 307 If True, converts tensors to NumPy arrays for easier visualization. 308 309 Returns: 310 - numpy.ndarray or list[numpy.ndarray]: 311 Input image as NumPy array, or a list [input, target] if `return_output` is True. 312 """ 313 dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type) 314 loader = DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1) 315 cur_data = next(iter(loader)) 316 input_ = cur_data[0] 317 output_ = cur_data[1] 318 319 if as_numpy_array: 320 input_ = input_.detach().cpu().numpy() 321 output_ = output_.detach().cpu().numpy() 322 323 # remove batch channel 324 input_ = np.squeeze(input_, axis=0) 325 output_ = np.squeeze(output_, axis=0) 326 327 if len(input_.shape) == 3: 328 input_ = np.squeeze(input_, axis=0) 329 output_ = np.squeeze(output_, axis=0) 330 331 input_ = np.transpose(input_, (1, 0)) 332 output_ = np.transpose(output_, (1, 0)) 333 334 335 result = input_ 336 if return_output: 337 result = [input_, output_] 338 339 return result 340 341 342 343def save_dataset(output_real_path, output_osm_path, 344 variation, input_type, output_type, 345 data_mode, 346 info_print=False, progress_print=True): 347 """ 348 Save PhysGen dataset samples as images to disk. 349 350 This function loads the specified PhysGen dataset, converts input and 351 target tensors to images, and saves them as `.png` files for inspection, 352 debugging, or model-agnostic data use. 353 354 Parameter: 355 - output_real_path (str): 356 Directory to save target (real) soundmaps. 357 - output_osm_path (str): 358 Directory to save input (OSM) maps. 359 - variation (str): 360 Dataset variation (e.g. 'sound_reflection'). 361 - input_type (str): 362 Input type ('osm' or 'base_simulation'). 363 - output_type (str): 364 Output type ('standard' or 'complex_only'). 365 - data_mode (str): 366 Dataset split ('train', 'test', 'validation'). 367 - info_print (bool, default=False): 368 If True, prints detailed information for each saved sample. 369 - progress_print (bool, default=True): 370 If True, shows progress updates in the console. 371 372 Raises: 373 - ValueError: 374 If image data falls outside the valid range [0, 255]. 375 376 """ 377 # Clearing 378 if os.path.exists(output_osm_path) and os.path.isdir(output_osm_path): 379 shutil.rmtree(output_osm_path) 380 os.makedirs(output_osm_path) 381 print(f"Cleared {output_osm_path}.") 382 else: 383 os.makedirs(output_osm_path) 384 print(f"Created {output_osm_path}.") 385 386 if os.path.exists(output_real_path) and os.path.isdir(output_real_path): 387 shutil.rmtree(output_real_path) 388 os.makedirs(output_real_path) 389 print(f"Cleared {output_real_path}.") 390 else: 391 os.makedirs(output_real_path) 392 print(f"Created {output_real_path}.") 393 394 # Load Dataset 395 dataset = PhysGenDataset(mode=data_mode, variation=variation, input_type=input_type, output_type=output_type) 396 data_len = len(dataset) 397 dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) 398 399 # Save Dataset 400 for i, data in enumerate(dataloader): 401 if progress_print: 402 prime.get_progress_bar(total=data_len, progress=i+1, 403 should_clear=True, left_bar_char='|', right_bar_char='|', 404 progress_char='#', empty_char=' ', 405 front_message='Physgen Data Loading', back_message='', size=15) 406 407 input_img, target_img, idx = data 408 idx = idx[0].item() if isinstance(idx, torch.Tensor) else idx 409 410 if info_print: 411 print(f"Prediction shape [osm]: {input_img.shape}") 412 print(f"Prediction shape [target]: {target_img.shape}") 413 414 print(f"OSM Info:\n -> shape: {input_img.shape}\n -> min: {input_img.min()}, max: {input_img.max()}") 415 416 real_img = target_img.squeeze(0).cpu().squeeze(0).detach().numpy() 417 if not (0 <= real_img.min() <= 255 and 0 <= real_img.max() <=255): 418 raise ValueError(f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 419 if info_print: 420 print( f"\nReal target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 421 if real_img.max() <= 1.0: 422 real_img *= 255 423 if info_print: 424 print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 425 real_img = real_img.astype(np.uint8) 426 if info_print: 427 print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 428 429 if len(input_img.shape) == 4: 430 osm_img = input_img[0, 0].cpu().detach().numpy() 431 else: 432 osm_img = input_img[0].cpu().detach().numpy() 433 if not (0 <= osm_img.min() <= 255 and 0 <= osm_img.max() <=255): 434 raise ValueError(f"Real target has values out of 0-256 range => min:{osm_img.min()}, max:{osm_img.max()}") 435 if osm_img.max() <= 1.0: 436 osm_img *= 255 437 osm_img = osm_img.astype(np.uint8) 438 439 if info_print: 440 print(f"OSM Info:\n -> shape: {osm_img.shape}\n -> min: {osm_img.min()}, max: {osm_img.max()}") 441 442 # Save Results 443 file_name = f"physgen_{idx}.png" 444 445 # save pred image 446 # save_img = os.path.join(output_pred_path, file_name) 447 # cv2.imwrite(save_img, pred_img) 448 # print(f" -> saved pred at {save_img}") 449 450 # save real image 451 save_img = os.path.join(output_real_path, "target_"+file_name) 452 cv2.imwrite(save_img, real_img) 453 if info_print: 454 print(f" -> saved real at {save_img}") 455 456 # save osm image 457 save_img = os.path.join(output_osm_path, "input_"+file_name) 458 cv2.imwrite(save_img, osm_img) 459 if info_print: 460 print(f" -> saved osm at {save_img}") 461 print(f"\nSuccessfull saved {data_len} datapoints into {os.path.abspath(output_real_path)} & {os.path.abspath(output_osm_path)}") 462 463 464 465# --------------------------- 466# > Dataset Saving < 467# --------------------------- 468if __name__ == "__main__": 469 """ 470 Command-line interface for saving PhysGen dataset samples. 471 472 Allows users to export the PhysGen dataset as image pairs for a given 473 variation, input/output configuration, and mode. 474 475 Example 476 ------- 477 >>> python physgen_dataset_loader.py \ 478 --output_real_path ./real \ 479 --output_osm_path ./osm \ 480 --variation sound_reflection \ 481 --input_type osm \ 482 --output_type standard \ 483 --data_mode train 484 """ 485 import argparse 486 487 parser = argparse.ArgumentParser(description="Save OSM and real PhysGen dataset images.") 488 489 parser.add_argument("--output_real_path", type=str, required=True, help="Path to save real target images") 490 parser.add_argument("--output_osm_path", type=str, required=True, help="Path to save OSM input images") 491 parser.add_argument("--variation", type=str, required=True, help="PhysGen variation (e.g. box_texture, box_position, etc.)") 492 parser.add_argument("--input_type", type=str, required=True, help="Input type (e.g. osm_depth)") 493 parser.add_argument("--output_type", type=str, required=True, help="Output type (e.g. real_depth)") 494 parser.add_argument("--data_mode", type=str, required=True, help="Data Mode: train, test, val") 495 parser.add_argument("--info_print", action="store_true", help="Print additional info") 496 parser.add_argument("--no_progress", action="store_true", help="Disable progress printing") 497 498 args = parser.parse_args() 499 500 save_dataset( 501 output_real_path=args.output_real_path, 502 output_osm_path=args.output_osm_path, 503 variation=args.variation, 504 input_type=args.input_type, 505 output_type=args.output_type, 506 data_mode=args.data_mode, 507 info_print=args.info_print, 508 progress_print=not args.no_progress 509 ) 510 511
40def resize_tensor_to_divisible_by_14(tensor: torch.Tensor) -> torch.Tensor: 41 """ 42 Resize a tensor so that its height and width are divisible by 14. 43 44 This function ensures the spatial dimensions (H, W) of a given tensor 45 are compatible with architectures that require sizes divisible by 14 46 (e.g., ResNet, ResFCN). It resizes using bilinear interpolation. 47 48 Parameter: 49 - tensor (torch.Tensor): 50 Input tensor of shape (C, H, W) or (B, C, H, W). 51 52 Returns: 53 - torch.Tensor: 54 Resized tensor with dimensions divisible by 14. 55 56 Raises: 57 - ValueError: 58 If the tensor has neither 3 nor 4 dimensions. 59 """ 60 if tensor.dim() == 3: 61 c, h, w = tensor.shape 62 new_h = h - (h % 14) 63 new_w = w - (w % 14) 64 return F.interpolate(tensor.unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False).squeeze(0) 65 66 elif tensor.dim() == 4: 67 b, c, h, w = tensor.shape 68 new_h = h - (h % 14) 69 new_w = w - (w % 14) 70 return F.interpolate(tensor, size=(new_h, new_w), mode='bilinear', align_corners=False) 71 72 else: 73 raise ValueError("Tensor must be 3D (C, H, W) or 4D (B, C, H, W)")
Resize a tensor so that its height and width are divisible by 14.
This function ensures the spatial dimensions (H, W) of a given tensor are compatible with architectures that require sizes divisible by 14 (e.g., ResNet, ResFCN). It resizes using bilinear interpolation.
Parameter:
- tensor (torch.Tensor): Input tensor of shape (C, H, W) or (B, C, H, W).
Returns:
- torch.Tensor: Resized tensor with dimensions divisible by 14.
Raises:
- ValueError: If the tensor has neither 3 nor 4 dimensions.
80class PhysGenDataset(Dataset): 81 """ 82 PyTorch Dataset wrapper for the PhysicsGen dataset. 83 84 Loads the PhysGen dataset from Hugging Face and provides configurable 85 input/output modes for physics-based generative learning tasks. 86 87 The dataset contains Open Sound Maps (OSM) and simulated soundmaps 88 for tasks involving sound propagation modeling. 89 """ 90 def __init__(self, variation="sound_baseline", mode="train", input_type="osm", output_type="standard", 91 fake_rgb_output=False, make_14_dividable_size=False, 92 reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False): 93 """ 94 Loads PhysGen Dataset. 95 96 Parameter: 97 - variation (str, default='sound_baseline'): 98 Dataset variation to load. Options include: 99 {'sound_baseline', 'sound_reflection', 'sound_diffraction', 'sound_combined'}. 100 - mode (str, default='train'): 101 Dataset split to use. Options: {'train', 'test', 'validation'}. 102 - input_type (str, default='osm'): 103 Defines the input image type: 104 - 'osm': open sound map input. 105 - 'base_simulation': uses the baseline sound simulation as input. 106 - output_type (str, default='standard'): 107 Defines the output image type: 108 - 'standard': full soundmap prediction. 109 - 'complex_only': difference from baseline soundmap. 110 - fake_rgb_output (bool, default=False): 111 If True, replicates single-channel inputs to fake RGB (3-channel). 112 - make_14_dividable_size (bool, default=False): 113 If True, resizes tensors so that height and width are divisible by 14. 114 - reflexion_channels (bool, default=False): 115 If ray-traces should add to the input. 116 - reflexion_steps (int, default=36): 117 Defines how many traces should get created. 118 - reflexions_as_channels (bool, default=False): 119 If True, every trace gets its own channel, else every trace in one channel. 120 """ 121 self.fake_rgb_output = fake_rgb_output 122 self.make_14_dividable_size = make_14_dividable_size 123 self.reflexion_channels = reflexion_channels 124 self.reflexion_steps = reflexion_steps 125 self.reflexions_as_channels = reflexions_as_channels 126 127 self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' 128 # get data 129 self.dataset = load_dataset("mspitzna/physicsgen", name=variation, trust_remote_code=True) 130 # print("Keys:", self.dataset.keys()) 131 self.dataset = self.dataset[mode] 132 self.mode = mode 133 134 self.input_type = input_type 135 self.output_type = output_type 136 if self.input_type == "base_simulation" or self.output_type == "complex_only": 137 self.basesimulation_dataset = load_dataset("mspitzna/physicsgen", name="sound_baseline", trust_remote_code=True) 138 self.basesimulation_dataset = self.basesimulation_dataset[mode] 139 140 self.transform = transforms.Compose([ 141 transforms.ToTensor(), # Converts [0,255] PIL image to [0,1] FloatTensor 142 ]) 143 print(f"PhysGen ({variation}) Dataset for {mode} got created") 144 145 def __len__(self): 146 """ 147 Returns the number of available samples. 148 149 Returns: 150 - int: 151 Number of samples in the dataset split. 152 """ 153 return len(self.dataset) 154 155 def __getitem__(self, idx): 156 """ 157 Retrieve an input-target pair from the dataset. 158 159 This function loads the input image and corresponding target image, 160 applies transformations (resizing, fake RGB, etc.), and returns 161 them as PyTorch tensors. 162 163 Parameter: 164 - idx (int): 165 Index of the data sample. 166 167 Returns: 168 - tuple[torch.Tensor, torch.Tensor]<br> 169 A tuple containing: 170 - input_img : torch.Tensor<br> 171 Input tensor (shape: [C, H, W]). 172 - target_img : torch.Tensor<br> 173 Target tensor (shape: [C, H, W]). 174 """ 175 sample = self.dataset[idx] 176 # print(sample) 177 # print(sample.keys()) 178 if self.input_type == "base_simulation": 179 input_img = self.basesimulation_dataset[idx]["soundmap"] 180 else: 181 input_img = sample["osm"] # PIL Image 182 target_img = sample["soundmap"] # PIL Image 183 184 input_img = self.transform(input_img) 185 target_img = self.transform(target_img) 186 187 # Fix real image size 512x512 > 256x256 188 input_img = F.interpolate(input_img.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False) 189 input_img = input_img.squeeze(0) 190 # target_img = target_img.unsqueeze(0) 191 192 # change size 193 if self.make_14_dividable_size: 194 input_img = resize_tensor_to_divisible_by_14(input_img) 195 target_img = resize_tensor_to_divisible_by_14(target_img) 196 197 # add fake rgb 198 if self.fake_rgb_output and input_img.shape[0] == 1: # shape (1, H, W) 199 input_img = input_img.repeat(3, 1, 1) # make it (B, 3, H, W) 200 201 if self.output_type == "complex_only": 202 base_simulation_img = self.transform(self.basesimulation_dataset[idx]["soundmap"]) 203 # base_simulation_img = resize_tensor_to_divisible_by_14(self.transform(self.basesimulation_dataset[idx]["soundmap"])) 204 # target_img = torch.abs(target_img[0] - base_simulation_img[0]) 205 target_img = target_img[0] - base_simulation_img[0] 206 target_img = target_img.unsqueeze(0) 207 target_img *= -1 208 209 # add raytracing 210 if self.reflexion_channels: 211 ray_path = os.path.join("./rays", "train", str(self.reflexion_steps), f"rays_[{str(idx.item())}].txt") 212 if self.mode == "train" and os.path.exists(ray_path): 213 rays = ips.ray_tracing.open(path=ray_path) 214 else: 215 rays = ips.ray_tracing.trace_beams(rel_position=(0.5, 0.5), 216 img_src=np.squeeze(input_img.cpu().numpy(), axis=0), 217 directions_in_degree=ips.math.get_linear_degree_range(step_size=(self.reflexion_steps/360)*100), 218 wall_values=[0], 219 wall_thickness=0, 220 img_border_also_collide=False, 221 reflexion_order=3, 222 should_scale_rays=True, 223 should_scale_img=False) 224 ray_img = ips.ray_tracing.draw_rays(rays, 225 detail_draw=False, 226 output_format='channels' if self.reflexions_as_channels else 'single_image', 227 img_background=None, 228 ray_value=[50, 100, 255], 229 ray_thickness=1, 230 img_shape=(256, 256), 231 should_scale_rays_to_image=True, 232 show_only_reflections=True) 233 # (256, 256) 234 # print("CHECKPOINT") 235 # print(ray_img.shape) 236 ray_img = self.transform(ray_img) 237 ray_img = ray_img.float() 238 if ray_img.ndim == 2: 239 ray_img = ray_img.unsqueeze(0) # (1, H, W) 240 241 # print(ray_img.shape) 242 # print(input_img.shape) 243 # Merging with input image 244 if ray_img.shape[1:] == input_img.shape[1:]: 245 input_img = torch.cat((input_img, ray_img), dim=0) 246 else: 247 raise ValueError(f"Ray image shape {ray_img.shape} does not match input image shape {input_img.shape}.") 248 249 return input_img, target_img
PyTorch Dataset wrapper for the PhysicsGen dataset.
Loads the PhysGen dataset from Hugging Face and provides configurable input/output modes for physics-based generative learning tasks.
The dataset contains Open Sound Maps (OSM) and simulated soundmaps for tasks involving sound propagation modeling.
90 def __init__(self, variation="sound_baseline", mode="train", input_type="osm", output_type="standard", 91 fake_rgb_output=False, make_14_dividable_size=False, 92 reflexion_channels=False, reflexion_steps=36, reflexions_as_channels=False): 93 """ 94 Loads PhysGen Dataset. 95 96 Parameter: 97 - variation (str, default='sound_baseline'): 98 Dataset variation to load. Options include: 99 {'sound_baseline', 'sound_reflection', 'sound_diffraction', 'sound_combined'}. 100 - mode (str, default='train'): 101 Dataset split to use. Options: {'train', 'test', 'validation'}. 102 - input_type (str, default='osm'): 103 Defines the input image type: 104 - 'osm': open sound map input. 105 - 'base_simulation': uses the baseline sound simulation as input. 106 - output_type (str, default='standard'): 107 Defines the output image type: 108 - 'standard': full soundmap prediction. 109 - 'complex_only': difference from baseline soundmap. 110 - fake_rgb_output (bool, default=False): 111 If True, replicates single-channel inputs to fake RGB (3-channel). 112 - make_14_dividable_size (bool, default=False): 113 If True, resizes tensors so that height and width are divisible by 14. 114 - reflexion_channels (bool, default=False): 115 If ray-traces should add to the input. 116 - reflexion_steps (int, default=36): 117 Defines how many traces should get created. 118 - reflexions_as_channels (bool, default=False): 119 If True, every trace gets its own channel, else every trace in one channel. 120 """ 121 self.fake_rgb_output = fake_rgb_output 122 self.make_14_dividable_size = make_14_dividable_size 123 self.reflexion_channels = reflexion_channels 124 self.reflexion_steps = reflexion_steps 125 self.reflexions_as_channels = reflexions_as_channels 126 127 self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' 128 # get data 129 self.dataset = load_dataset("mspitzna/physicsgen", name=variation, trust_remote_code=True) 130 # print("Keys:", self.dataset.keys()) 131 self.dataset = self.dataset[mode] 132 self.mode = mode 133 134 self.input_type = input_type 135 self.output_type = output_type 136 if self.input_type == "base_simulation" or self.output_type == "complex_only": 137 self.basesimulation_dataset = load_dataset("mspitzna/physicsgen", name="sound_baseline", trust_remote_code=True) 138 self.basesimulation_dataset = self.basesimulation_dataset[mode] 139 140 self.transform = transforms.Compose([ 141 transforms.ToTensor(), # Converts [0,255] PIL image to [0,1] FloatTensor 142 ]) 143 print(f"PhysGen ({variation}) Dataset for {mode} got created")
Loads PhysGen Dataset.
Parameter:
- variation (str, default='sound_baseline'): Dataset variation to load. Options include: {'sound_baseline', 'sound_reflection', 'sound_diffraction', 'sound_combined'}.
- mode (str, default='train'): Dataset split to use. Options: {'train', 'test', 'validation'}.
- input_type (str, default='osm'):
Defines the input image type:
- 'osm': open sound map input.
- 'base_simulation': uses the baseline sound simulation as input.
- output_type (str, default='standard'):
Defines the output image type:
- 'standard': full soundmap prediction.
- 'complex_only': difference from baseline soundmap.
- fake_rgb_output (bool, default=False): If True, replicates single-channel inputs to fake RGB (3-channel).
- make_14_dividable_size (bool, default=False): If True, resizes tensors so that height and width are divisible by 14.
- reflexion_channels (bool, default=False): If ray-traces should add to the input.
- reflexion_steps (int, default=36): Defines how many traces should get created.
- reflexions_as_channels (bool, default=False): If True, every trace gets its own channel, else every trace in one channel.
258def get_dataloader(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True): 259 """ 260 Create a PyTorch DataLoader for the PhysGen dataset. 261 262 This helper simplifies loading the PhysGen dataset for training, 263 validation, or testing. 264 265 Parameter: 266 - mode (str, default='train'): 267 Dataset split to use ('train', 'test', 'validation'). 268 - variation (str, default='sound_reflection'): 269 Dataset variation to load. 270 - input_type (str, default='osm'): 271 Defines the input type ('osm' or 'base_simulation'). 272 - output_type (str, default='complex_only'): 273 Defines the output type ('standard' or 'complex_only'). 274 - shuffle (bool, default=True): 275 Whether to shuffle the dataset between epochs. 276 277 Returns: 278 - torch.utils.data.DataLoader: 279 DataLoader that provides batches of PhysGen samples. 280 """ 281 dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type) 282 return DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1)
Create a PyTorch DataLoader for the PhysGen dataset.
This helper simplifies loading the PhysGen dataset for training, validation, or testing.
Parameter:
- mode (str, default='train'): Dataset split to use ('train', 'test', 'validation').
- variation (str, default='sound_reflection'): Dataset variation to load.
- input_type (str, default='osm'): Defines the input type ('osm' or 'base_simulation').
- output_type (str, default='complex_only'): Defines the output type ('standard' or 'complex_only').
- shuffle (bool, default=True): Whether to shuffle the dataset between epochs.
Returns:
- torch.utils.data.DataLoader: DataLoader that provides batches of PhysGen samples.
286def get_image(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True, 287 return_output=False, as_numpy_array=True): 288 """ 289 Retrieve one image (input and optionally output) from the PhysGen dataset. 290 291 Provides an easy way to visualize or inspect a single PhysGen sample 292 without manually instantiating a DataLoader. 293 294 Parameter: 295 - mode (str, default='train'): 296 Dataset split ('train', 'test', 'validation'). 297 variation (str, default='sound_reflection'): 298 Dataset variation. 299 input_type (str, default='osm'): 300 Defines the input type ('osm' or 'base_simulation'). 301 output_type (str, default='complex_only'): 302 Defines the output type ('standard' or 'complex_only'). 303 shuffle (bool, default=True): 304 Randomly select the sample. 305 return_output (bool, default=False): 306 If True, returns both input and target tensors. 307 as_numpy_array (bool, default=True): 308 If True, converts tensors to NumPy arrays for easier visualization. 309 310 Returns: 311 - numpy.ndarray or list[numpy.ndarray]: 312 Input image as NumPy array, or a list [input, target] if `return_output` is True. 313 """ 314 dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type) 315 loader = DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1) 316 cur_data = next(iter(loader)) 317 input_ = cur_data[0] 318 output_ = cur_data[1] 319 320 if as_numpy_array: 321 input_ = input_.detach().cpu().numpy() 322 output_ = output_.detach().cpu().numpy() 323 324 # remove batch channel 325 input_ = np.squeeze(input_, axis=0) 326 output_ = np.squeeze(output_, axis=0) 327 328 if len(input_.shape) == 3: 329 input_ = np.squeeze(input_, axis=0) 330 output_ = np.squeeze(output_, axis=0) 331 332 input_ = np.transpose(input_, (1, 0)) 333 output_ = np.transpose(output_, (1, 0)) 334 335 336 result = input_ 337 if return_output: 338 result = [input_, output_] 339 340 return result
Retrieve one image (input and optionally output) from the PhysGen dataset.
Provides an easy way to visualize or inspect a single PhysGen sample without manually instantiating a DataLoader.
Parameter:
- mode (str, default='train'): Dataset split ('train', 'test', 'validation'). variation (str, default='sound_reflection'): Dataset variation. input_type (str, default='osm'): Defines the input type ('osm' or 'base_simulation'). output_type (str, default='complex_only'): Defines the output type ('standard' or 'complex_only'). shuffle (bool, default=True): Randomly select the sample. return_output (bool, default=False): If True, returns both input and target tensors. as_numpy_array (bool, default=True): If True, converts tensors to NumPy arrays for easier visualization.
Returns:
- numpy.ndarray or list[numpy.ndarray]:
Input image as NumPy array, or a list [input, target] if
return_outputis True.
344def save_dataset(output_real_path, output_osm_path, 345 variation, input_type, output_type, 346 data_mode, 347 info_print=False, progress_print=True): 348 """ 349 Save PhysGen dataset samples as images to disk. 350 351 This function loads the specified PhysGen dataset, converts input and 352 target tensors to images, and saves them as `.png` files for inspection, 353 debugging, or model-agnostic data use. 354 355 Parameter: 356 - output_real_path (str): 357 Directory to save target (real) soundmaps. 358 - output_osm_path (str): 359 Directory to save input (OSM) maps. 360 - variation (str): 361 Dataset variation (e.g. 'sound_reflection'). 362 - input_type (str): 363 Input type ('osm' or 'base_simulation'). 364 - output_type (str): 365 Output type ('standard' or 'complex_only'). 366 - data_mode (str): 367 Dataset split ('train', 'test', 'validation'). 368 - info_print (bool, default=False): 369 If True, prints detailed information for each saved sample. 370 - progress_print (bool, default=True): 371 If True, shows progress updates in the console. 372 373 Raises: 374 - ValueError: 375 If image data falls outside the valid range [0, 255]. 376 377 """ 378 # Clearing 379 if os.path.exists(output_osm_path) and os.path.isdir(output_osm_path): 380 shutil.rmtree(output_osm_path) 381 os.makedirs(output_osm_path) 382 print(f"Cleared {output_osm_path}.") 383 else: 384 os.makedirs(output_osm_path) 385 print(f"Created {output_osm_path}.") 386 387 if os.path.exists(output_real_path) and os.path.isdir(output_real_path): 388 shutil.rmtree(output_real_path) 389 os.makedirs(output_real_path) 390 print(f"Cleared {output_real_path}.") 391 else: 392 os.makedirs(output_real_path) 393 print(f"Created {output_real_path}.") 394 395 # Load Dataset 396 dataset = PhysGenDataset(mode=data_mode, variation=variation, input_type=input_type, output_type=output_type) 397 data_len = len(dataset) 398 dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) 399 400 # Save Dataset 401 for i, data in enumerate(dataloader): 402 if progress_print: 403 prime.get_progress_bar(total=data_len, progress=i+1, 404 should_clear=True, left_bar_char='|', right_bar_char='|', 405 progress_char='#', empty_char=' ', 406 front_message='Physgen Data Loading', back_message='', size=15) 407 408 input_img, target_img, idx = data 409 idx = idx[0].item() if isinstance(idx, torch.Tensor) else idx 410 411 if info_print: 412 print(f"Prediction shape [osm]: {input_img.shape}") 413 print(f"Prediction shape [target]: {target_img.shape}") 414 415 print(f"OSM Info:\n -> shape: {input_img.shape}\n -> min: {input_img.min()}, max: {input_img.max()}") 416 417 real_img = target_img.squeeze(0).cpu().squeeze(0).detach().numpy() 418 if not (0 <= real_img.min() <= 255 and 0 <= real_img.max() <=255): 419 raise ValueError(f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 420 if info_print: 421 print( f"\nReal target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 422 if real_img.max() <= 1.0: 423 real_img *= 255 424 if info_print: 425 print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 426 real_img = real_img.astype(np.uint8) 427 if info_print: 428 print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 429 430 if len(input_img.shape) == 4: 431 osm_img = input_img[0, 0].cpu().detach().numpy() 432 else: 433 osm_img = input_img[0].cpu().detach().numpy() 434 if not (0 <= osm_img.min() <= 255 and 0 <= osm_img.max() <=255): 435 raise ValueError(f"Real target has values out of 0-256 range => min:{osm_img.min()}, max:{osm_img.max()}") 436 if osm_img.max() <= 1.0: 437 osm_img *= 255 438 osm_img = osm_img.astype(np.uint8) 439 440 if info_print: 441 print(f"OSM Info:\n -> shape: {osm_img.shape}\n -> min: {osm_img.min()}, max: {osm_img.max()}") 442 443 # Save Results 444 file_name = f"physgen_{idx}.png" 445 446 # save pred image 447 # save_img = os.path.join(output_pred_path, file_name) 448 # cv2.imwrite(save_img, pred_img) 449 # print(f" -> saved pred at {save_img}") 450 451 # save real image 452 save_img = os.path.join(output_real_path, "target_"+file_name) 453 cv2.imwrite(save_img, real_img) 454 if info_print: 455 print(f" -> saved real at {save_img}") 456 457 # save osm image 458 save_img = os.path.join(output_osm_path, "input_"+file_name) 459 cv2.imwrite(save_img, osm_img) 460 if info_print: 461 print(f" -> saved osm at {save_img}") 462 print(f"\nSuccessfull saved {data_len} datapoints into {os.path.abspath(output_real_path)} & {os.path.abspath(output_osm_path)}")
Save PhysGen dataset samples as images to disk.
This function loads the specified PhysGen dataset, converts input and
target tensors to images, and saves them as .png files for inspection,
debugging, or model-agnostic data use.
Parameter:
- output_real_path (str): Directory to save target (real) soundmaps.
- output_osm_path (str): Directory to save input (OSM) maps.
- variation (str): Dataset variation (e.g. 'sound_reflection').
- input_type (str): Input type ('osm' or 'base_simulation').
- output_type (str): Output type ('standard' or 'complex_only').
- data_mode (str): Dataset split ('train', 'test', 'validation').
- info_print (bool, default=False): If True, prints detailed information for each saved sample.
- progress_print (bool, default=True): If True, shows progress updates in the console.
Raises:
- ValueError: If image data falls outside the valid range [0, 255].