Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,595 Bytes
adf1965 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import os
import torch
import argparse
import torchvision
import pytorch_lightning
import numpy as np
from PIL import Image
from torch import autocast
from einops import rearrange
from functools import partial
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
def un_norm(x):
return (x+1.0)/2.0
def un_norm_clip(x):
x[0,:,:] = x[0,:,:] * 0.26862954 + 0.48145466
x[1,:,:] = x[1,:,:] * 0.26130258 + 0.4578275
x[2,:,:] = x[2,:,:] * 0.27577711 + 0.40821073
return x
class DataModuleFromConfig(pytorch_lightning.LightningDataModule):
def __init__(self,
batch_size, # 1
test=None, # {...}
wrap=False, # False
shuffle=False,
shuffle_test_loader=False,
use_worker_init_fn=False):
super().__init__()
self.batch_size = batch_size
self.num_workers = batch_size * 2
self.use_worker_init_fn = use_worker_init_fn
self.wrap = wrap
self.datasets = instantiate_from_config(test)
self.dataloader = torch.utils.data.DataLoader(self.datasets,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=shuffle,
worker_init_fn=None)
if __name__ == "__main__":
# =============================================================
# 处理 opt
# =============================================================
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--base", type=str, default="configs/test.yaml")
parser.add_argument("-c", "--ckpt", type=str, default="./model.ckpt")
parser.add_argument("-s", "--seed", type=int, default=42)
parser.add_argument("-d", "--ddim", type=int, default=64)
opt = parser.parse_args()
# =============================================================
# 设置 seed
# =============================================================
seed_everything(opt.seed)
# =============================================================
# 初始化 config
# =============================================================
config = OmegaConf.load(f"{opt.base}")
# =============================================================
# 加载 dataloader
# =============================================================
data = instantiate_from_config(config.data)
print(f"{data.__class__.__name__}, {len(data.dataloader)}")
# =============================================================
# 加载 model
# =============================================================
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(opt.ckpt, map_location="cpu")["state_dict"], strict=False)
model.cuda()
model.eval()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
sampler = DDIMSampler(model)
# =============================================================
# 设置精度
# =============================================================
precision_scope = autocast
# =============================================================
# 开始测试
# =============================================================
os.makedirs("results/Unpaired_Direst")
os.makedirs("results/Unpaired_Concatenation")
with torch.no_grad():
with precision_scope("cuda"):
for i,batch in enumerate(data.dataloader):
# 加载数据
inpaint = batch["inpaint_image"].to(torch.float16).to(device)
reference = batch["ref_imgs"].to(torch.float16).to(device)
mask = batch["inpaint_mask"].to(torch.float16).to(device)
hint = batch["hint"].to(torch.float16).to(device)
truth = batch["GT"].to(torch.float16).to(device)
# 数据处理
encoder_posterior_inpaint = model.first_stage_model.encode(inpaint)
z_inpaint = model.scale_factor * (encoder_posterior_inpaint.sample()).detach()
mask_resize = torchvision.transforms.Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(mask)
test_model_kwargs = {}
test_model_kwargs['inpaint_image'] = z_inpaint
test_model_kwargs['inpaint_mask'] = mask_resize
shape = (model.channels, model.image_size, model.image_size)
# 预测结果
samples, _ = sampler.sample(S=opt.ddim,
batch_size=1,
shape=shape,
pose=hint,
conditioning=reference,
verbose=False,
eta=0,
test_model_kwargs=test_model_kwargs)
samples = 1. / model.scale_factor * samples
x_samples = model.first_stage_model.decode(samples[:,:4,:,:])
x_samples_ddim = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
x_checked_image=x_samples_ddim
x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
# 保存图像
all_img=[]
all_img_C = []
# all_img.append(un_norm(truth[0]).cpu())
# all_img.append(un_norm(inpaint[0]).cpu())
# all_img.append(un_norm_clip(torchvision.transforms.Resize([512,512])(reference)[0].cpu()))
mask = mask.cpu().permute(0, 2, 3, 1).numpy()
mask = torch.from_numpy(mask).permute(0, 3, 1, 2)
truth = torch.clamp((truth + 1.0) / 2.0, min=0.0, max=1.0)
truth = truth.cpu().permute(0, 2, 3, 1).numpy()
truth = torch.from_numpy(truth).permute(0, 3, 1, 2)
x_checked_image_torch_C = x_checked_image_torch*(1-mask) + truth.cpu()*mask
x_checked_image_torch = torch.nn.functional.interpolate(x_checked_image_torch.float(), size=[512,384])
x_checked_image_torch_C = torch.nn.functional.interpolate(x_checked_image_torch_C.float(), size=[512,384])
all_img.append(x_checked_image_torch[0])
all_img_C.append(x_checked_image_torch_C[0])
grid = torch.stack(all_img, 0)
grid = torchvision.utils.make_grid(grid)
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img.save("results/Unpaired_Direst/"+str(i)+".png")
grid_C = torch.stack(all_img_C, 0)
grid_C = torchvision.utils.make_grid(grid_C)
grid_C = 255. * rearrange(grid_C, 'c h w -> h w c').cpu().numpy()
img_C = Image.fromarray(grid_C.astype(np.uint8))
img_C.save("results/Unpaired_Concatenation/"+str(i)+".png") |