basso4 commited on
Commit
a7e0ced
1 Parent(s): b7a16b6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +267 -0
app.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append('./')
4
+ import numpy as np
5
+ import argparse
6
+
7
+ import torch
8
+ import torchvision
9
+ import pytorch_lightning
10
+ from torch import autocast
11
+ from torchvision import transforms
12
+ from pytorch_lightning import seed_everything
13
+
14
+
15
+ from einops import rearrange
16
+ from functools import partial
17
+ from omegaconf import OmegaConf
18
+
19
+ from PIL import Image
20
+ from typing import List
21
+ import matplotlib.pyplot as plt
22
+
23
+ import gradio as gr
24
+ import apply_net
25
+
26
+ from torchvision.transforms.functional import to_pil_image
27
+ # from tools.mask_vitonhd import get_img_agnostic
28
+
29
+ from utils_mask import get_mask_location
30
+ from preprocess.humanparsing.run_parsing import Parsing
31
+ from preprocess.openpose.run_openpose import OpenPose
32
+ from ldm.util import instantiate_from_config, get_obj_from_str
33
+ from ldm.models.diffusion.ddim import DDIMSampler
34
+ from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
35
+
36
+
37
+ def un_norm(x):
38
+ return (x+1.0)/2.0
39
+
40
+
41
+ def un_norm_clip(x):
42
+ x[0,:,:] = x[0,:,:] * 0.26862954 + 0.48145466
43
+ x[1,:,:] = x[1,:,:] * 0.26130258 + 0.4578275
44
+ x[2,:,:] = x[2,:,:] * 0.27577711 + 0.40821073
45
+ return x
46
+
47
+
48
+ class DataModuleFromConfig(pytorch_lightning.LightningDataModule):
49
+ def __init__(self,
50
+ batch_size,
51
+ test=None,
52
+ wrap=False,
53
+ shuffle=False,
54
+ shuffle_test_loader=False,
55
+ use_worker_init_fn=False):
56
+ super().__init__()
57
+ self.batch_size = batch_size
58
+ self.num_workers = batch_size * 2
59
+ self.use_worker_init_fn = use_worker_init_fn
60
+ self.wrap = wrap
61
+ self.datasets = instantiate_from_config(test)
62
+ self.dataloader = torch.utils.data.Dataloader(self.datasets,
63
+ batch_size=self.batch_size,
64
+ num_workers=self.num_workers,
65
+ shuffle=shuffle,
66
+ use_worker_init_fn=None)
67
+
68
+
69
+ if __name__ == "__main__":
70
+
71
+ parser = argparse.ArgumentParser(description="Script for demo model")
72
+ parser.add_argument("-b", "--base", type=str, default=r"configs/test_vitonhd.yaml")
73
+ parser.add_argument("-c", "--ckpt", type=str, default=r"checkpoints/hitonhd.ckpt")
74
+ parser.add_argument("-s", "--seed", type=str, default=42)
75
+ parser.add_argument("-d", "--ddim", type=str, default=64)
76
+ args = parser.parse_args()
77
+
78
+ seed_everything(args.seed)
79
+ config = OmegaConf.load(f"{args.base}")
80
+ # data = instantiate_from_config(config.data)
81
+
82
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
83
+
84
+ # model = instantiate_from_config(config.model)
85
+ # model.load_state_dict(torch.load(args.ckpt, map_location="cpu")["state_dict"], strict=False)
86
+ # model.cuda()
87
+ # model.eval()
88
+ # model = model.to(device)
89
+ # sampler = DDIMSampler(model)
90
+
91
+ precision_scope = autocast
92
+
93
+
94
+ def start_tryon(human_img,garm_img):
95
+ #load human image
96
+ human_img = human_img.convert("RGB").resize((768,1024))
97
+
98
+ #mask
99
+ tensor_transfrom = transforms.Compose(
100
+ [
101
+ transforms.ToTensor(),
102
+ transforms.Normalize([0.5], [0.5]),
103
+ ]
104
+ )
105
+
106
+ parsing_model = Parsing(0)
107
+ openose_model = OpenPose(0)
108
+ openose_model.preprocessor.body_estimation.model.to(device)
109
+
110
+ keypoints = openose_model(human_img.resize((384,512)))
111
+ model_parse, _ = parsing_model(human_img.resize((384,512)))
112
+ mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
113
+ mask = mask.resize((768, 1024))
114
+ mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
115
+ mask_gray = to_pil_image((mask_gray+1.0)/2.0)
116
+ # mask_gray.save(r'D:\Capstone_Project\cat_dm\gradio_demo\output\maskgray_output.png')
117
+
118
+ #densepose
119
+ human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
120
+ human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
121
+ args = apply_net.create_argument_parser().parse_args(('show',
122
+ './configs/configs_densepose/densepose_rcnn_R_50_FPN_s1x.yaml',
123
+ './ckpt/densepose/model_final_162be9.pkl',
124
+ 'dp_segm', '-v',
125
+ '--opts',
126
+ 'MODEL.DEVICE',
127
+ 'cuda'))
128
+ # verbosity = getattr(args, "verbosity", None)
129
+ pose_img = args.func(args,human_img_arg)
130
+ pose_img = pose_img[:,:,::-1]
131
+ pose_img = Image.fromarray(pose_img).resize((768,1024))
132
+
133
+ #preprocessing image
134
+ human_img = human_img.convert("RGB").resize((512, 512))
135
+ human_img = torchvision.transforms.ToTensor()(human_img)
136
+
137
+ garm_img = garm_img.convert("RGB").resize((512, 512))
138
+ garm_img = torchvision.transforms.ToTensor()(garm_img)
139
+
140
+ mask = mask.convert("L").resize((512,512))
141
+ mask = torchvision.transforms.ToTensor()(mask)
142
+ mask = 1-mask
143
+
144
+ pose_img = pose_img.convert("RGB").resize((512, 512))
145
+ pose_img = torchvision.transforms.ToTensor()(pose_img)
146
+
147
+ #Normalize
148
+ human_img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(human_img)
149
+ garm_img = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
150
+ (0.26862954, 0.26130258, 0.27577711))(garm_img)
151
+ pose_img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(pose_img)
152
+
153
+ #create inpaint & hint
154
+ inpaint = human_img * mask
155
+ hint = torchvision.transforms.Resize((512, 512))(garm_img)
156
+ hint = torch.cat((hint, pose_img), dim=0)
157
+
158
+ # {"human_img": human_img, # [3, 512, 512]
159
+ # "inpaint_image": inpaint, # [3, 512, 512]
160
+ # "inpaint_mask": mask, # [1, 512, 512]
161
+ # "garm_img": garm_img, # [3, 224, 224]
162
+ # "hint": hint, # [6, 512, 512]
163
+ # }
164
+
165
+
166
+ with torch.no_grad():
167
+ with precision_scope("cuda"):
168
+ #loading data
169
+ inpaint = inpaint.to(torch.float16).to(device)
170
+ reference = garm_img.to(torch.float16).to(device)
171
+ mask = mask.to(torch.float16).to(device)
172
+ hint = hint.to(torch.float16).to(device)
173
+ truth = human_img.to(torch.float16).to(device)
174
+
175
+ #data preprocessing
176
+ encoder_posterior_inpaint = model.first_stage_model.encode(inpaint)
177
+ z_inpaint = model.scale_factor * (encoder_posterior_inpaint.sample()).detach()
178
+ mask_resize = torchvision.transforms.Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(mask)
179
+ test_model_kwargs = {}
180
+ test_model_kwargs['inpaint_image'] = z_inpaint
181
+ test_model_kwargs['inpaint_mask'] = mask_resize
182
+ shape = (model.channels, model.image_size, model.image_size)
183
+
184
+ #predict
185
+ samples, _ = sampler.sample(S=args.ddim,
186
+ batch_size=1,
187
+ shape=shape,
188
+ pose=hint,
189
+ conditioning=reference,
190
+ verbose=False,
191
+ eta=0,
192
+ test_model_kwargs=test_model_kwargs)
193
+ samples = 1. / model.scale_factor * samples
194
+ x_samples = model.first_stage_model.decode(samples[:,:4,:,:])
195
+
196
+ x_samples_ddim = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
197
+ x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
198
+ x_checked_image=x_samples_ddim
199
+ x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
200
+
201
+ # Xử lý và trả về img và img_C
202
+ mask = mask.cpu().permute(0, 2, 3, 1).numpy()
203
+ mask = torch.from_numpy(mask).permute(0, 3, 1, 2)
204
+ truth = torch.clamp((truth + 1.0) / 2.0, min=0.0, max=1.0)
205
+ truth = truth.cpu().permute(0, 2, 3, 1).numpy()
206
+ truth = torch.from_numpy(truth).permute(0, 3, 1, 2)
207
+
208
+ x_checked_image_torch_C = x_checked_image_torch * (1 - mask) + truth.cpu() * mask
209
+ x_checked_image_torch = torch.nn.functional.interpolate(x_checked_image_torch.float(), size=[512, 384])
210
+ x_checked_image_torch_C = torch.nn.functional.interpolate(x_checked_image_torch_C.float(), size=[512, 384])
211
+
212
+ # Tạo và trả về hình ảnh img và img_C
213
+ img = x_checked_image_torch[0].cpu().numpy().transpose(1, 2, 0) # Chuyển về định dạng HWC
214
+ img_C = x_checked_image_torch_C[0].cpu().numpy().transpose(1, 2, 0) # Chuyển về định dạng HWC
215
+
216
+ return img, img_C, mask_gray
217
+
218
+
219
+ example_path = os.path.join(os.path.dirname(__file__), 'example')
220
+
221
+ garm_list = os.listdir(os.path.join(example_path,"cloth"))
222
+ garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
223
+
224
+ human_list = os.listdir(os.path.join(example_path,"human"))
225
+ human_list_path = [os.path.join(example_path,"human",human) for human in human_list]
226
+
227
+
228
+
229
+ image_blocks = gr.Blocks().queue()
230
+ with image_blocks as demo:
231
+ gr.Markdown("## CAT-DM 👕👔👚")
232
+ gr.Markdown("Virtual Try-on with your image and garment image")
233
+ with gr.Row():
234
+ with gr.Column():
235
+ imgs = gr.ImageEditor(sources='upload', type="pil", label='Human Picture or use Examples below', interactive=True)
236
+
237
+ example = gr.Examples(
238
+ inputs=imgs,
239
+ examples_per_page=10,
240
+ examples=human_list_path
241
+ )
242
+
243
+ with gr.Column():
244
+ garm_img = gr.Image(label="Garment", sources='upload', type="pil")
245
+
246
+ example = gr.Examples(
247
+ inputs=garm_img,
248
+ examples_per_page=8,
249
+ examples=garm_list_path
250
+ )
251
+
252
+ with gr.Column():
253
+ image_out = gr.Image(label="Output", elem_id="output-img",show_download_button=False)
254
+ try_button = gr.Button(value="Try-on")
255
+
256
+ with gr.Column():
257
+ image_out_c = gr.Image(label="Output", elem_id="output-img",show_download_button=False)
258
+
259
+ with gr.Column():
260
+ masked_img = gr.Image(label="Masked image output", elem_id="masked_img", show_download_button=False)
261
+
262
+
263
+ try_button.click(fn=start_tryon, inputs=[imgs,garm_img], outputs=[image_out,image_out_c,masked_img], api_name='tryon')
264
+
265
+
266
+
267
+ image_blocks.launch()