INR-Harmon / inference.py
WindVChen's picture
Init
6710c89
import os
import argparse
import albumentations
from albumentations import Resize
import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from model.build_model import build_model
from datasets.build_dataset import dataset_generator
from utils import misc, metrics
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--workers', type=int, default=1,
metavar='N', help='Dataloader threads.')
parser.add_argument('--batch_size', type=int, default=1,
help='You can override model batch size by specify positive number.')
parser.add_argument('--device', type=str, default='cuda',
help="Whether use cuda, 'cuda' or 'cpu'.")
parser.add_argument('--save_path', type=str, default="./logs",
help='Where to save logs and checkpoints.')
parser.add_argument('--dataset_path', type=str, default=r".\iHarmony4",
help='Dataset path.')
parser.add_argument('--base_size', type=int, default=256,
help='Base size. Resolution of the image input into the Encoder')
parser.add_argument('--input_size', type=int, default=256,
help='Input size. Resolution of the image that want to be generated by the Decoder')
parser.add_argument('--INR_input_size', type=int, default=256,
help='INR input size. Resolution of the image that want to be generated by the Decoder. '
'Should be the same as `input_size`')
parser.add_argument('--INR_MLP_dim', type=int, default=32,
help='Number of channels for INR linear layer.')
parser.add_argument('--LUT_dim', type=int, default=7,
help='Dim of the output LUT. Refer to https://ieeexplore.ieee.org/abstract/document/9206076')
parser.add_argument('--activation', type=str, default='leakyrelu_pe',
help='INR activation layer type: leakyrelu_pe, sine')
parser.add_argument('--pretrained', type=str,
default=r'.\pretrained_models\Resolution_RAW_iHarmony4.pth',
help='Pretrained weight path')
parser.add_argument('--param_factorize_dim', type=int,
default=10,
help='The intermediate dimensions of the factorization of the predicted MLP parameters. '
'Refer to https://arxiv.org/abs/2011.12026')
parser.add_argument('--embedding_type', type=str,
default="CIPS_embed",
help='Which embedding_type to use.')
parser.add_argument('--optim', type=str,
default='adamw',
help='Which optimizer to use.')
parser.add_argument('--INRDecode', action="store_false",
help='Whether INR decoder. Set it to False if you want to test the baseline '
'(https://github.com/SamsungLabs/image_harmonization)')
parser.add_argument('--isMoreINRInput', action="store_false",
help='Whether to cat RGB and mask. See Section 3.4 in the paper.')
parser.add_argument('--hr_train', action="store_true",
help='Whether use hr_train. See section 3.4 in the paper.')
parser.add_argument('--isFullRes', action="store_true",
help='Whether for original resolution. See section 3.4 in the paper.')
opt = parser.parse_args()
opt.save_path = misc.increment_path(os.path.join(opt.save_path, "test1"))
return opt
def inference(val_loader, model, logger, opt):
current_process = 10
model.eval()
metric_log = {
'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
}
lut_metric_log = {
'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
}
for step, batch in enumerate(val_loader):
composite_image = batch['composite_image'].to(opt.device)
real_image = batch['real_image'].to(opt.device)
mask = batch['mask'].to(opt.device)
category = batch['category']
fg_INR_coordinates = batch['fg_INR_coordinates'].to(opt.device)
with torch.no_grad():
fg_content_bg_appearance_construct, _, lut_transform_image = model(
composite_image,
mask,
fg_INR_coordinates,
)
if opt.INRDecode:
pred_fg_image = fg_content_bg_appearance_construct[-1]
else:
pred_fg_image = misc.lin2img(fg_content_bg_appearance_construct,
val_loader.dataset.INR_dataset.size) if fg_content_bg_appearance_construct is not None else None
if not opt.INRDecode:
pred_harmonized_image = None
else:
pred_harmonized_image = pred_fg_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.))
lut_transform_image = lut_transform_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.))
misc.visualize(real_image, composite_image, mask, pred_fg_image,
pred_harmonized_image, lut_transform_image, opt, -1, show=False,
wandb=False, isAll=True, step=step)
if opt.INRDecode:
mse, fmse, psnr, ssim = metrics.calc_metrics(misc.normalize(pred_harmonized_image, opt, 'inv'),
misc.normalize(real_image, opt, 'inv'), mask)
lut_mse, lut_fmse, lut_psnr, lut_ssim = metrics.calc_metrics(misc.normalize(lut_transform_image, opt, 'inv'),
misc.normalize(real_image, opt, 'inv'), mask)
for idx in range(len(category)):
if opt.INRDecode:
metric_log[category[idx]]['Samples'] += 1
metric_log[category[idx]]['MSE'] += mse[idx]
metric_log[category[idx]]['fMSE'] += fmse[idx]
metric_log[category[idx]]['PSNR'] += psnr[idx]
metric_log[category[idx]]['SSIM'] += ssim[idx]
metric_log['All']['Samples'] += 1
metric_log['All']['MSE'] += mse[idx]
metric_log['All']['fMSE'] += fmse[idx]
metric_log['All']['PSNR'] += psnr[idx]
metric_log['All']['SSIM'] += ssim[idx]
lut_metric_log[category[idx]]['Samples'] += 1
lut_metric_log[category[idx]]['MSE'] += lut_mse[idx]
lut_metric_log[category[idx]]['fMSE'] += lut_fmse[idx]
lut_metric_log[category[idx]]['PSNR'] += lut_psnr[idx]
lut_metric_log[category[idx]]['SSIM'] += lut_ssim[idx]
lut_metric_log['All']['Samples'] += 1
lut_metric_log['All']['MSE'] += lut_mse[idx]
lut_metric_log['All']['fMSE'] += lut_fmse[idx]
lut_metric_log['All']['PSNR'] += lut_psnr[idx]
lut_metric_log['All']['SSIM'] += lut_ssim[idx]
if (step + 1) / len(val_loader) * 100 >= current_process:
logger.info(f'Processing: {current_process}')
current_process += 10
logger.info('=========================')
for key in metric_log.keys():
if opt.INRDecode:
msg = f"{key}-'MSE': {metric_log[key]['MSE'] / metric_log[key]['Samples']:.2f}\n" \
f"{key}-'fMSE': {metric_log[key]['fMSE'] / metric_log[key]['Samples']:.2f}\n" \
f"{key}-'PSNR': {metric_log[key]['PSNR'] / metric_log[key]['Samples']:.2f}\n" \
f"{key}-'SSIM': {metric_log[key]['SSIM'] / metric_log[key]['Samples']:.4f}\n" \
f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \
f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n"
else:
msg = f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \
f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n"
logger.info(msg)
logger.info('=========================')
def main_process(opt):
logger = misc.create_logger(os.path.join(opt.save_path, "log.txt"))
cudnn.benchmark = True
valset_path = os.path.join(opt.dataset_path, "IHD_test.txt")
opt.transform_mean = [.5, .5, .5]
opt.transform_var = [.5, .5, .5]
torch_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(opt.transform_mean, opt.transform_var)])
valset_alb_transform = albumentations.Compose([Resize(opt.input_size, opt.input_size)],
additional_targets={'real_image': 'image', 'object_mask': 'image'})
valset = dataset_generator(valset_path, valset_alb_transform, torch_transform, opt, mode='Val')
val_loader = DataLoader(valset, opt.batch_size, shuffle=False, drop_last=False, pin_memory=True,
num_workers=opt.workers, persistent_workers=True)
model = build_model(opt).to(opt.device)
logger.info(f"Load pretrained weight from {opt.pretrained}")
load_dict = torch.load(opt.pretrained)['model']
for k in load_dict.keys():
if k not in model.state_dict().keys():
print(f"Skip {k}")
model.load_state_dict(load_dict, strict=False)
inference(val_loader, model, logger, opt)
if __name__ == '__main__':
opt = parse_args()
os.makedirs(opt.save_path, exist_ok=True)
main_process(opt)