image_to_image.model_interactions.inference

Module for inferencing image to image models.

Functions:

  • load_image
  • save_output
  • inference

By Tobia Ippolito

  1"""
  2Module for inferencing image to image models.
  3
  4Functions:
  5- load_image
  6- save_output
  7- inference
  8
  9By Tobia Ippolito
 10"""
 11# ---------------------------
 12#        > Imports <
 13# ---------------------------
 14import os
 15
 16import cv2
 17from PIL import Image
 18
 19from tqdm import tqdm
 20
 21import torch
 22from torch.utils.data import DataLoader
 23from torchvision import transforms
 24
 25from ..utils.argument_parsing import parse_args
 26from ..utils.model_io import load_and_get_model
 27
 28from ..data.physgen import PhysGenDataset
 29from ..data.residual_physgen import PhysGenResidualDataset
 30
 31from ..models.resfcn import ResFCN 
 32
 33
 34
 35# ---------------------------
 36#    > Inference Helper <
 37# ---------------------------
 38def load_image(path, width=256, height=256, grayscale=True):
 39    """
 40    Loads an image from disk, resizes it, converts to tensor, and adds a batch dimension.
 41
 42    Parameters:
 43    - path (str): 
 44        Path to the image file.
 45    - width (int): 
 46        Desired image width after resizing (default=256).
 47    - height (int): 
 48        Desired image height after resizing (default=256).
 49    - grayscale (bool): 
 50        Whether to load the image as grayscale (default=True).
 51
 52    Returns:
 53    - torch.Tensor: Image tensor with shape [1, C, H, W], ready for model input.
 54    """
 55    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR)
 56    transform = transforms.Compose([
 57        transforms.Resize((width, height)),
 58        transforms.ToTensor(),
 59    ])
 60    img = Image.fromarray(img)
 61    return transform(img).unsqueeze(0)  # Add batch dimension
 62
 63
 64def save_output(tensor, path):
 65    """
 66    Saves a tensor as an image file to disk.
 67
 68    Parameters:
 69    - tensor (torch.Tensor): 
 70        Image tensor to save. Expected shape [B, C, H, W] or [C, H, W].
 71    - path (str): 
 72        Path where the image will be saved.
 73    """
 74    from torchvision.utils import save_image
 75    save_image(tensor, path)
 76
 77
 78
 79# ---------------------------
 80#       > Inference <
 81# ---------------------------
 82def inference(args=None):
 83    """
 84    Runs inference using a pre-trained ResFCN model on a dataset or a custom image directory.
 85
 86    The function supports:
 87    - Loading a test dataset if no custom images are provided.
 88    - Loading a pre-trained model checkpoint.
 89    - Running inference on either dataset or custom images.
 90    - Saving predicted outputs to the specified output directory.
 91
 92    Parameters:
 93    - args (Namespace or None): Optional argument namespace, typically from argparse.
 94        Required fields in `args`:
 95        - device (str): Device for computation ("cuda" or "cpu").
 96        - image_dir_path (str or None): Path to custom images (optional).
 97        - output_dir (str): Directory where outputs will be saved.
 98        - data_variation (str): Dataset variation if using dataset.
 99        - input_type (str): Input type for dataset.
100        - output_type (str): Output type for dataset.
101        - fake_rgb_output (bool): Flag for dataset processing.
102        - make_14_dividable_size (bool): Flag for resizing dataset images.
103        - resfcn_in_channels (int): Number of input channels for the ResFCN model.
104        - resfcn_hidden_channels (int): Number of hidden channels for ResFCN.
105        - resfcn_out_channels (int): Number of output channels for ResFCN.
106        - resfcn_num_blocks (int): Number of residual blocks in ResFCN.
107        - model_params_path (str): Path to the saved model checkpoint.
108
109    Returns:
110    - None: The function saves predicted images to disk.
111    """
112    if args is None:
113        args = parse_args()
114
115    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
116
117    custom_images = not type(args.image_dir_path) == type(None) and os.path.exists(args.image_dir_path)
118
119    # Data Loading
120    if not custom_images:
121        if args.model.lower() == "residual_design_model":
122            test_dataset = PhysGenResidualDataset(variation=args.data_variation, mode="test", 
123                                                  fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
124                                                  reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
125        else:
126            test_dataset = PhysGenDataset(variation=args.data_variation, mode="test", input_type=args.input_type, output_type=args.output_type, 
127                                           fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
128                                           reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
129
130        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
131    
132    # Model Loading
133    model, _ = load_and_get_model(args.model_params_path, device, criterion=None)
134
135    # input_dir = args.input_dir
136    output_dir = args.output_dir
137    os.makedirs(output_dir, exist_ok=True)
138
139    # run inference
140    if custom_images:
141        for filename in tqdm(os.listdir(args.image_dir_path), desc="Inference", leave=False):
142            if not filename.lower().endswith(('.jpg', '.png', '.jpeg')):
143                continue
144            path = os.path.join(args.image_dir_path, filename)
145            x = load_image(path).to(device)
146            with torch.no_grad():
147                y_pred = model(x)
148                save_output(y_pred, os.path.join(output_dir, "pred_"+filename))
149
150                tqdm.write(f"Saved predictions to {os.path.join(output_dir, "pred_"+filename)}")
151    else:
152        idx = 0
153        for x, y in tqdm(test_loader, desc="Inference", leave=False):
154            x, y = x.to(device), y.to(device)
155            cur_file_name = f"buildings_{idx}_real_B.png"
156            with torch.no_grad():
157                y_pred = model(x)
158                save_output(y, os.path.join(output_dir, cur_file_name))
159                save_output(y_pred, os.path.join(output_dir, cur_file_name.replace("real", "fake")))
160                if isinstance(x, (list, tuple)):
161                    x = x[0]
162                save_output(x, os.path.join(output_dir, cur_file_name.replace("real_B", "real_A")))
163                
164                tqdm.write(f"[{idx}] Saved predictions to {os.path.join(output_dir, cur_file_name.replace("real", "fake"))}")
165
166                idx += 1
167
168
169
170if __name__ == "__main__":
171    inference()
def load_image(path, width=256, height=256, grayscale=True):
39def load_image(path, width=256, height=256, grayscale=True):
40    """
41    Loads an image from disk, resizes it, converts to tensor, and adds a batch dimension.
42
43    Parameters:
44    - path (str): 
45        Path to the image file.
46    - width (int): 
47        Desired image width after resizing (default=256).
48    - height (int): 
49        Desired image height after resizing (default=256).
50    - grayscale (bool): 
51        Whether to load the image as grayscale (default=True).
52
53    Returns:
54    - torch.Tensor: Image tensor with shape [1, C, H, W], ready for model input.
55    """
56    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR)
57    transform = transforms.Compose([
58        transforms.Resize((width, height)),
59        transforms.ToTensor(),
60    ])
61    img = Image.fromarray(img)
62    return transform(img).unsqueeze(0)  # Add batch dimension

Loads an image from disk, resizes it, converts to tensor, and adds a batch dimension.

Parameters:

  • path (str): Path to the image file.
  • width (int): Desired image width after resizing (default=256).
  • height (int): Desired image height after resizing (default=256).
  • grayscale (bool): Whether to load the image as grayscale (default=True).

Returns:

  • torch.Tensor: Image tensor with shape [1, C, H, W], ready for model input.
def save_output(tensor, path):
65def save_output(tensor, path):
66    """
67    Saves a tensor as an image file to disk.
68
69    Parameters:
70    - tensor (torch.Tensor): 
71        Image tensor to save. Expected shape [B, C, H, W] or [C, H, W].
72    - path (str): 
73        Path where the image will be saved.
74    """
75    from torchvision.utils import save_image
76    save_image(tensor, path)

Saves a tensor as an image file to disk.

Parameters:

  • tensor (torch.Tensor): Image tensor to save. Expected shape [B, C, H, W] or [C, H, W].
  • path (str): Path where the image will be saved.
def inference(args=None):
 83def inference(args=None):
 84    """
 85    Runs inference using a pre-trained ResFCN model on a dataset or a custom image directory.
 86
 87    The function supports:
 88    - Loading a test dataset if no custom images are provided.
 89    - Loading a pre-trained model checkpoint.
 90    - Running inference on either dataset or custom images.
 91    - Saving predicted outputs to the specified output directory.
 92
 93    Parameters:
 94    - args (Namespace or None): Optional argument namespace, typically from argparse.
 95        Required fields in `args`:
 96        - device (str): Device for computation ("cuda" or "cpu").
 97        - image_dir_path (str or None): Path to custom images (optional).
 98        - output_dir (str): Directory where outputs will be saved.
 99        - data_variation (str): Dataset variation if using dataset.
100        - input_type (str): Input type for dataset.
101        - output_type (str): Output type for dataset.
102        - fake_rgb_output (bool): Flag for dataset processing.
103        - make_14_dividable_size (bool): Flag for resizing dataset images.
104        - resfcn_in_channels (int): Number of input channels for the ResFCN model.
105        - resfcn_hidden_channels (int): Number of hidden channels for ResFCN.
106        - resfcn_out_channels (int): Number of output channels for ResFCN.
107        - resfcn_num_blocks (int): Number of residual blocks in ResFCN.
108        - model_params_path (str): Path to the saved model checkpoint.
109
110    Returns:
111    - None: The function saves predicted images to disk.
112    """
113    if args is None:
114        args = parse_args()
115
116    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
117
118    custom_images = not type(args.image_dir_path) == type(None) and os.path.exists(args.image_dir_path)
119
120    # Data Loading
121    if not custom_images:
122        if args.model.lower() == "residual_design_model":
123            test_dataset = PhysGenResidualDataset(variation=args.data_variation, mode="test", 
124                                                  fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
125                                                  reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
126        else:
127            test_dataset = PhysGenDataset(variation=args.data_variation, mode="test", input_type=args.input_type, output_type=args.output_type, 
128                                           fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
129                                           reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
130
131        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
132    
133    # Model Loading
134    model, _ = load_and_get_model(args.model_params_path, device, criterion=None)
135
136    # input_dir = args.input_dir
137    output_dir = args.output_dir
138    os.makedirs(output_dir, exist_ok=True)
139
140    # run inference
141    if custom_images:
142        for filename in tqdm(os.listdir(args.image_dir_path), desc="Inference", leave=False):
143            if not filename.lower().endswith(('.jpg', '.png', '.jpeg')):
144                continue
145            path = os.path.join(args.image_dir_path, filename)
146            x = load_image(path).to(device)
147            with torch.no_grad():
148                y_pred = model(x)
149                save_output(y_pred, os.path.join(output_dir, "pred_"+filename))
150
151                tqdm.write(f"Saved predictions to {os.path.join(output_dir, "pred_"+filename)}")
152    else:
153        idx = 0
154        for x, y in tqdm(test_loader, desc="Inference", leave=False):
155            x, y = x.to(device), y.to(device)
156            cur_file_name = f"buildings_{idx}_real_B.png"
157            with torch.no_grad():
158                y_pred = model(x)
159                save_output(y, os.path.join(output_dir, cur_file_name))
160                save_output(y_pred, os.path.join(output_dir, cur_file_name.replace("real", "fake")))
161                if isinstance(x, (list, tuple)):
162                    x = x[0]
163                save_output(x, os.path.join(output_dir, cur_file_name.replace("real_B", "real_A")))
164                
165                tqdm.write(f"[{idx}] Saved predictions to {os.path.join(output_dir, cur_file_name.replace("real", "fake"))}")
166
167                idx += 1

Runs inference using a pre-trained ResFCN model on a dataset or a custom image directory.

The function supports:

  • Loading a test dataset if no custom images are provided.
  • Loading a pre-trained model checkpoint.
  • Running inference on either dataset or custom images.
  • Saving predicted outputs to the specified output directory.

Parameters:

  • args (Namespace or None): Optional argument namespace, typically from argparse. Required fields in args:
    • device (str): Device for computation ("cuda" or "cpu").
    • image_dir_path (str or None): Path to custom images (optional).
    • output_dir (str): Directory where outputs will be saved.
    • data_variation (str): Dataset variation if using dataset.
    • input_type (str): Input type for dataset.
    • output_type (str): Output type for dataset.
    • fake_rgb_output (bool): Flag for dataset processing.
    • make_14_dividable_size (bool): Flag for resizing dataset images.
    • resfcn_in_channels (int): Number of input channels for the ResFCN model.
    • resfcn_hidden_channels (int): Number of hidden channels for ResFCN.
    • resfcn_out_channels (int): Number of output channels for ResFCN.
    • resfcn_num_blocks (int): Number of residual blocks in ResFCN.
    • model_params_path (str): Path to the saved model checkpoint.

Returns:

  • None: The function saves predicted images to disk.