Vaibhavnaik12 commited on
Commit
ea9eec7
·
verified ·
1 Parent(s): d8bcb3f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +355 -0
app.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ os.environ['CUDA_HOME'] = '/usr/local/cuda'
4
+ os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
5
+ from datetime import datetime
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import numpy as np
10
+ import torch
11
+ from diffusers.image_processor import VaeImageProcessor
12
+ from huggingface_hub import snapshot_download
13
+ from PIL import Image
14
+ torch.jit.script = lambda f: f
15
+ from model.cloth_masker import AutoMasker, vis_mask
16
+ from model.pipeline import CatVTONPipeline
17
+ from utils import init_weight_dtype, resize_and_crop, resize_and_padding
18
+
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
22
+ parser.add_argument(
23
+ "--base_model_path",
24
+ type=str,
25
+ default="booksforcharlie/stable-diffusion-inpainting",
26
+ help=(
27
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
28
+ ),
29
+ )
30
+ parser.add_argument(
31
+ "--resume_path",
32
+ type=str,
33
+ default="zhengchong/CatVTON",
34
+ help=(
35
+ "The Path to the checkpoint of trained tryon model."
36
+ ),
37
+ )
38
+ parser.add_argument(
39
+ "--output_dir",
40
+ type=str,
41
+ default="resource/demo/output",
42
+ help="The output directory where the model predictions will be written.",
43
+ )
44
+
45
+ parser.add_argument(
46
+ "--width",
47
+ type=int,
48
+ default=768,
49
+ help=(
50
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
51
+ " resolution"
52
+ ),
53
+ )
54
+ parser.add_argument(
55
+ "--height",
56
+ type=int,
57
+ default=1024,
58
+ help=(
59
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
60
+ " resolution"
61
+ ),
62
+ )
63
+ parser.add_argument(
64
+ "--repaint",
65
+ action="store_true",
66
+ help="Whether to repaint the result image with the original background."
67
+ )
68
+ parser.add_argument(
69
+ "--allow_tf32",
70
+ action="store_true",
71
+ default=True,
72
+ help=(
73
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
74
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
75
+ ),
76
+ )
77
+ parser.add_argument(
78
+ "--mixed_precision",
79
+ type=str,
80
+ default="bf16",
81
+ choices=["no", "fp16", "bf16"],
82
+ help=(
83
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
84
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
85
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
86
+ ),
87
+ )
88
+
89
+ args = parser.parse_args()
90
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
91
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
92
+ args.local_rank = env_local_rank
93
+
94
+ return args
95
+
96
+ def image_grid(imgs, rows, cols):
97
+ assert len(imgs) == rows * cols
98
+
99
+ w, h = imgs[0].size
100
+ grid = Image.new("RGB", size=(cols * w, rows * h))
101
+
102
+ for i, img in enumerate(imgs):
103
+ grid.paste(img, box=(i % cols * w, i // cols * h))
104
+ return grid
105
+
106
+
107
+ args = parse_args()
108
+ repo_path = snapshot_download(repo_id=args.resume_path)
109
+ # Pipeline
110
+ pipeline = CatVTONPipeline(
111
+ base_ckpt=args.base_model_path,
112
+ attn_ckpt=repo_path,
113
+ attn_ckpt_version="mix",
114
+ weight_dtype=init_weight_dtype(args.mixed_precision),
115
+ use_tf32=args.allow_tf32,
116
+ device='cuda'
117
+ )
118
+ # AutoMasker
119
+ mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
120
+ automasker = AutoMasker(
121
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
122
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
123
+ device='cuda',
124
+ )
125
+
126
+ @spaces.GPU(duration=120)
127
+ def submit_function(
128
+ person_image,
129
+ cloth_image,
130
+ cloth_type,
131
+ num_inference_steps,
132
+ guidance_scale,
133
+ seed,
134
+ show_type
135
+ ):
136
+ person_image, mask = person_image["background"], person_image["layers"][0]
137
+ mask = Image.open(mask).convert("L")
138
+ if len(np.unique(np.array(mask))) == 1:
139
+ mask = None
140
+ else:
141
+ mask = np.array(mask)
142
+ mask[mask > 0] = 255
143
+ mask = Image.fromarray(mask)
144
+
145
+ tmp_folder = args.output_dir
146
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
147
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
148
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
149
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
150
+
151
+ generator = None
152
+ if seed != -1:
153
+ generator = torch.Generator(device='cuda').manual_seed(seed)
154
+
155
+ person_image = Image.open(person_image).convert("RGB")
156
+ cloth_image = Image.open(cloth_image).convert("RGB")
157
+ person_image = resize_and_crop(person_image, (args.width, args.height))
158
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
159
+
160
+ # Process mask
161
+ if mask is not None:
162
+ mask = resize_and_crop(mask, (args.width, args.height))
163
+ else:
164
+ mask = automasker(
165
+ person_image,
166
+ cloth_type
167
+ )['mask']
168
+ mask = mask_processor.blur(mask, blur_factor=9)
169
+
170
+ # Inference
171
+ # try:
172
+ result_image = pipeline(
173
+ image=person_image,
174
+ condition_image=cloth_image,
175
+ mask=mask,
176
+ num_inference_steps=num_inference_steps,
177
+ guidance_scale=guidance_scale,
178
+ generator=generator
179
+ )[0]
180
+ # except Exception as e:
181
+ # raise gr.Error(
182
+ # "An error occurred. Please try again later: {}".format(e)
183
+ # )
184
+
185
+ # Post-process
186
+ masked_person = vis_mask(person_image, mask)
187
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
188
+ save_result_image.save(result_save_path)
189
+ if show_type == "result only":
190
+ return result_image
191
+ else:
192
+ width, height = person_image.size
193
+ if show_type == "input & result":
194
+ condition_width = width // 2
195
+ conditions = image_grid([person_image, cloth_image], 2, 1)
196
+ else:
197
+ condition_width = width // 3
198
+ conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
199
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
200
+ new_result_image = Image.new("RGB", (width + condition_width + 6, height))
201
+ new_result_image.paste(conditions, (0, 0))
202
+ new_result_image.paste(result_image, (condition_width + 6, 0))
203
+ return new_result_image
204
+
205
+
206
+ def person_example_fn(image_path):
207
+ return image_path
208
+
209
+
210
+ # Define the HTML content for the header
211
+ HEADER = """
212
+ <div style="text-align: center;">
213
+ <img src="https://i.ibb.co/9bh36NJ/resource-De-XFIT.png" alt="DeX Logo" style="width: 40%; display: block; margin: 0 auto;">
214
+ <h1 style="color: #101820;"> Virtual Try-On with Diffusion Models </h1>
215
+ </div>
216
+ """
217
+
218
+ def app_gradio():
219
+ with gr.Blocks(title="DeXFit") as demo:
220
+ gr.Markdown(HEADER)
221
+ with gr.Row():
222
+ with gr.Column(scale=1, min_width=350):
223
+ with gr.Row():
224
+ image_path = gr.Image(
225
+ type="filepath",
226
+ interactive=True,
227
+ visible=False,
228
+ )
229
+ person_image = gr.ImageEditor(
230
+ interactive=True, label="Person Image", type="filepath"
231
+ )
232
+
233
+ with gr.Row():
234
+ with gr.Column(scale=1, min_width=230):
235
+ cloth_image = gr.Image(
236
+ interactive=True, label="Condition Image", type="filepath"
237
+ )
238
+ with gr.Column(scale=1, min_width=120):
239
+ gr.Markdown(
240
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
241
+ )
242
+ cloth_type = gr.Radio(
243
+ label="Try-On Cloth Type",
244
+ choices=["upper", "lower", "overall"],
245
+ value="upper",
246
+ )
247
+
248
+
249
+ submit = gr.Button("Submit")
250
+ gr.Markdown(
251
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
252
+ )
253
+
254
+ gr.Markdown(
255
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
256
+ )
257
+ with gr.Accordion("Advanced Options", open=False):
258
+ num_inference_steps = gr.Slider(
259
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
260
+ )
261
+ # Guidence Scale
262
+ guidance_scale = gr.Slider(
263
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
264
+ )
265
+ # Random Seed
266
+ seed = gr.Slider(
267
+ label="Seed", minimum=-1, maximum=10000, step=1, value=1000
268
+ )
269
+ show_type = gr.Radio(
270
+ label="Show Type",
271
+ choices=["result only", "input & result", "input & mask & result"],
272
+ value="input & mask & result",
273
+ )
274
+
275
+ with gr.Column(scale=2, min_width=500):
276
+ result_image = gr.Image(interactive=False, label="Result")
277
+ with gr.Row():
278
+ # Photo Examples
279
+ root_path = "resource/demo/example"
280
+ with gr.Column():
281
+ men_exm = gr.Examples(
282
+ examples=[
283
+ os.path.join(root_path, "person", "men", _)
284
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
285
+ ],
286
+ examples_per_page=4,
287
+ inputs=image_path,
288
+ label="Person Examples ①",
289
+ )
290
+ women_exm = gr.Examples(
291
+ examples=[
292
+ os.path.join(root_path, "person", "women", _)
293
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
294
+ ],
295
+ examples_per_page=4,
296
+ inputs=image_path,
297
+ label="Person Examples ②",
298
+ )
299
+ gr.Markdown(
300
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
301
+ )
302
+ with gr.Column():
303
+ condition_upper_exm = gr.Examples(
304
+ examples=[
305
+ os.path.join(root_path, "condition", "upper", _)
306
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
307
+ ],
308
+ examples_per_page=4,
309
+ inputs=cloth_image,
310
+ label="Condition Upper Examples",
311
+ )
312
+ condition_overall_exm = gr.Examples(
313
+ examples=[
314
+ os.path.join(root_path, "condition", "overall", _)
315
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
316
+ ],
317
+ examples_per_page=4,
318
+ inputs=cloth_image,
319
+ label="Condition Overall Examples",
320
+ )
321
+ condition_person_exm = gr.Examples(
322
+ examples=[
323
+ os.path.join(root_path, "condition", "person", _)
324
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
325
+ ],
326
+ examples_per_page=4,
327
+ inputs=cloth_image,
328
+ label="Condition Reference Person Examples",
329
+ )
330
+ gr.Markdown(
331
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
332
+ )
333
+
334
+ image_path.change(
335
+ person_example_fn, inputs=image_path, outputs=person_image
336
+ )
337
+
338
+ submit.click(
339
+ submit_function,
340
+ [
341
+ person_image,
342
+ cloth_image,
343
+ cloth_type,
344
+ num_inference_steps,
345
+ guidance_scale,
346
+ seed,
347
+ show_type,
348
+ ],
349
+ result_image,
350
+ )
351
+ demo.queue().launch(share=True, show_error=True)
352
+
353
+
354
+ if __name__ == "__main__":
355
+ app_gradio()