Vaibhavnaik12 commited on
Commit
d8bcb3f
·
verified ·
1 Parent(s): 05db3d8

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -350
app.py DELETED
@@ -1,350 +0,0 @@
1
- import argparse
2
- import os
3
- os.environ['CUDA_HOME'] = '/usr/local/cuda'
4
- os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
5
- from datetime import datetime
6
-
7
- import gradio as gr
8
- import spaces
9
- import numpy as np
10
- import torch
11
- from diffusers.image_processor import VaeImageProcessor
12
- from huggingface_hub import snapshot_download
13
- from PIL import Image
14
- torch.jit.script = lambda f: f
15
- from model.cloth_masker import AutoMasker, vis_mask
16
- from model.pipeline import CatVTONPipeline
17
- from utils import init_weight_dtype, resize_and_crop, resize_and_padding
18
-
19
-
20
- def parse_args():
21
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
22
- parser.add_argument(
23
- "--base_model_path",
24
- type=str,
25
- default="booksforcharlie/stable-diffusion-inpainting",
26
- help=(
27
- "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
28
- ),
29
- )
30
- parser.add_argument(
31
- "--resume_path",
32
- type=str,
33
- default="zhengchong/CatVTON",
34
- help=(
35
- "The Path to the checkpoint of trained tryon model."
36
- ),
37
- )
38
- parser.add_argument(
39
- "--output_dir",
40
- type=str,
41
- default="resource/demo/output",
42
- help="The output directory where the model predictions will be written.",
43
- )
44
-
45
- parser.add_argument(
46
- "--width",
47
- type=int,
48
- default=768,
49
- help=(
50
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
51
- " resolution"
52
- ),
53
- )
54
- parser.add_argument(
55
- "--height",
56
- type=int,
57
- default=1024,
58
- help=(
59
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
60
- " resolution"
61
- ),
62
- )
63
- parser.add_argument(
64
- "--repaint",
65
- action="store_true",
66
- help="Whether to repaint the result image with the original background."
67
- )
68
- parser.add_argument(
69
- "--allow_tf32",
70
- action="store_true",
71
- default=True,
72
- help=(
73
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
74
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
75
- ),
76
- )
77
- parser.add_argument(
78
- "--mixed_precision",
79
- type=str,
80
- default="bf16",
81
- choices=["no", "fp16", "bf16"],
82
- help=(
83
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
84
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
85
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
86
- ),
87
- )
88
-
89
- args = parser.parse_args()
90
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
91
- if env_local_rank != -1 and env_local_rank != args.local_rank:
92
- args.local_rank = env_local_rank
93
-
94
- return args
95
-
96
- def image_grid(imgs, rows, cols):
97
- assert len(imgs) == rows * cols
98
-
99
- w, h = imgs[0].size
100
- grid = Image.new("RGB", size=(cols * w, rows * h))
101
-
102
- for i, img in enumerate(imgs):
103
- grid.paste(img, box=(i % cols * w, i // cols * h))
104
- return grid
105
-
106
-
107
- args = parse_args()
108
- repo_path = snapshot_download(repo_id=args.resume_path)
109
- # Pipeline
110
- pipeline = CatVTONPipeline(
111
- base_ckpt=args.base_model_path,
112
- attn_ckpt=repo_path,
113
- attn_ckpt_version="mix",
114
- weight_dtype=init_weight_dtype(args.mixed_precision),
115
- use_tf32=args.allow_tf32,
116
- device='cuda'
117
- )
118
- # AutoMasker
119
- mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
120
- automasker = AutoMasker(
121
- densepose_ckpt=os.path.join(repo_path, "DensePose"),
122
- schp_ckpt=os.path.join(repo_path, "SCHP"),
123
- device='cuda',
124
- )
125
-
126
- @spaces.GPU(duration=120)
127
- def submit_function(
128
- person_image,
129
- cloth_image,
130
- cloth_type,
131
- num_inference_steps,
132
- guidance_scale,
133
- seed,
134
- show_type
135
- ):
136
- person_image, mask = person_image["background"], person_image["layers"][0]
137
- mask = Image.open(mask).convert("L")
138
- if len(np.unique(np.array(mask))) == 1:
139
- mask = None
140
- else:
141
- mask = np.array(mask)
142
- mask[mask > 0] = 255
143
- mask = Image.fromarray(mask)
144
-
145
- tmp_folder = args.output_dir
146
- date_str = datetime.now().strftime("%Y%m%d%H%M%S")
147
- result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
148
- if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
149
- os.makedirs(os.path.join(tmp_folder, date_str[:8]))
150
-
151
- generator = None
152
- if seed != -1:
153
- generator = torch.Generator(device='cuda').manual_seed(seed)
154
-
155
- person_image = Image.open(person_image).convert("RGB")
156
- cloth_image = Image.open(cloth_image).convert("RGB")
157
- person_image = resize_and_crop(person_image, (args.width, args.height))
158
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
159
-
160
- # Process mask
161
- if mask is not None:
162
- mask = resize_and_crop(mask, (args.width, args.height))
163
- else:
164
- mask = automasker(
165
- person_image,
166
- cloth_type # This will now include 'bags' and 'footwear'
167
- )['mask']
168
- mask = mask_processor.blur(mask, blur_factor=9)
169
-
170
- # Inference
171
- result_image = pipeline(
172
- image=person_image,
173
- condition_image=cloth_image,
174
- mask=mask,
175
- num_inference_steps=num_inference_steps,
176
- guidance_scale=guidance_scale,
177
- generator=generator
178
- )[0]
179
-
180
- # Post-process
181
- masked_person = vis_mask(person_image, mask)
182
- save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
183
- save_result_image.save(result_save_path)
184
- if show_type == "result only":
185
- return result_image
186
- else:
187
- width, height = person_image.size
188
- if show_type == "input & result":
189
- condition_width = width // 2
190
- conditions = image_grid([person_image, cloth_image], 2, 1)
191
- else:
192
- condition_width = width // 3
193
- conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
194
- conditions = conditions.resize((condition_width, height), Image.NEAREST)
195
- new_result_image = Image.new("RGB", (width + condition_width + 6, height))
196
- new_result_image.paste(conditions, (0, 0))
197
- new_result_image.paste(result_image, (condition_width + 6, 0))
198
- return new_result_image
199
-
200
-
201
- def person_example_fn(image_path):
202
- return image_path
203
-
204
-
205
- # Define the HTML content for the header
206
- HEADER = """
207
- <div style="text-align: center;">
208
- <img src="https://i.ibb.co/9bh36NJ/resource-De-XFIT.png" alt="DeX Logo" style="width: 40%; display: block; margin: 0 auto;">
209
- <h1 style="color: #101820;"> Virtual Try-On with Diffusion Models </h1>
210
- </div>
211
- """
212
-
213
- def app_gradio():
214
- with gr.Blocks(title="DeXFit") as demo:
215
- gr.Markdown(HEADER)
216
- with gr.Row():
217
- with gr.Column(scale=1, min_width=350):
218
- with gr.Row():
219
- image_path = gr.Image(
220
- type="filepath",
221
- interactive=True,
222
- visible=False,
223
- )
224
- person_image = gr.ImageEditor(
225
- interactive=True, label="Person Image", type="filepath"
226
- )
227
-
228
- with gr.Row():
229
- with gr.Column(scale=1, min_width=230):
230
- cloth_image = gr.Image(
231
- interactive=True, label="Condition Image", type="filepath"
232
- )
233
- with gr.Column(scale=1, min_width=120):
234
- gr.Markdown(
235
- '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
236
- )
237
- cloth_type = gr.Radio(
238
- label="Try-On Cloth Type",
239
- choices=["upper", "lower", "overall", "bags", "footwear"],
240
- value="upper",
241
- )
242
-
243
-
244
- submit = gr.Button("Submit")
245
- gr.Markdown(
246
- '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
247
- )
248
-
249
- gr.Markdown(
250
- '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
251
- )
252
- with gr.Accordion("Advanced Options", open=False):
253
- num_inference_steps = gr.Slider(
254
- label="Inference Step", minimum=10, maximum=100, step=5, value=50
255
- )
256
- # Guidence Scale
257
- guidance_scale = gr.Slider(
258
- label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
259
- )
260
- # Random Seed
261
- seed = gr.Slider(
262
- label="Seed", minimum=-1, maximum=10000, step=1, value=1000
263
- )
264
- show_type = gr.Radio(
265
- label="Show Type",
266
- choices=["result only", "input & result", "input & mask & result"],
267
- value="input & mask & result",
268
- )
269
-
270
- with gr.Column(scale=2, min_width=500):
271
- result_image = gr.Image(interactive=False, label="Result")
272
- with gr.Row():
273
- # Photo Examples
274
- root_path = "resource/demo/example"
275
- with gr.Column():
276
- men_exm = gr.Examples(
277
- examples=[
278
- os.path.join(root_path, "person", "men", _)
279
- for _ in os.listdir(os.path.join(root_path, "person", "men"))
280
- ],
281
- examples_per_page=4,
282
- inputs=image_path,
283
- label="Person Examples ①",
284
- )
285
- women_exm = gr.Examples(
286
- examples=[
287
- os.path.join(root_path, "person", "women", _)
288
- for _ in os.listdir(os.path.join(root_path, "person", "women"))
289
- ],
290
- examples_per_page=4,
291
- inputs=image_path,
292
- label="Person Examples ②",
293
- )
294
- gr.Markdown(
295
- '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
296
- )
297
- with gr.Column():
298
- condition_upper_exm = gr.Examples(
299
- examples=[
300
- os.path.join(root_path, "condition", "upper", _)
301
- for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
302
- ],
303
- examples_per_page=4,
304
- inputs=cloth_image,
305
- label="Condition Upper Examples",
306
- )
307
- condition_overall_exm = gr.Examples(
308
- examples=[
309
- os.path.join(root_path, "condition", "overall", _)
310
- for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
311
- ],
312
- examples_per_page=4,
313
- inputs=cloth_image,
314
- label="Condition Overall Examples",
315
- )
316
- condition_person_exm = gr.Examples(
317
- examples=[
318
- os.path.join(root_path, "condition", "person", _)
319
- for _ in os.listdir(os.path.join(root_path, "condition", "person"))
320
- ],
321
- examples_per_page=4,
322
- inputs=cloth_image,
323
- label="Condition Reference Person Examples",
324
- )
325
- gr.Markdown(
326
- '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
327
- )
328
-
329
- image_path.change(
330
- person_example_fn, inputs=image_path, outputs=person_image
331
- )
332
-
333
- submit.click(
334
- submit_function,
335
- [
336
- person_image,
337
- cloth_image,
338
- cloth_type,
339
- num_inference_steps,
340
- guidance_scale,
341
- seed,
342
- show_type,
343
- ],
344
- result_image,
345
- )
346
- demo.queue().launch(share=True, show_error=True)
347
-
348
-
349
- if __name__ == "__main__":
350
- app_gradio()