image_to_image.model_interactions.test
Module for testing image to image models.
Functions:
- test
By Tobia Ippolito
1""" 2Module for testing image to image models. 3 4Functions: 5- test 6 7By Tobia Ippolito 8""" 9# --------------------------- 10# > Imports < 11# --------------------------- 12import torch 13from torch.utils.data import DataLoader 14from torchvision import datasets, transforms 15from torch import nn 16 17from ..utils.argument_parsing import parse_args 18from ..utils.model_io import load_and_get_model 19 20from .train import evaluate, get_loss # reuse the evaluate function 21 22from ..data.physgen import PhysGenDataset 23from ..data.residual_physgen import PhysGenResidualDataset 24 25from ..models.resfcn import ResFCN 26 27 28# --------------------------- 29# > Run Test < 30# --------------------------- 31def test(args=None): 32 """ 33 Runs evaluation of a pre-trained ResFCN model on a test dataset. 34 35 This function: 36 - Loads the test dataset using `PhysGenDataset`. 37 - Loads a ResFCN model with parameters specified in `args`. 38 - Loads model weights from a checkpoint file. 39 - Evaluates the model on the test dataset using a specified loss criterion. 40 41 Parameters: 42 - args (Namespace or None): Optional argument namespace, typically from argparse. 43 Required fields in `args`: 44 - device (str): Device for computation ("cuda" or "cpu"). 45 - data_variation (str): Dataset variation to use for testing. 46 - input_type (str): Input type for dataset. 47 - output_type (str): Output type for dataset. 48 - fake_rgb_output (bool): Flag for dataset preprocessing. 49 - make_14_dividable_size (bool): Flag for resizing dataset images. 50 - batch_size (int): Batch size for DataLoader. 51 - resfcn_in_channels (int): Number of input channels for the ResFCN model. 52 - resfcn_hidden_channels (int): Number of hidden channels for ResFCN. 53 - resfcn_out_channels (int): Number of output channels for ResFCN. 54 - resfcn_num_blocks (int): Number of residual blocks in ResFCN. 55 - model_params_path (str): Path to the saved model checkpoint. 56 - loss (str): Loss type to use for evaluation ("l1" or "crossentropy"). 57 58 Returns: 59 - None: Prints the test loss to stdout. 60 """ 61 if args is None: 62 args = parse_args() 63 64 device = torch.device(args.device if torch.cuda.is_available() else "cpu") 65 66 # Dataset loading 67 if args.model.lower() == "residual_design_model": 68 test_dataset = PhysGenResidualDataset(variation=args.data_variation, mode="test", 69 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 70 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 71 else: 72 test_dataset = PhysGenDataset(variation=args.data_variation, mode="test", input_type=args.input_type, output_type=args.output_type, 73 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 74 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 75 76 test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 77 78 # Model Loading 79 model, _ = load_and_get_model(args.model_params_path, device, criterion=None) 80 81 # Criterion for evaluation 82 criterion = get_loss(args.loss, args) 83 print(f"Used Loss: {args.loss}") 84 85 # Run evaluation 86 test_loss = evaluate(model, test_loader, criterion, device) 87 print(f"Test Loss: {test_loss:.4f}") 88 89 90 91if __name__ == "__main__": 92 test()
def
test(args=None):
32def test(args=None): 33 """ 34 Runs evaluation of a pre-trained ResFCN model on a test dataset. 35 36 This function: 37 - Loads the test dataset using `PhysGenDataset`. 38 - Loads a ResFCN model with parameters specified in `args`. 39 - Loads model weights from a checkpoint file. 40 - Evaluates the model on the test dataset using a specified loss criterion. 41 42 Parameters: 43 - args (Namespace or None): Optional argument namespace, typically from argparse. 44 Required fields in `args`: 45 - device (str): Device for computation ("cuda" or "cpu"). 46 - data_variation (str): Dataset variation to use for testing. 47 - input_type (str): Input type for dataset. 48 - output_type (str): Output type for dataset. 49 - fake_rgb_output (bool): Flag for dataset preprocessing. 50 - make_14_dividable_size (bool): Flag for resizing dataset images. 51 - batch_size (int): Batch size for DataLoader. 52 - resfcn_in_channels (int): Number of input channels for the ResFCN model. 53 - resfcn_hidden_channels (int): Number of hidden channels for ResFCN. 54 - resfcn_out_channels (int): Number of output channels for ResFCN. 55 - resfcn_num_blocks (int): Number of residual blocks in ResFCN. 56 - model_params_path (str): Path to the saved model checkpoint. 57 - loss (str): Loss type to use for evaluation ("l1" or "crossentropy"). 58 59 Returns: 60 - None: Prints the test loss to stdout. 61 """ 62 if args is None: 63 args = parse_args() 64 65 device = torch.device(args.device if torch.cuda.is_available() else "cpu") 66 67 # Dataset loading 68 if args.model.lower() == "residual_design_model": 69 test_dataset = PhysGenResidualDataset(variation=args.data_variation, mode="test", 70 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 71 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 72 else: 73 test_dataset = PhysGenDataset(variation=args.data_variation, mode="test", input_type=args.input_type, output_type=args.output_type, 74 fake_rgb_output=args.fake_rgb_output, make_14_dividable_size=args.make_14_dividable_size, 75 reflexion_channels=args.reflexion_channels, reflexion_steps=args.reflexion_steps, reflexions_as_channels=args.reflexions_as_channels) 76 77 test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 78 79 # Model Loading 80 model, _ = load_and_get_model(args.model_params_path, device, criterion=None) 81 82 # Criterion for evaluation 83 criterion = get_loss(args.loss, args) 84 print(f"Used Loss: {args.loss}") 85 86 # Run evaluation 87 test_loss = evaluate(model, test_loader, criterion, device) 88 print(f"Test Loss: {test_loss:.4f}")
Runs evaluation of a pre-trained ResFCN model on a test dataset.
This function:
- Loads the test dataset using
PhysGenDataset. - Loads a ResFCN model with parameters specified in
args. - Loads model weights from a checkpoint file.
- Evaluates the model on the test dataset using a specified loss criterion.
Parameters:
- args (Namespace or None): Optional argument namespace, typically from argparse.
Required fields in
args:- device (str): Device for computation ("cuda" or "cpu").
- data_variation (str): Dataset variation to use for testing.
- input_type (str): Input type for dataset.
- output_type (str): Output type for dataset.
- fake_rgb_output (bool): Flag for dataset preprocessing.
- make_14_dividable_size (bool): Flag for resizing dataset images.
- batch_size (int): Batch size for DataLoader.
- resfcn_in_channels (int): Number of input channels for the ResFCN model.
- resfcn_hidden_channels (int): Number of hidden channels for ResFCN.
- resfcn_out_channels (int): Number of output channels for ResFCN.
- resfcn_num_blocks (int): Number of residual blocks in ResFCN.
- model_params_path (str): Path to the saved model checkpoint.
- loss (str): Loss type to use for evaluation ("l1" or "crossentropy").
Returns:
- None: Prints the test loss to stdout.