fptvton1 / test.py
basso4's picture
Upload 1471 files
adf1965 verified
raw
history blame
7.6 kB
import os
import torch
import argparse
import torchvision
import pytorch_lightning
import numpy as np
from PIL import Image
from torch import autocast
from einops import rearrange
from functools import partial
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
def un_norm(x):
return (x+1.0)/2.0
def un_norm_clip(x):
x[0,:,:] = x[0,:,:] * 0.26862954 + 0.48145466
x[1,:,:] = x[1,:,:] * 0.26130258 + 0.4578275
x[2,:,:] = x[2,:,:] * 0.27577711 + 0.40821073
return x
class DataModuleFromConfig(pytorch_lightning.LightningDataModule):
def __init__(self,
batch_size, # 1
test=None, # {...}
wrap=False, # False
shuffle=False,
shuffle_test_loader=False,
use_worker_init_fn=False):
super().__init__()
self.batch_size = batch_size
self.num_workers = batch_size * 2
self.use_worker_init_fn = use_worker_init_fn
self.wrap = wrap
self.datasets = instantiate_from_config(test)
self.dataloader = torch.utils.data.DataLoader(self.datasets,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=shuffle,
worker_init_fn=None)
if __name__ == "__main__":
# =============================================================
# 处理 opt
# =============================================================
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--base", type=str, default="configs/test.yaml")
parser.add_argument("-c", "--ckpt", type=str, default="./model.ckpt")
parser.add_argument("-s", "--seed", type=int, default=42)
parser.add_argument("-d", "--ddim", type=int, default=64)
opt = parser.parse_args()
# =============================================================
# 设置 seed
# =============================================================
seed_everything(opt.seed)
# =============================================================
# 初始化 config
# =============================================================
config = OmegaConf.load(f"{opt.base}")
# =============================================================
# 加载 dataloader
# =============================================================
data = instantiate_from_config(config.data)
print(f"{data.__class__.__name__}, {len(data.dataloader)}")
# =============================================================
# 加载 model
# =============================================================
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(opt.ckpt, map_location="cpu")["state_dict"], strict=False)
model.cuda()
model.eval()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
sampler = DDIMSampler(model)
# =============================================================
# 设置精度
# =============================================================
precision_scope = autocast
# =============================================================
# 开始测试
# =============================================================
os.makedirs("results/Unpaired_Direst")
os.makedirs("results/Unpaired_Concatenation")
with torch.no_grad():
with precision_scope("cuda"):
for i,batch in enumerate(data.dataloader):
# 加载数据
inpaint = batch["inpaint_image"].to(torch.float16).to(device)
reference = batch["ref_imgs"].to(torch.float16).to(device)
mask = batch["inpaint_mask"].to(torch.float16).to(device)
hint = batch["hint"].to(torch.float16).to(device)
truth = batch["GT"].to(torch.float16).to(device)
# 数据处理
encoder_posterior_inpaint = model.first_stage_model.encode(inpaint)
z_inpaint = model.scale_factor * (encoder_posterior_inpaint.sample()).detach()
mask_resize = torchvision.transforms.Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(mask)
test_model_kwargs = {}
test_model_kwargs['inpaint_image'] = z_inpaint
test_model_kwargs['inpaint_mask'] = mask_resize
shape = (model.channels, model.image_size, model.image_size)
# 预测结果
samples, _ = sampler.sample(S=opt.ddim,
batch_size=1,
shape=shape,
pose=hint,
conditioning=reference,
verbose=False,
eta=0,
test_model_kwargs=test_model_kwargs)
samples = 1. / model.scale_factor * samples
x_samples = model.first_stage_model.decode(samples[:,:4,:,:])
x_samples_ddim = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
x_checked_image=x_samples_ddim
x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
# 保存图像
all_img=[]
all_img_C = []
# all_img.append(un_norm(truth[0]).cpu())
# all_img.append(un_norm(inpaint[0]).cpu())
# all_img.append(un_norm_clip(torchvision.transforms.Resize([512,512])(reference)[0].cpu()))
mask = mask.cpu().permute(0, 2, 3, 1).numpy()
mask = torch.from_numpy(mask).permute(0, 3, 1, 2)
truth = torch.clamp((truth + 1.0) / 2.0, min=0.0, max=1.0)
truth = truth.cpu().permute(0, 2, 3, 1).numpy()
truth = torch.from_numpy(truth).permute(0, 3, 1, 2)
x_checked_image_torch_C = x_checked_image_torch*(1-mask) + truth.cpu()*mask
x_checked_image_torch = torch.nn.functional.interpolate(x_checked_image_torch.float(), size=[512,384])
x_checked_image_torch_C = torch.nn.functional.interpolate(x_checked_image_torch_C.float(), size=[512,384])
all_img.append(x_checked_image_torch[0])
all_img_C.append(x_checked_image_torch_C[0])
grid = torch.stack(all_img, 0)
grid = torchvision.utils.make_grid(grid)
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img.save("results/Unpaired_Direst/"+str(i)+".png")
grid_C = torch.stack(all_img_C, 0)
grid_C = torchvision.utils.make_grid(grid_C)
grid_C = 255. * rearrange(grid_C, 'c h w -> h w c').cpu().numpy()
img_C = Image.fromarray(grid_C.astype(np.uint8))
img_C.save("results/Unpaired_Concatenation/"+str(i)+".png")