diffcr-models / UnCRtainTS /model /test_reconstruct.py
XavierJiezou's picture
Upload folder using huggingface_hub
3c8ff2e verified
"""
Script for image reconstruction inference with pre-trained models
Author: Patrick Ebel (github/PatrickTUM), based on the scripts of
Vivien Sainte Fare Garnot (github/VSainteuf)
License: MIT
"""
import os
import sys
import json
import pprint
import argparse
from parse_args import create_parser
import torch
dirname = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(dirname))
from src import utils
from src.model_utils import get_model, load_checkpoint
from train_reconstruct import (
iterate,
save_results,
prepare_output,
import_from_path,
seed_packages,
)
from data.dataLoader import SEN12MSCR, get_pairedS1
from torch.utils.tensorboard import SummaryWriter
parser = create_parser(mode="test")
test_config = parser.parse_args()
# grab the PID so we can look it up in the logged config for server-side process management
test_config.pid = os.getpid()
# related to flag --use_custom:
# define custom target S2 patches (these will be mosaiced into a single sample), and fetch associated target S1 patches as well as input data
# (TODO: keeping this hard-coded until a more convenient way to pass it as an argument comes about ...)
targ_s2 = [
f"ROIs1868/73/S2/14/s2_ROIs1868_73_ImgNo_14_2018-06-21_patch_{pdx}.tif"
for pdx in [171, 172, 173, 187, 188, 189, 203, 204, 205]
]
# load previous config from training directories
# if no custom path to config file is passed, try fetching config file at default location
conf_path = (
os.path.join(
dirname, test_config.weight_folder, test_config.experiment_name, "conf.json"
)
if not test_config.load_config
else test_config.load_config
)
if os.path.isfile(conf_path):
with open(conf_path) as file:
model_config = json.loads(file.read())
t_args = argparse.Namespace()
# do not overwrite the following flags by their respective values in the config file
no_overwrite = [
"pid",
"device",
"resume_at",
"trained_checkp",
"res_dir",
"weight_folder",
"root1",
"root2",
"root3",
"max_samples_count",
"batch_size",
"display_step",
"plot_every",
"export_every",
"input_t",
"region",
"min_cov",
"max_cov",
]
conf_dict = {
key: val for key, val in model_config.items() if key not in no_overwrite
}
for key, val in vars(test_config).items():
if key in no_overwrite:
conf_dict[key] = val
t_args.__dict__.update(conf_dict)
config = parser.parse_args(namespace=t_args)
else:
config = test_config # otherwise, keep passed flags without any overwriting
config = utils.str2list(config, ["encoder_widths", "decoder_widths", "out_conv"])
if config.pretrain:
config.batch_size = 32
experime_dir = os.path.join(config.res_dir, config.experiment_name)
if not os.path.exists(experime_dir):
os.makedirs(experime_dir)
with open(os.path.join(experime_dir, "conf.json"), "w") as file:
file.write(json.dumps(vars(config), indent=4))
# seed everything
seed_packages(config.rdm_seed)
if __name__ == "__main__":
pprint.pprint(config)
# instantiate tensorboard logger
writer = SummaryWriter(os.path.join(config.res_dir, config.experiment_name))
if config.use_custom:
print("Testing on custom data samples")
# define a dictionary for the custom sample, with customized ROI and time points
custom = [
{
"input": {
"S1": [
get_pairedS1(targ_s2, config.root1, mod="s1", time=tdx)
for tdx in range(0, 3)
],
"S2": [
get_pairedS1(targ_s2, config.root1, mod="s2", time=tdx)
for tdx in range(0, 3)
],
},
"target": {
"S1": [get_pairedS1(targ_s2, config.root1, mod="s1")],
"S2": [targ_s2],
},
}
]
def main(config):
device = torch.device(config.device)
prepare_output(config)
model = get_model(config)
model = model.to(device)
config.N_params = utils.get_ntrainparams(model)
print(f"TOTAL TRAINABLE PARAMETERS: {config.N_params}\n")
# print(model)
# get data loader
if config.pretrain:
dt_test = SEN12MSCR(
config.root3,
split="test",
region=config.region,
sample_type=config.sample_type,
)
else:
pass
dt_test = torch.utils.data.Subset(
dt_test, range(0, min(config.max_samples_count, len(dt_test)))
)
test_loader = torch.utils.data.DataLoader(
dt_test, batch_size=config.batch_size, shuffle=False
)
# Load weights
ckpt_n = f"_epoch_{config.resume_at}" if config.resume_at > 0 else ""
load_checkpoint(config, config.weight_folder, model, f"model{ckpt_n}")
# Inference
print("Testing . . .")
model.eval()
_, test_img_metrics = iterate(
model,
data_loader=test_loader,
config=config,
writer=writer,
mode="test",
epoch=1,
device=device,
)
print(f"\nTest image metrics: {test_img_metrics}")
save_results(
test_img_metrics,
os.path.join(config.res_dir, config.experiment_name),
split="test",
)
print(
f"\nLogged test metrics to path {os.path.join(config.res_dir, config.experiment_name)}"
)
if __name__ == "__main__":
main(config)
exit()