Vaibhavnaik12 commited on
Commit
b46c980
·
verified ·
1 Parent(s): 4622cdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -53
app.py CHANGED
@@ -1,34 +1,229 @@
1
  import argparse
2
  import os
 
 
3
  from datetime import datetime
 
4
  import gradio as gr
 
5
  import numpy as np
6
  import torch
7
  from diffusers.image_processor import VaeImageProcessor
8
  from huggingface_hub import snapshot_download
9
  from PIL import Image
 
10
  from model.cloth_masker import AutoMasker, vis_mask
11
  from model.pipeline import CatVTONPipeline
12
  from utils import init_weight_dtype, resize_and_crop, resize_and_padding
13
 
14
- # ... (rest of your imports and function definitions remain unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  HEADER = """
17
- <p style="text-align: center;">
18
- <img src="resource/DeXFIT.png" alt="DeX Logo" style="height: 100px;">
19
- </p>
20
- <h1 style="text-align: center; color: #101820;"> DEX FIT Virtual Try-On with Diffusion Models </h1>
21
  <br>
22
- <p style="color: #101820;">· SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span style="color: #00685E;">seed</span> for normal outcomes.</p>
 
23
  """
24
 
25
  def app_gradio():
26
- with gr.Blocks(title="CatVTON", css="#main {background-color: #F4F4F1;}") as demo:
27
  gr.Markdown(HEADER)
28
-
29
  with gr.Row():
30
  with gr.Column(scale=1, min_width=350):
31
  with gr.Row():
 
 
 
 
 
32
  person_image = gr.ImageEditor(
33
  interactive=True, label="Person Image", type="filepath"
34
  )
@@ -46,10 +241,10 @@ def app_gradio():
46
  label="Try-On Cloth Type",
47
  choices=["upper", "lower", "overall"],
48
  value="upper",
49
- label_style={"color": "#101820"}
50
  )
51
 
52
- submit = gr.Button("Submit", elem_id="submit-button", style={"background-color": "#00685E", "color": "#FFFFFF"})
 
53
  gr.Markdown(
54
  '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
55
  )
@@ -61,9 +256,11 @@ def app_gradio():
61
  num_inference_steps = gr.Slider(
62
  label="Inference Step", minimum=10, maximum=100, step=5, value=50
63
  )
 
64
  guidance_scale = gr.Slider(
65
- label="CFG Strength", minimum=0.0, maximum=7.5, step=0.5, value=2.5
66
  )
 
67
  seed = gr.Slider(
68
  label="Seed", minimum=-1, maximum=10000, step=1, value=1000
69
  )
@@ -76,29 +273,30 @@ def app_gradio():
76
  with gr.Column(scale=2, min_width=500):
77
  result_image = gr.Image(interactive=False, label="Result")
78
  with gr.Row():
 
79
  root_path = "resource/demo/example"
80
- with gr.Column():
81
- men_exm = gr.Examples(
82
- examples=[
83
- os.path.join(root_path, "person", "men", _)
84
- for _ in os.listdir(os.path.join(root_path, "person", "men"))
85
- ],
86
- examples_per_page=4,
87
- inputs=person_image,
88
- label="Person Examples ①",
89
- )
90
- women_exm = gr.Examples(
91
- examples=[
92
- os.path.join(root_path, "person", "women", _)
93
- for _ in os.listdir(os.path.join(root_path, "person", "women"))
94
- ],
95
- examples_per_page=4,
96
- inputs=person_image,
97
- label="Person Examples ②",
98
- )
99
- gr.Markdown(
100
- '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion" style="color: #00685E;">OOTDiffusion</a> and <a href="https://www.outfitanyone.org" style="color: #00685E;">OutfitAnyone</a>.</span>'
101
- )
102
  with gr.Column():
103
  condition_upper_exm = gr.Examples(
104
  examples=[
@@ -128,30 +326,28 @@ with gr.Column():
128
  label="Condition Reference Person Examples",
129
  )
130
  gr.Markdown(
131
- '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet.</span>'
132
  )
133
 
134
- # Update the image path change function
135
- person_image.change(
136
- person_example_fn, inputs=person_image, outputs=person_image
137
- )
138
-
139
- # Connect the submit button to the function
140
- submit.click(
141
- submit_function,
142
- [
143
- person_image,
144
- cloth_image,
145
- cloth_type,
146
- num_inference_steps,
147
- guidance_scale,
148
- seed,
149
- show_type,
150
- ],
151
- result_image,
152
- )
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  demo.queue().launch(share=True, show_error=True)
155
 
 
156
  if __name__ == "__main__":
157
  app_gradio()
 
1
  import argparse
2
  import os
3
+ os.environ['CUDA_HOME'] = '/usr/local/cuda'
4
+ os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
5
  from datetime import datetime
6
+
7
  import gradio as gr
8
+ import spaces
9
  import numpy as np
10
  import torch
11
  from diffusers.image_processor import VaeImageProcessor
12
  from huggingface_hub import snapshot_download
13
  from PIL import Image
14
+ torch.jit.script = lambda f: f
15
  from model.cloth_masker import AutoMasker, vis_mask
16
  from model.pipeline import CatVTONPipeline
17
  from utils import init_weight_dtype, resize_and_crop, resize_and_padding
18
 
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
22
+ parser.add_argument(
23
+ "--base_model_path",
24
+ type=str,
25
+ default="booksforcharlie/stable-diffusion-inpainting",
26
+ help=(
27
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
28
+ ),
29
+ )
30
+ parser.add_argument(
31
+ "--resume_path",
32
+ type=str,
33
+ default="zhengchong/CatVTON",
34
+ help=(
35
+ "The Path to the checkpoint of trained tryon model."
36
+ ),
37
+ )
38
+ parser.add_argument(
39
+ "--output_dir",
40
+ type=str,
41
+ default="resource/demo/output",
42
+ help="The output directory where the model predictions will be written.",
43
+ )
44
+
45
+ parser.add_argument(
46
+ "--width",
47
+ type=int,
48
+ default=768,
49
+ help=(
50
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
51
+ " resolution"
52
+ ),
53
+ )
54
+ parser.add_argument(
55
+ "--height",
56
+ type=int,
57
+ default=1024,
58
+ help=(
59
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
60
+ " resolution"
61
+ ),
62
+ )
63
+ parser.add_argument(
64
+ "--repaint",
65
+ action="store_true",
66
+ help="Whether to repaint the result image with the original background."
67
+ )
68
+ parser.add_argument(
69
+ "--allow_tf32",
70
+ action="store_true",
71
+ default=True,
72
+ help=(
73
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
74
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
75
+ ),
76
+ )
77
+ parser.add_argument(
78
+ "--mixed_precision",
79
+ type=str,
80
+ default="bf16",
81
+ choices=["no", "fp16", "bf16"],
82
+ help=(
83
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
84
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
85
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
86
+ ),
87
+ )
88
+
89
+ args = parser.parse_args()
90
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
91
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
92
+ args.local_rank = env_local_rank
93
+
94
+ return args
95
+
96
+ def image_grid(imgs, rows, cols):
97
+ assert len(imgs) == rows * cols
98
+
99
+ w, h = imgs[0].size
100
+ grid = Image.new("RGB", size=(cols * w, rows * h))
101
+
102
+ for i, img in enumerate(imgs):
103
+ grid.paste(img, box=(i % cols * w, i // cols * h))
104
+ return grid
105
+
106
+
107
+ args = parse_args()
108
+ repo_path = snapshot_download(repo_id=args.resume_path)
109
+ # Pipeline
110
+ pipeline = CatVTONPipeline(
111
+ base_ckpt=args.base_model_path,
112
+ attn_ckpt=repo_path,
113
+ attn_ckpt_version="mix",
114
+ weight_dtype=init_weight_dtype(args.mixed_precision),
115
+ use_tf32=args.allow_tf32,
116
+ device='cuda'
117
+ )
118
+ # AutoMasker
119
+ mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
120
+ automasker = AutoMasker(
121
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
122
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
123
+ device='cuda',
124
+ )
125
+
126
+ @spaces.GPU(duration=120)
127
+ def submit_function(
128
+ person_image,
129
+ cloth_image,
130
+ cloth_type,
131
+ num_inference_steps,
132
+ guidance_scale,
133
+ seed,
134
+ show_type
135
+ ):
136
+ person_image, mask = person_image["background"], person_image["layers"][0]
137
+ mask = Image.open(mask).convert("L")
138
+ if len(np.unique(np.array(mask))) == 1:
139
+ mask = None
140
+ else:
141
+ mask = np.array(mask)
142
+ mask[mask > 0] = 255
143
+ mask = Image.fromarray(mask)
144
+
145
+ tmp_folder = args.output_dir
146
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
147
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
148
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
149
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
150
+
151
+ generator = None
152
+ if seed != -1:
153
+ generator = torch.Generator(device='cuda').manual_seed(seed)
154
+
155
+ person_image = Image.open(person_image).convert("RGB")
156
+ cloth_image = Image.open(cloth_image).convert("RGB")
157
+ person_image = resize_and_crop(person_image, (args.width, args.height))
158
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
159
+
160
+ # Process mask
161
+ if mask is not None:
162
+ mask = resize_and_crop(mask, (args.width, args.height))
163
+ else:
164
+ mask = automasker(
165
+ person_image,
166
+ cloth_type
167
+ )['mask']
168
+ mask = mask_processor.blur(mask, blur_factor=9)
169
+
170
+ # Inference
171
+ # try:
172
+ result_image = pipeline(
173
+ image=person_image,
174
+ condition_image=cloth_image,
175
+ mask=mask,
176
+ num_inference_steps=num_inference_steps,
177
+ guidance_scale=guidance_scale,
178
+ generator=generator
179
+ )[0]
180
+ # except Exception as e:
181
+ # raise gr.Error(
182
+ # "An error occurred. Please try again later: {}".format(e)
183
+ # )
184
+
185
+ # Post-process
186
+ masked_person = vis_mask(person_image, mask)
187
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
188
+ save_result_image.save(result_save_path)
189
+ if show_type == "result only":
190
+ return result_image
191
+ else:
192
+ width, height = person_image.size
193
+ if show_type == "input & result":
194
+ condition_width = width // 2
195
+ conditions = image_grid([person_image, cloth_image], 2, 1)
196
+ else:
197
+ condition_width = width // 3
198
+ conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
199
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
200
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
201
+ new_result_image.paste(conditions, (0, 0))
202
+ new_result_image.paste(result_image, (condition_width + 5, 0))
203
+ return new_result_image
204
+
205
+
206
+ def person_example_fn(image_path):
207
+ return image_path
208
 
209
  HEADER = """
210
+ <h1 style="text-align: center;"> DeX Fit - Try-On </h1>
 
 
 
211
  <br>
212
+
213
+ · SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
214
  """
215
 
216
  def app_gradio():
217
+ with gr.Blocks(title="CatVTON") as demo:
218
  gr.Markdown(HEADER)
 
219
  with gr.Row():
220
  with gr.Column(scale=1, min_width=350):
221
  with gr.Row():
222
+ image_path = gr.Image(
223
+ type="filepath",
224
+ interactive=True,
225
+ visible=False,
226
+ )
227
  person_image = gr.ImageEditor(
228
  interactive=True, label="Person Image", type="filepath"
229
  )
 
241
  label="Try-On Cloth Type",
242
  choices=["upper", "lower", "overall"],
243
  value="upper",
 
244
  )
245
 
246
+
247
+ submit = gr.Button("Submit")
248
  gr.Markdown(
249
  '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
250
  )
 
256
  num_inference_steps = gr.Slider(
257
  label="Inference Step", minimum=10, maximum=100, step=5, value=50
258
  )
259
+ # Guidence Scale
260
  guidance_scale = gr.Slider(
261
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
262
  )
263
+ # Random Seed
264
  seed = gr.Slider(
265
  label="Seed", minimum=-1, maximum=10000, step=1, value=1000
266
  )
 
273
  with gr.Column(scale=2, min_width=500):
274
  result_image = gr.Image(interactive=False, label="Result")
275
  with gr.Row():
276
+ # Photo Examples
277
  root_path = "resource/demo/example"
278
+ with gr.Column():
279
+ men_exm = gr.Examples(
280
+ examples=[
281
+ os.path.join(root_path, "person", "men", _)
282
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
283
+ ],
284
+ examples_per_page=4,
285
+ inputs=image_path,
286
+ label="Person Examples ①",
287
+ )
288
+ women_exm = gr.Examples(
289
+ examples=[
290
+ os.path.join(root_path, "person", "women", _)
291
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
292
+ ],
293
+ examples_per_page=4,
294
+ inputs=image_path,
295
+ label="Person Examples ②",
296
+ )
297
+ gr.Markdown(
298
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
299
+ )
300
  with gr.Column():
301
  condition_upper_exm = gr.Examples(
302
  examples=[
 
326
  label="Condition Reference Person Examples",
327
  )
328
  gr.Markdown(
329
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
330
  )
331
 
332
+ image_path.change(
333
+ person_example_fn, inputs=image_path, outputs=person_image
334
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
+ submit.click(
337
+ submit_function,
338
+ [
339
+ person_image,
340
+ cloth_image,
341
+ cloth_type,
342
+ num_inference_steps,
343
+ guidance_scale,
344
+ seed,
345
+ show_type,
346
+ ],
347
+ result_image,
348
+ )
349
  demo.queue().launch(share=True, show_error=True)
350
 
351
+
352
  if __name__ == "__main__":
353
  app_gradio()