NikhilJoson commited on
Commit
b07cf8d
ยท
verified ยท
1 Parent(s): 6f118c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -51
app.py CHANGED
@@ -20,9 +20,6 @@ from preprocess.openpose.run_openpose import OpenPose
20
  from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
21
  from torchvision.transforms.functional import to_pil_image
22
 
23
- topwears = ["shirt", "t-shirt", "top", "blouse", "sweatshirt"]
24
- bottomwears = ["short", "shorts", "trousers", "leggings", "sweatshirt", "jeans", "skirts", "joggers", "pants", "dhoti", "lungi", "capris", "palazzos"]
25
-
26
 
27
  def pil_to_binary_mask(pil_image, threshold=0):
28
  np_image = np.array(pil_image)
@@ -38,7 +35,7 @@ def pil_to_binary_mask(pil_image, threshold=0):
38
  return output_mask
39
 
40
 
41
- base_path = './IDM-VTON'
42
  example_path = os.path.join(os.path.dirname(__file__), 'example')
43
 
44
  unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16,)
@@ -64,7 +61,12 @@ vae.requires_grad_(False)
64
  unet.requires_grad_(False)
65
  text_encoder_one.requires_grad_(False)
66
  text_encoder_two.requires_grad_(False)
67
- tensor_transfrom = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]),])
 
 
 
 
 
68
 
69
  pipe = TryonPipeline.from_pretrained(
70
  base_path,
@@ -82,7 +84,7 @@ pipe = TryonPipeline.from_pretrained(
82
  pipe.unet_encoder = UNet_Encoder
83
 
84
  @spaces.GPU
85
- def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed):
86
  device = "cuda"
87
 
88
  openpose_model.preprocessor.body_estimation.model.to(device)
@@ -110,15 +112,7 @@ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_ste
110
  if is_checked:
111
  keypoints = openpose_model(human_img.resize((384,512)))
112
  model_parse, _ = parsing_model(human_img.resize((384,512)))
113
-
114
- # using lambda functions to check if the description contains any -
115
- contains_word = lambda s, l: any(map(lambda x: x in s, l))
116
- # topwears
117
- if contains_word(desc,topwears):
118
- mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
119
- # bottomwears
120
- if contains_word(desc,bottomwears):
121
- mask, mask_gray = get_mask_location('hd', "lower_body", model_parse, keypoints)
122
  mask = mask.resize((768,1024))
123
  else:
124
  mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
@@ -165,17 +159,7 @@ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_ste
165
  if not isinstance(negative_prompt, List):
166
  negative_prompt = [negative_prompt] * 1
167
  with torch.inference_mode():
168
- (
169
- prompt_embeds_c,
170
- _,
171
- _,
172
- _,
173
- ) = pipe.encode_prompt(
174
- prompt,
175
- num_images_per_prompt=1,
176
- do_classifier_free_guidance=False,
177
- negative_prompt=negative_prompt,
178
- )
179
 
180
 
181
 
@@ -209,6 +193,19 @@ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_ste
209
  return images[0], mask_gray
210
  # return images[0], mask_gray
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  garm_list = os.listdir(os.path.join(example_path,"cloth"))
213
  garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
214
 
@@ -229,7 +226,7 @@ for ex_human in human_list_path:
229
  image_blocks = gr.Blocks(theme="Nymbo/Alyx_Theme").queue()
230
  with image_blocks as demo:
231
  gr.HTML("<center><h1>Virtual Try-On</h1></center>")
232
- gr.HTML("<center><p>Upload an image of a person and an image of a garment โœจ</p></center>")
233
  with gr.Row():
234
  with gr.Column():
235
  imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
@@ -237,39 +234,33 @@ with image_blocks as demo:
237
  is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
238
  with gr.Row():
239
  is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
240
-
241
- example = gr.Examples(
242
- inputs=imgs,
243
- examples_per_page=10,
244
- examples=human_ex_list
245
- )
246
 
247
  with gr.Column():
248
- garm_img = gr.Image(label="Topwear", sources='upload', type="pil")
249
  with gr.Row(elem_id="prompt-container"):
250
  with gr.Row():
251
- prompt = gr.Textbox(placeholder="Description of topwear ex) Short Sleeve Black Round Neck T-shirts", show_label=False, elem_id="prompt")
252
- example = gr.Examples(
253
- inputs=garm_img,
254
- examples_per_page=8,
255
- examples=garm_list_path)
256
-
257
  with gr.Column():
258
- garm_img = gr.Image(label="Bottomwear", sources='upload', type="pil")
259
  with gr.Row(elem_id="prompt-container"):
260
  with gr.Row():
261
- prompt = gr.Textbox(placeholder="Description of bottomwear ex) Olive Cargo Pants", show_label=False, elem_id="prompt")
262
- example = gr.Examples(
263
- inputs=garm_img,
264
- examples_per_page=8,
265
- examples=garm_list_path)
266
-
267
  with gr.Column():
268
- # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
269
- masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False)
 
 
 
 
270
  with gr.Column():
271
  # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
272
  image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False)
 
 
 
273
 
274
 
275
 
@@ -281,8 +272,10 @@ with image_blocks as demo:
281
  seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
282
 
283
 
 
 
 
 
284
 
285
- try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon')
286
 
287
-
288
  image_blocks.launch()
 
20
  from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
21
  from torchvision.transforms.functional import to_pil_image
22
 
 
 
 
23
 
24
  def pil_to_binary_mask(pil_image, threshold=0):
25
  np_image = np.array(pil_image)
 
35
  return output_mask
36
 
37
 
38
+ base_path = 'yisol/IDM-VTON'
39
  example_path = os.path.join(os.path.dirname(__file__), 'example')
40
 
41
  unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16,)
 
61
  unet.requires_grad_(False)
62
  text_encoder_one.requires_grad_(False)
63
  text_encoder_two.requires_grad_(False)
64
+ tensor_transfrom = transforms.Compose(
65
+ [
66
+ transforms.ToTensor(),
67
+ transforms.Normalize([0.5], [0.5]),
68
+ ]
69
+ )
70
 
71
  pipe = TryonPipeline.from_pretrained(
72
  base_path,
 
84
  pipe.unet_encoder = UNet_Encoder
85
 
86
  @spaces.GPU
87
+ def start_tryon(dict,garm_img,garment_des,cloth_type,is_checked,is_checked_crop,denoise_steps,seed):
88
  device = "cuda"
89
 
90
  openpose_model.preprocessor.body_estimation.model.to(device)
 
112
  if is_checked:
113
  keypoints = openpose_model(human_img.resize((384,512)))
114
  model_parse, _ = parsing_model(human_img.resize((384,512)))
115
+ mask, mask_gray = get_mask_location('hd', cloth_type, model_parse, keypoints)
 
 
 
 
 
 
 
 
116
  mask = mask.resize((768,1024))
117
  else:
118
  mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
 
159
  if not isinstance(negative_prompt, List):
160
  negative_prompt = [negative_prompt] * 1
161
  with torch.inference_mode():
162
+ (prompt_embeds_c,_,_,_,) = pipe.encode_prompt(prompt,num_images_per_prompt=1,do_classifier_free_guidance=False,negative_prompt=negative_prompt,)
 
 
 
 
 
 
 
 
 
 
163
 
164
 
165
 
 
193
  return images[0], mask_gray
194
  # return images[0], mask_gray
195
 
196
+
197
+ def main_(imgs, topwear_img, topwear_des, bottomwear_img, bottomwear_des, dress_img, dress_des prompt, is_checked,is_checked_crop, denoise_steps, seed):
198
+ if dress_img!=None:
199
+ return start_tryon(imgs,dress_img,dress_des,"dresses",is_checked,is_checked_crop,denoise_steps,seed)
200
+ elif topwear_img!=None and bottomwear_img==None:
201
+ return start_tryon(imgs,topwear_img,topwear_des,"upper_body",is_checked,is_checked_crop,denoise_steps,seed)
202
+ elif topwear_img==None and bottomwear_img!=None:
203
+ return start_tryon(imgs,bottomwear_img,bottomwear_des,"lower_body",is_checked,is_checked_crop,denoise_steps,seed)
204
+ elif topwear_img!=None and bottomwear_img!=None:
205
+ half_img, half_mask = start_tryon(imgs,topwear_img,topwear_des,"upper_body",is_checked,is_checked_crop,denoise_steps,seed)
206
+ return start_tryon(imgs,half_img,bottomwear_des,"lower_body",is_checked,is_checked_crop,denoise_steps,seed)
207
+
208
+
209
  garm_list = os.listdir(os.path.join(example_path,"cloth"))
210
  garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
211
 
 
226
  image_blocks = gr.Blocks(theme="Nymbo/Alyx_Theme").queue()
227
  with image_blocks as demo:
228
  gr.HTML("<center><h1>Virtual Try-On</h1></center>")
229
+ gr.HTML("<center><p>Upload an image of a person and images of the clothesโœจ</p></center>")
230
  with gr.Row():
231
  with gr.Column():
232
  imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
 
234
  is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
235
  with gr.Row():
236
  is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
237
+ example = gr.Examples(inputs=imgs, examples_per_page=10, examples=human_ex_list)
 
 
 
 
 
238
 
239
  with gr.Column():
240
+ topwear_image = gr.Image(label="Topwear", sources='upload', type="pil")
241
  with gr.Row(elem_id="prompt-container"):
242
  with gr.Row():
243
+ topwear_desc = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
244
+ example = gr.Examples(inputs=topwear_img, examples_per_page=8,examples=garm_list_path)
 
 
 
 
245
  with gr.Column():
246
+ bottomwear_image = gr.Image(label="Bottomwear", sources='upload', type="pil")
247
  with gr.Row(elem_id="prompt-container"):
248
  with gr.Row():
249
+ bottomwear_desc = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
250
+ example = gr.Examples(inputs=bottomwear_img, examples_per_page=8,examples=garm_list_path)
 
 
 
 
251
  with gr.Column():
252
+ dress_image = gr.Image(label="Dress", sources='upload', type="pil")
253
+ with gr.Row(elem_id="prompt-container"):
254
+ with gr.Row():
255
+ dress_desc = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
256
+ example = gr.Examples(inputs=dress_img, examples_per_page=8,examples=garm_list_path)
257
+
258
  with gr.Column():
259
  # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
260
  image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False)
261
+ with gr.Accordion("Debug Info", open=False):
262
+ masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False)
263
+
264
 
265
 
266
 
 
272
  seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
273
 
274
 
275
+ try_button.click(fn=main_, inputs=[imgs,topwear_image,topwear_desc,bottomwear_image,bottomwear_desc,dress_image,dress_desc,prompt,is_checked,is_checked_crop,denoise_steps,seed],
276
+ outputs=[image_out, masked_img], api_name='tryon')
277
+
278
+
279
 
 
280
 
 
281
  image_blocks.launch()