Vaibhavnaik12 commited on
Commit
7627f49
·
verified ·
1 Parent(s): 56c2873

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -229
app.py CHANGED
@@ -1,228 +1,34 @@
1
  import argparse
2
  import os
3
- os.environ['CUDA_HOME'] = '/usr/local/cuda'
4
- os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
5
  from datetime import datetime
6
-
7
  import gradio as gr
8
- import spaces
9
  import numpy as np
10
  import torch
11
  from diffusers.image_processor import VaeImageProcessor
12
  from huggingface_hub import snapshot_download
13
  from PIL import Image
14
- torch.jit.script = lambda f: f
15
  from model.cloth_masker import AutoMasker, vis_mask
16
  from model.pipeline import CatVTONPipeline
17
  from utils import init_weight_dtype, resize_and_crop, resize_and_padding
18
 
19
-
20
- def parse_args():
21
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
22
- parser.add_argument(
23
- "--base_model_path",
24
- type=str,
25
- default="booksforcharlie/stable-diffusion-inpainting",
26
- help=(
27
- "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
28
- ),
29
- )
30
- parser.add_argument(
31
- "--resume_path",
32
- type=str,
33
- default="zhengchong/CatVTON",
34
- help=(
35
- "The Path to the checkpoint of trained tryon model."
36
- ),
37
- )
38
- parser.add_argument(
39
- "--output_dir",
40
- type=str,
41
- default="resource/demo/output",
42
- help="The output directory where the model predictions will be written.",
43
- )
44
-
45
- parser.add_argument(
46
- "--width",
47
- type=int,
48
- default=768,
49
- help=(
50
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
51
- " resolution"
52
- ),
53
- )
54
- parser.add_argument(
55
- "--height",
56
- type=int,
57
- default=1024,
58
- help=(
59
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
60
- " resolution"
61
- ),
62
- )
63
- parser.add_argument(
64
- "--repaint",
65
- action="store_true",
66
- help="Whether to repaint the result image with the original background."
67
- )
68
- parser.add_argument(
69
- "--allow_tf32",
70
- action="store_true",
71
- default=True,
72
- help=(
73
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
74
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
75
- ),
76
- )
77
- parser.add_argument(
78
- "--mixed_precision",
79
- type=str,
80
- default="bf16",
81
- choices=["no", "fp16", "bf16"],
82
- help=(
83
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
84
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
85
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
86
- ),
87
- )
88
-
89
- args = parser.parse_args()
90
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
91
- if env_local_rank != -1 and env_local_rank != args.local_rank:
92
- args.local_rank = env_local_rank
93
-
94
- return args
95
-
96
- def image_grid(imgs, rows, cols):
97
- assert len(imgs) == rows * cols
98
-
99
- w, h = imgs[0].size
100
- grid = Image.new("RGB", size=(cols * w, rows * h))
101
-
102
- for i, img in enumerate(imgs):
103
- grid.paste(img, box=(i % cols * w, i // cols * h))
104
- return grid
105
-
106
-
107
- args = parse_args()
108
- repo_path = snapshot_download(repo_id=args.resume_path)
109
- # Pipeline
110
- pipeline = CatVTONPipeline(
111
- base_ckpt=args.base_model_path,
112
- attn_ckpt=repo_path,
113
- attn_ckpt_version="mix",
114
- weight_dtype=init_weight_dtype(args.mixed_precision),
115
- use_tf32=args.allow_tf32,
116
- device='cuda'
117
- )
118
- # AutoMasker
119
- mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
120
- automasker = AutoMasker(
121
- densepose_ckpt=os.path.join(repo_path, "DensePose"),
122
- schp_ckpt=os.path.join(repo_path, "SCHP"),
123
- device='cuda',
124
- )
125
-
126
- @spaces.GPU(duration=120)
127
- def submit_function(
128
- person_image,
129
- cloth_image,
130
- cloth_type,
131
- num_inference_steps,
132
- guidance_scale,
133
- seed,
134
- show_type
135
- ):
136
- person_image, mask = person_image["background"], person_image["layers"][0]
137
- mask = Image.open(mask).convert("L")
138
- if len(np.unique(np.array(mask))) == 1:
139
- mask = None
140
- else:
141
- mask = np.array(mask)
142
- mask[mask > 0] = 255
143
- mask = Image.fromarray(mask)
144
-
145
- tmp_folder = args.output_dir
146
- date_str = datetime.now().strftime("%Y%m%d%H%M%S")
147
- result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
148
- if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
149
- os.makedirs(os.path.join(tmp_folder, date_str[:8]))
150
-
151
- generator = None
152
- if seed != -1:
153
- generator = torch.Generator(device='cuda').manual_seed(seed)
154
-
155
- person_image = Image.open(person_image).convert("RGB")
156
- cloth_image = Image.open(cloth_image).convert("RGB")
157
- person_image = resize_and_crop(person_image, (args.width, args.height))
158
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
159
-
160
- # Process mask
161
- if mask is not None:
162
- mask = resize_and_crop(mask, (args.width, args.height))
163
- else:
164
- mask = automasker(
165
- person_image,
166
- cloth_type
167
- )['mask']
168
- mask = mask_processor.blur(mask, blur_factor=9)
169
-
170
- # Inference
171
- # try:
172
- result_image = pipeline(
173
- image=person_image,
174
- condition_image=cloth_image,
175
- mask=mask,
176
- num_inference_steps=num_inference_steps,
177
- guidance_scale=guidance_scale,
178
- generator=generator
179
- )[0]
180
- # except Exception as e:
181
- # raise gr.Error(
182
- # "An error occurred. Please try again later: {}".format(e)
183
- # )
184
-
185
- # Post-process
186
- masked_person = vis_mask(person_image, mask)
187
- save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
188
- save_result_image.save(result_save_path)
189
- if show_type == "result only":
190
- return result_image
191
- else:
192
- width, height = person_image.size
193
- if show_type == "input & result":
194
- condition_width = width // 2
195
- conditions = image_grid([person_image, cloth_image], 2, 1)
196
- else:
197
- condition_width = width // 3
198
- conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
199
- conditions = conditions.resize((condition_width, height), Image.NEAREST)
200
- new_result_image = Image.new("RGB", (width + condition_width + 5, height))
201
- new_result_image.paste(conditions, (0, 0))
202
- new_result_image.paste(result_image, (condition_width + 5, 0))
203
- return new_result_image
204
-
205
-
206
- def person_example_fn(image_path):
207
- return image_path
208
 
209
  HEADER = """
210
- <h1 style="text-align: center;"> DEX FIT Virtual Try-On with Diffusion Models </h1>
 
 
 
211
  <br>
212
- · SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
213
  """
214
 
215
  def app_gradio():
216
- with gr.Blocks(title="CatVTON") as demo:
217
  gr.Markdown(HEADER)
 
218
  with gr.Row():
219
  with gr.Column(scale=1, min_width=350):
220
  with gr.Row():
221
- image_path = gr.Image(
222
- type="filepath",
223
- interactive=True,
224
- visible=False,
225
- )
226
  person_image = gr.ImageEditor(
227
  interactive=True, label="Person Image", type="filepath"
228
  )
@@ -240,10 +46,10 @@ def app_gradio():
240
  label="Try-On Cloth Type",
241
  choices=["upper", "lower", "overall"],
242
  value="upper",
 
243
  )
244
 
245
-
246
- submit = gr.Button("Submit")
247
  gr.Markdown(
248
  '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
249
  )
@@ -255,11 +61,9 @@ def app_gradio():
255
  num_inference_steps = gr.Slider(
256
  label="Inference Step", minimum=10, maximum=100, step=5, value=50
257
  )
258
- # Guidence Scale
259
  guidance_scale = gr.Slider(
260
- label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
261
  )
262
- # Random Seed
263
  seed = gr.Slider(
264
  label="Seed", minimum=-1, maximum=10000, step=1, value=1000
265
  )
@@ -272,7 +76,6 @@ def app_gradio():
272
  with gr.Column(scale=2, min_width=500):
273
  result_image = gr.Image(interactive=False, label="Result")
274
  with gr.Row():
275
- # Photo Examples
276
  root_path = "resource/demo/example"
277
  with gr.Column():
278
  men_exm = gr.Examples(
@@ -281,7 +84,7 @@ def app_gradio():
281
  for _ in os.listdir(os.path.join(root_path, "person", "men"))
282
  ],
283
  examples_per_page=4,
284
- inputs=image_path,
285
  label="Person Examples ①",
286
  )
287
  women_exm = gr.Examples(
@@ -289,12 +92,13 @@ def app_gradio():
289
  os.path.join(root_path, "person", "women", _)
290
  for _ in os.listdir(os.path.join(root_path, "person", "women"))
291
  ],
 
292
  examples_per_page=4,
293
- inputs=image_path,
294
  label="Person Examples ②",
295
  )
296
  gr.Markdown(
297
- '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
298
  )
299
  with gr.Column():
300
  condition_upper_exm = gr.Examples(
@@ -325,28 +129,30 @@ def app_gradio():
325
  label="Condition Reference Person Examples",
326
  )
327
  gr.Markdown(
328
- '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
329
  )
330
 
331
- image_path.change(
332
- person_example_fn, inputs=image_path, outputs=person_image
333
- )
 
334
 
335
- submit.click(
336
- submit_function,
337
- [
338
- person_image,
339
- cloth_image,
340
- cloth_type,
341
- num_inference_steps,
342
- guidance_scale,
343
- seed,
344
- show_type,
345
- ],
346
- result_image,
347
- )
348
- demo.queue().launch(share=True, show_error=True)
349
 
 
350
 
351
  if __name__ == "__main__":
352
  app_gradio()
 
1
  import argparse
2
  import os
 
 
3
  from datetime import datetime
 
4
  import gradio as gr
 
5
  import numpy as np
6
  import torch
7
  from diffusers.image_processor import VaeImageProcessor
8
  from huggingface_hub import snapshot_download
9
  from PIL import Image
 
10
  from model.cloth_masker import AutoMasker, vis_mask
11
  from model.pipeline import CatVTONPipeline
12
  from utils import init_weight_dtype, resize_and_crop, resize_and_padding
13
 
14
+ # ... (rest of your imports and function definitions remain unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  HEADER = """
17
+ <p style="text-align: center;">
18
+ <img src="resource/DeXFIT.png" alt="DeX Logo" style="height: 100px;">
19
+ </p>
20
+ <h1 style="text-align: center; color: #101820;"> DEX FIT Virtual Try-On with Diffusion Models </h1>
21
  <br>
22
+ <p style="color: #101820;">· SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span style="color: #00685E;">seed</span> for normal outcomes.</p>
23
  """
24
 
25
  def app_gradio():
26
+ with gr.Blocks(title="CatVTON", css="#main {background-color: #F4F4F1;}") as demo:
27
  gr.Markdown(HEADER)
28
+
29
  with gr.Row():
30
  with gr.Column(scale=1, min_width=350):
31
  with gr.Row():
 
 
 
 
 
32
  person_image = gr.ImageEditor(
33
  interactive=True, label="Person Image", type="filepath"
34
  )
 
46
  label="Try-On Cloth Type",
47
  choices=["upper", "lower", "overall"],
48
  value="upper",
49
+ label_style={"color": "#101820"}
50
  )
51
 
52
+ submit = gr.Button("Submit", elem_id="submit-button", style={"background-color": "#00685E", "color": "#FFFFFF"})
 
53
  gr.Markdown(
54
  '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
55
  )
 
61
  num_inference_steps = gr.Slider(
62
  label="Inference Step", minimum=10, maximum=100, step=5, value=50
63
  )
 
64
  guidance_scale = gr.Slider(
65
+ label="CFG Strength", minimum=0.0, maximum=7.5, step=0.5, value=2.5
66
  )
 
67
  seed = gr.Slider(
68
  label="Seed", minimum=-1, maximum=10000, step=1, value=1000
69
  )
 
76
  with gr.Column(scale=2, min_width=500):
77
  result_image = gr.Image(interactive=False, label="Result")
78
  with gr.Row():
 
79
  root_path = "resource/demo/example"
80
  with gr.Column():
81
  men_exm = gr.Examples(
 
84
  for _ in os.listdir(os.path.join(root_path, "person", "men"))
85
  ],
86
  examples_per_page=4,
87
+ inputs=person_image,
88
  label="Person Examples ①",
89
  )
90
  women_exm = gr.Examples(
 
92
  os.path.join(root_path, "person", "women", _)
93
  for _ in os.listdir(os.path.join(root_path, "person", "women"))
94
  ],
95
+ examples
96
  examples_per_page=4,
97
+ inputs=person_image,
98
  label="Person Examples ②",
99
  )
100
  gr.Markdown(
101
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion" style="color: #00685E;">OOTDiffusion</a> and <a href="https://www.outfitanyone.org" style="color: #00685E;">OutfitAnyone</a>.</span>'
102
  )
103
  with gr.Column():
104
  condition_upper_exm = gr.Examples(
 
129
  label="Condition Reference Person Examples",
130
  )
131
  gr.Markdown(
132
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet.</span>'
133
  )
134
 
135
+ # Update the image path change function
136
+ image_path.change(
137
+ person_example_fn, inputs=image_path, outputs=person_image
138
+ )
139
 
140
+ # Connect the submit button to the function
141
+ submit.click(
142
+ submit_function,
143
+ [
144
+ person_image,
145
+ cloth_image,
146
+ cloth_type,
147
+ num_inference_steps,
148
+ guidance_scale,
149
+ seed,
150
+ show_type,
151
+ ],
152
+ result_image,
153
+ )
154
 
155
+ demo.queue().launch(share=True, show_error=True)
156
 
157
  if __name__ == "__main__":
158
  app_gradio()