File size: 7,595 Bytes
adf1965
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import torch
import argparse
import torchvision
import pytorch_lightning
import numpy as np

from PIL import Image
from torch import autocast
from einops import rearrange
from functools import partial
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler

def un_norm(x):
    return (x+1.0)/2.0

def un_norm_clip(x):
    x[0,:,:] = x[0,:,:] * 0.26862954 + 0.48145466
    x[1,:,:] = x[1,:,:] * 0.26130258 + 0.4578275
    x[2,:,:] = x[2,:,:] * 0.27577711 + 0.40821073
    return x

class DataModuleFromConfig(pytorch_lightning.LightningDataModule):
    def __init__(self, 
                 batch_size,                        # 1
                 test=None,                         # {...}
                 wrap=False,                        # False
                 shuffle=False,             
                 shuffle_test_loader=False,
                 use_worker_init_fn=False):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = batch_size * 2
        self.use_worker_init_fn = use_worker_init_fn
        self.wrap = wrap
        self.datasets = instantiate_from_config(test)
        self.dataloader = torch.utils.data.DataLoader(self.datasets, 
                                                      batch_size=self.batch_size,
                                                      num_workers=self.num_workers,
                                                      shuffle=shuffle,
                                                      worker_init_fn=None)



if __name__ == "__main__":
    # =============================================================
    # 处理 opt
    # =============================================================
    parser = argparse.ArgumentParser()
    parser.add_argument("-b", "--base", type=str, default="configs/test.yaml")
    parser.add_argument("-c", "--ckpt", type=str, default="./model.ckpt")
    parser.add_argument("-s", "--seed", type=int, default=42)
    parser.add_argument("-d", "--ddim", type=int, default=64)
    opt = parser.parse_args()

    # =============================================================
    # 设置 seed
    # =============================================================
    seed_everything(opt.seed)

    # =============================================================
    # 初始化 config
    # =============================================================
    config = OmegaConf.load(f"{opt.base}")

    # =============================================================
    # 加载 dataloader
    # =============================================================
    data = instantiate_from_config(config.data)
    print(f"{data.__class__.__name__}, {len(data.dataloader)}")

    # =============================================================
    # 加载 model
    # =============================================================
    model = instantiate_from_config(config.model)
    model.load_state_dict(torch.load(opt.ckpt, map_location="cpu")["state_dict"], strict=False)
    model.cuda()
    model.eval()
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)
    sampler = DDIMSampler(model)

    # =============================================================
    # 设置精度
    # =============================================================
    precision_scope = autocast

    # =============================================================
    # 开始测试
    # =============================================================
    os.makedirs("results/Unpaired_Direst")
    os.makedirs("results/Unpaired_Concatenation")

    with torch.no_grad():
        with precision_scope("cuda"):
            for i,batch in enumerate(data.dataloader):
                # 加载数据
                inpaint = batch["inpaint_image"].to(torch.float16).to(device)
                reference = batch["ref_imgs"].to(torch.float16).to(device)
                mask = batch["inpaint_mask"].to(torch.float16).to(device)
                hint = batch["hint"].to(torch.float16).to(device)
                truth = batch["GT"].to(torch.float16).to(device)
                # 数据处理
                encoder_posterior_inpaint = model.first_stage_model.encode(inpaint)
                z_inpaint = model.scale_factor * (encoder_posterior_inpaint.sample()).detach()
                mask_resize = torchvision.transforms.Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(mask)
                test_model_kwargs = {}
                test_model_kwargs['inpaint_image'] = z_inpaint
                test_model_kwargs['inpaint_mask'] = mask_resize
                shape = (model.channels, model.image_size, model.image_size)
                # 预测结果
                samples, _ = sampler.sample(S=opt.ddim,
                                                 batch_size=1,
                                                 shape=shape,
                                                 pose=hint,
                                                 conditioning=reference,
                                                 verbose=False,
                                                 eta=0,
                                                 test_model_kwargs=test_model_kwargs)
                samples = 1. / model.scale_factor * samples
                x_samples = model.first_stage_model.decode(samples[:,:4,:,:])

                x_samples_ddim = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
                x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
                x_checked_image=x_samples_ddim
                x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)                
                # 保存图像
                all_img=[]
                all_img_C = []
                # all_img.append(un_norm(truth[0]).cpu())
                # all_img.append(un_norm(inpaint[0]).cpu())
                # all_img.append(un_norm_clip(torchvision.transforms.Resize([512,512])(reference)[0].cpu()))
                mask = mask.cpu().permute(0, 2, 3, 1).numpy()
                mask = torch.from_numpy(mask).permute(0, 3, 1, 2)
                truth = torch.clamp((truth + 1.0) / 2.0, min=0.0, max=1.0)
                truth = truth.cpu().permute(0, 2, 3, 1).numpy()
                truth = torch.from_numpy(truth).permute(0, 3, 1, 2)
                x_checked_image_torch_C = x_checked_image_torch*(1-mask) + truth.cpu()*mask
                x_checked_image_torch = torch.nn.functional.interpolate(x_checked_image_torch.float(), size=[512,384])
                x_checked_image_torch_C = torch.nn.functional.interpolate(x_checked_image_torch_C.float(), size=[512,384])
                
                all_img.append(x_checked_image_torch[0])
                all_img_C.append(x_checked_image_torch_C[0])
                grid = torch.stack(all_img, 0)
                grid = torchvision.utils.make_grid(grid)
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                img = Image.fromarray(grid.astype(np.uint8))
                img.save("results/Unpaired_Direst/"+str(i)+".png")

                grid_C = torch.stack(all_img_C, 0)
                grid_C = torchvision.utils.make_grid(grid_C)
                grid_C = 255. * rearrange(grid_C, 'c h w -> h w c').cpu().numpy()
                img_C = Image.fromarray(grid_C.astype(np.uint8))
                img_C.save("results/Unpaired_Concatenation/"+str(i)+".png")