image_to_image.model_interactions.inference
Module for inferencing image to image models.
Functions:
- load_image
- save_output
- inference
By Tobia Ippolito
1""" 2Module for inferencing image to image models. 3 4Functions: 5- load_image 6- save_output 7- inference 8 9By Tobia Ippolito 10""" 11# --------------------------- 12# > Imports < 13# --------------------------- 14import os 15 16import cv2 17from PIL import Image 18 19from tqdm import tqdm 20 21import torch 22from torch.utils.data import DataLoader 23from torchvision import transforms 24 25from ..utils.argument_parsing import parse_args 26from ..utils.model_io import load_and_get_model 27 28from ..data.physgen import PhysGenDataset 29from ..data.residual_physgen import PhysGenResidualDataset 30 31from ..models.resfcn import ResFCN 32 33 34 35# --------------------------- 36# > Inference Helper < 37# --------------------------- 38def load_image(path, width=256, height=256, grayscale=True): 39 """ 40 Loads an image from disk, resizes it, converts to tensor, and adds a batch dimension. 41 42 Parameters: 43 - path (str): 44 Path to the image file. 45 - width (int): 46 Desired image width after resizing (default=256). 47 - height (int): 48 Desired image height after resizing (default=256). 49 - grayscale (bool): 50 Whether to load the image as grayscale (default=True). 51 52 Returns: 53 - torch.Tensor: Image tensor with shape [1, C, H, W], ready for model input. 54 """ 55 img = cv2.imread(path, cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR) 56 transform = transforms.Compose([ 57 transforms.Resize((width, height)), 58 transforms.ToTensor(), 59 ]) 60 img = Image.fromarray(img) 61 return transform(img).unsqueeze(0) # Add batch dimension 62 63 64def save_output(tensor, path): 65 """ 66 Saves a tensor as an image file to disk. 67 68 Parameters: 69 - tensor (torch.Tensor): 70 Image tensor to save. Expected shape [B, C, H, W] or [C, H, W]. 71 - path (str): 72 Path where the image will be saved. 73 """ 74 from torchvision.utils import save_image 75 save_image(tensor, path) 76 77 78 79# --------------------------- 80# > Inference < 81# --------------------------- 82def inference(args=None): 83 """ 84 Runs inference using a pre-trained ResFCN model on a dataset or a custom image directory. 85 86 The function supports: 87 - Loading a test dataset if no custom images are provided. 88 - Loading a pre-trained model checkpoint. 89 - Running inference on either dataset or custom images. 90 - Saving predicted outputs to the specified output directory. 91 92 Parameters: 93 - args (Namespace or None): Optional argument namespace, typically from argparse. 94 Required fields in `args`: 95 - device (str): Device for computation ("cuda" or "cpu"). 96 - image_dir_path (str or None): Path to custom images (optional). 97 - output_dir (str): Directory where outputs will be saved. 98 - data_variation (str): Dataset variation if using dataset. 99 - input_type (str): Input type for dataset. 100 - output_type (str): Output type for dataset. 101 - fake_rgb_output (bool): Flag for dataset processing. 102 - make_14_dividable_size (bool): Flag for resizing dataset images. 103 - resfcn_in_channels (int): Number of input channels for the ResFCN model. 104 - resfcn_hidden_channels (int): Number of hidden channels for ResFCN. 105 - resfcn_out_channels (int): Number of output channels for ResFCN. 106 - resfcn_num_blocks (int): Number of residual blocks in ResFCN. 107 - model_params_path (str): Path to the saved model checkpoint. 108 109 Returns: 110 - None: The function saves predicted images to disk. 111 """ 112 if args is None: 113 args = parse_args() 114 115 device = torch.device(args.device if torch.cuda.is_available() else "cpu") 116 117 custom_images = not type(args.image_dir_path) == type(None) and os.path.exists(args.image_dir_path) 118 119 # Data Loading 120 if not custom_images: 121 if args.model.lower() == "residual_design_model": 122 test_dataset = PhysGenResidualDataset(variation=args.data_variation, mode="test", 123 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 124 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 125 else: 126 test_dataset = PhysGenDataset(variation=args.data_variation, mode="test", input_type=args.input_type, output_type=args.output_type, 127 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 128 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 129 130 test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 131 132 # Model Loading 133 model, _ = load_and_get_model(args.model_params_path, device, criterion=None) 134 135 # input_dir = args.input_dir 136 output_dir = args.output_dir 137 os.makedirs(output_dir, exist_ok=True) 138 139 # run inference 140 if custom_images: 141 for filename in tqdm(os.listdir(args.image_dir_path), desc="Inference", leave=False): 142 if not filename.lower().endswith(('.jpg', '.png', '.jpeg')): 143 continue 144 path = os.path.join(args.image_dir_path, filename) 145 x = load_image(path).to(device) 146 with torch.no_grad(): 147 y_pred = model(x) 148 save_output(y_pred, os.path.join(output_dir, "pred_"+filename)) 149 150 tqdm.write(f"Saved predictions to {os.path.join(output_dir, "pred_"+filename)}") 151 else: 152 idx = 0 153 for x, y in tqdm(test_loader, desc="Inference", leave=False): 154 x, y = x.to(device), y.to(device) 155 cur_file_name = f"buildings_{idx}_real_B.png" 156 with torch.no_grad(): 157 y_pred = model(x) 158 save_output(y, os.path.join(output_dir, cur_file_name)) 159 save_output(y_pred, os.path.join(output_dir, cur_file_name.replace("real", "fake"))) 160 if isinstance(x, (list, tuple)): 161 x = x[0] 162 save_output(x, os.path.join(output_dir, cur_file_name.replace("real_B", "real_A"))) 163 164 tqdm.write(f"[{idx}] Saved predictions to {os.path.join(output_dir, cur_file_name.replace("real", "fake"))}") 165 166 idx += 1 167 168 169 170if __name__ == "__main__": 171 inference()
def
load_image(path, width=256, height=256, grayscale=True):
39def load_image(path, width=256, height=256, grayscale=True): 40 """ 41 Loads an image from disk, resizes it, converts to tensor, and adds a batch dimension. 42 43 Parameters: 44 - path (str): 45 Path to the image file. 46 - width (int): 47 Desired image width after resizing (default=256). 48 - height (int): 49 Desired image height after resizing (default=256). 50 - grayscale (bool): 51 Whether to load the image as grayscale (default=True). 52 53 Returns: 54 - torch.Tensor: Image tensor with shape [1, C, H, W], ready for model input. 55 """ 56 img = cv2.imread(path, cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR) 57 transform = transforms.Compose([ 58 transforms.Resize((width, height)), 59 transforms.ToTensor(), 60 ]) 61 img = Image.fromarray(img) 62 return transform(img).unsqueeze(0) # Add batch dimension
Loads an image from disk, resizes it, converts to tensor, and adds a batch dimension.
Parameters:
- path (str): Path to the image file.
- width (int): Desired image width after resizing (default=256).
- height (int): Desired image height after resizing (default=256).
- grayscale (bool): Whether to load the image as grayscale (default=True).
Returns:
- torch.Tensor: Image tensor with shape [1, C, H, W], ready for model input.
def
save_output(tensor, path):
65def save_output(tensor, path): 66 """ 67 Saves a tensor as an image file to disk. 68 69 Parameters: 70 - tensor (torch.Tensor): 71 Image tensor to save. Expected shape [B, C, H, W] or [C, H, W]. 72 - path (str): 73 Path where the image will be saved. 74 """ 75 from torchvision.utils import save_image 76 save_image(tensor, path)
Saves a tensor as an image file to disk.
Parameters:
- tensor (torch.Tensor): Image tensor to save. Expected shape [B, C, H, W] or [C, H, W].
- path (str): Path where the image will be saved.
def
inference(args=None):
83def inference(args=None): 84 """ 85 Runs inference using a pre-trained ResFCN model on a dataset or a custom image directory. 86 87 The function supports: 88 - Loading a test dataset if no custom images are provided. 89 - Loading a pre-trained model checkpoint. 90 - Running inference on either dataset or custom images. 91 - Saving predicted outputs to the specified output directory. 92 93 Parameters: 94 - args (Namespace or None): Optional argument namespace, typically from argparse. 95 Required fields in `args`: 96 - device (str): Device for computation ("cuda" or "cpu"). 97 - image_dir_path (str or None): Path to custom images (optional). 98 - output_dir (str): Directory where outputs will be saved. 99 - data_variation (str): Dataset variation if using dataset. 100 - input_type (str): Input type for dataset. 101 - output_type (str): Output type for dataset. 102 - fake_rgb_output (bool): Flag for dataset processing. 103 - make_14_dividable_size (bool): Flag for resizing dataset images. 104 - resfcn_in_channels (int): Number of input channels for the ResFCN model. 105 - resfcn_hidden_channels (int): Number of hidden channels for ResFCN. 106 - resfcn_out_channels (int): Number of output channels for ResFCN. 107 - resfcn_num_blocks (int): Number of residual blocks in ResFCN. 108 - model_params_path (str): Path to the saved model checkpoint. 109 110 Returns: 111 - None: The function saves predicted images to disk. 112 """ 113 if args is None: 114 args = parse_args() 115 116 device = torch.device(args.device if torch.cuda.is_available() else "cpu") 117 118 custom_images = not type(args.image_dir_path) == type(None) and os.path.exists(args.image_dir_path) 119 120 # Data Loading 121 if not custom_images: 122 if args.model.lower() == "residual_design_model": 123 test_dataset = PhysGenResidualDataset(variation=args.data_variation, mode="test", 124 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 125 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 126 else: 127 test_dataset = PhysGenDataset(variation=args.data_variation, mode="test", input_type=args.input_type, output_type=args.output_type, 128 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 129 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 130 131 test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 132 133 # Model Loading 134 model, _ = load_and_get_model(args.model_params_path, device, criterion=None) 135 136 # input_dir = args.input_dir 137 output_dir = args.output_dir 138 os.makedirs(output_dir, exist_ok=True) 139 140 # run inference 141 if custom_images: 142 for filename in tqdm(os.listdir(args.image_dir_path), desc="Inference", leave=False): 143 if not filename.lower().endswith(('.jpg', '.png', '.jpeg')): 144 continue 145 path = os.path.join(args.image_dir_path, filename) 146 x = load_image(path).to(device) 147 with torch.no_grad(): 148 y_pred = model(x) 149 save_output(y_pred, os.path.join(output_dir, "pred_"+filename)) 150 151 tqdm.write(f"Saved predictions to {os.path.join(output_dir, "pred_"+filename)}") 152 else: 153 idx = 0 154 for x, y in tqdm(test_loader, desc="Inference", leave=False): 155 x, y = x.to(device), y.to(device) 156 cur_file_name = f"buildings_{idx}_real_B.png" 157 with torch.no_grad(): 158 y_pred = model(x) 159 save_output(y, os.path.join(output_dir, cur_file_name)) 160 save_output(y_pred, os.path.join(output_dir, cur_file_name.replace("real", "fake"))) 161 if isinstance(x, (list, tuple)): 162 x = x[0] 163 save_output(x, os.path.join(output_dir, cur_file_name.replace("real_B", "real_A"))) 164 165 tqdm.write(f"[{idx}] Saved predictions to {os.path.join(output_dir, cur_file_name.replace("real", "fake"))}") 166 167 idx += 1
Runs inference using a pre-trained ResFCN model on a dataset or a custom image directory.
The function supports:
- Loading a test dataset if no custom images are provided.
- Loading a pre-trained model checkpoint.
- Running inference on either dataset or custom images.
- Saving predicted outputs to the specified output directory.
Parameters:
- args (Namespace or None): Optional argument namespace, typically from argparse.
Required fields in
args:- device (str): Device for computation ("cuda" or "cpu").
- image_dir_path (str or None): Path to custom images (optional).
- output_dir (str): Directory where outputs will be saved.
- data_variation (str): Dataset variation if using dataset.
- input_type (str): Input type for dataset.
- output_type (str): Output type for dataset.
- fake_rgb_output (bool): Flag for dataset processing.
- make_14_dividable_size (bool): Flag for resizing dataset images.
- resfcn_in_channels (int): Number of input channels for the ResFCN model.
- resfcn_hidden_channels (int): Number of hidden channels for ResFCN.
- resfcn_out_channels (int): Number of output channels for ResFCN.
- resfcn_num_blocks (int): Number of residual blocks in ResFCN.
- model_params_path (str): Path to the saved model checkpoint.
Returns:
- None: The function saves predicted images to disk.