image_to_image.model_interactions.test

Module for testing image to image models.

Functions:

  • test

By Tobia Ippolito

 1"""
 2Module for testing image to image models.
 3
 4Functions:
 5- test
 6
 7By Tobia Ippolito
 8"""
 9# ---------------------------
10#        > Imports <
11# ---------------------------
12import torch
13from torch.utils.data import DataLoader
14from torchvision import datasets, transforms
15from torch import nn
16
17from ..utils.argument_parsing import parse_args
18from ..utils.model_io import load_and_get_model
19
20from .train import evaluate, get_loss  # reuse the evaluate function
21
22from ..data.physgen import PhysGenDataset
23from ..data.residual_physgen import PhysGenResidualDataset
24
25from ..models.resfcn import ResFCN 
26
27
28# ---------------------------
29#        > Run Test <
30# ---------------------------
31def test(args=None):
32    """
33    Runs evaluation of a pre-trained ResFCN model on a test dataset.
34
35    This function:
36    - Loads the test dataset using `PhysGenDataset`.
37    - Loads a ResFCN model with parameters specified in `args`.
38    - Loads model weights from a checkpoint file.
39    - Evaluates the model on the test dataset using a specified loss criterion.
40
41    Parameters:
42    - args (Namespace or None): Optional argument namespace, typically from argparse.
43        Required fields in `args`:
44        - device (str): Device for computation ("cuda" or "cpu").
45        - data_variation (str): Dataset variation to use for testing.
46        - input_type (str): Input type for dataset.
47        - output_type (str): Output type for dataset.
48        - fake_rgb_output (bool): Flag for dataset preprocessing.
49        - make_14_dividable_size (bool): Flag for resizing dataset images.
50        - batch_size (int): Batch size for DataLoader.
51        - resfcn_in_channels (int): Number of input channels for the ResFCN model.
52        - resfcn_hidden_channels (int): Number of hidden channels for ResFCN.
53        - resfcn_out_channels (int): Number of output channels for ResFCN.
54        - resfcn_num_blocks (int): Number of residual blocks in ResFCN.
55        - model_params_path (str): Path to the saved model checkpoint.
56        - loss (str): Loss type to use for evaluation ("l1" or "crossentropy").
57
58    Returns:
59    - None: Prints the test loss to stdout.
60    """
61    if args is None:
62        args = parse_args()
63
64    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
65
66    # Dataset loading
67    if args.model.lower() == "residual_design_model":
68        test_dataset = PhysGenResidualDataset(variation=args.data_variation, mode="test", 
69                                              fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
70                                              reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
71    else:
72        test_dataset = PhysGenDataset(variation=args.data_variation, mode="test", input_type=args.input_type, output_type=args.output_type, 
73                                      fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
74                                      reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
75
76    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
77
78    # Model Loading
79    model, _ = load_and_get_model(args.model_params_path, device, criterion=None)
80
81    # Criterion for evaluation
82    criterion = get_loss(args.loss, args)
83    print(f"Used Loss: {args.loss}")
84
85    # Run evaluation
86    test_loss = evaluate(model, test_loader, criterion, device)
87    print(f"Test Loss: {test_loss:.4f}")
88
89
90
91if __name__ == "__main__":
92    test()
def test(args=None):
32def test(args=None):
33    """
34    Runs evaluation of a pre-trained ResFCN model on a test dataset.
35
36    This function:
37    - Loads the test dataset using `PhysGenDataset`.
38    - Loads a ResFCN model with parameters specified in `args`.
39    - Loads model weights from a checkpoint file.
40    - Evaluates the model on the test dataset using a specified loss criterion.
41
42    Parameters:
43    - args (Namespace or None): Optional argument namespace, typically from argparse.
44        Required fields in `args`:
45        - device (str): Device for computation ("cuda" or "cpu").
46        - data_variation (str): Dataset variation to use for testing.
47        - input_type (str): Input type for dataset.
48        - output_type (str): Output type for dataset.
49        - fake_rgb_output (bool): Flag for dataset preprocessing.
50        - make_14_dividable_size (bool): Flag for resizing dataset images.
51        - batch_size (int): Batch size for DataLoader.
52        - resfcn_in_channels (int): Number of input channels for the ResFCN model.
53        - resfcn_hidden_channels (int): Number of hidden channels for ResFCN.
54        - resfcn_out_channels (int): Number of output channels for ResFCN.
55        - resfcn_num_blocks (int): Number of residual blocks in ResFCN.
56        - model_params_path (str): Path to the saved model checkpoint.
57        - loss (str): Loss type to use for evaluation ("l1" or "crossentropy").
58
59    Returns:
60    - None: Prints the test loss to stdout.
61    """
62    if args is None:
63        args = parse_args()
64
65    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
66
67    # Dataset loading
68    if args.model.lower() == "residual_design_model":
69        test_dataset = PhysGenResidualDataset(variation=args.data_variation, mode="test", 
70                                              fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
71                                              reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
72    else:
73        test_dataset = PhysGenDataset(variation=args.data_variation, mode="test", input_type=args.input_type, output_type=args.output_type, 
74                                      fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size,
75                                      reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels)
76
77    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
78
79    # Model Loading
80    model, _ = load_and_get_model(args.model_params_path, device, criterion=None)
81
82    # Criterion for evaluation
83    criterion = get_loss(args.loss, args)
84    print(f"Used Loss: {args.loss}")
85
86    # Run evaluation
87    test_loss = evaluate(model, test_loader, criterion, device)
88    print(f"Test Loss: {test_loss:.4f}")

Runs evaluation of a pre-trained ResFCN model on a test dataset.

This function:

  • Loads the test dataset using PhysGenDataset.
  • Loads a ResFCN model with parameters specified in args.
  • Loads model weights from a checkpoint file.
  • Evaluates the model on the test dataset using a specified loss criterion.

Parameters:

  • args (Namespace or None): Optional argument namespace, typically from argparse. Required fields in args:
    • device (str): Device for computation ("cuda" or "cpu").
    • data_variation (str): Dataset variation to use for testing.
    • input_type (str): Input type for dataset.
    • output_type (str): Output type for dataset.
    • fake_rgb_output (bool): Flag for dataset preprocessing.
    • make_14_dividable_size (bool): Flag for resizing dataset images.
    • batch_size (int): Batch size for DataLoader.
    • resfcn_in_channels (int): Number of input channels for the ResFCN model.
    • resfcn_hidden_channels (int): Number of hidden channels for ResFCN.
    • resfcn_out_channels (int): Number of output channels for ResFCN.
    • resfcn_num_blocks (int): Number of residual blocks in ResFCN.
    • model_params_path (str): Path to the saved model checkpoint.
    • loss (str): Loss type to use for evaluation ("l1" or "crossentropy").

Returns:

  • None: Prints the test loss to stdout.