NikhilJoson commited on
Commit
6f118c9
ยท
verified ยท
1 Parent(s): ba6ca56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -61
app.py CHANGED
@@ -4,12 +4,7 @@ from PIL import Image
4
  from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
5
  from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
6
  from src.unet_hacked_tryon import UNet2DConditionModel
7
- from transformers import (
8
- CLIPImageProcessor,
9
- CLIPVisionModelWithProjection,
10
- CLIPTextModel,
11
- CLIPTextModelWithProjection,
12
- )
13
  from diffusers import DDPMScheduler,AutoencoderKL
14
  from typing import List
15
 
@@ -25,6 +20,9 @@ from preprocess.openpose.run_openpose import OpenPose
25
  from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
26
  from torchvision.transforms.functional import to_pil_image
27
 
 
 
 
28
 
29
  def pil_to_binary_mask(pil_image, threshold=0):
30
  np_image = np.array(pil_image)
@@ -40,55 +38,22 @@ def pil_to_binary_mask(pil_image, threshold=0):
40
  return output_mask
41
 
42
 
43
- base_path = 'yisol/IDM-VTON'
44
  example_path = os.path.join(os.path.dirname(__file__), 'example')
45
 
46
- unet = UNet2DConditionModel.from_pretrained(
47
- base_path,
48
- subfolder="unet",
49
- torch_dtype=torch.float16,
50
- )
51
  unet.requires_grad_(False)
52
- tokenizer_one = AutoTokenizer.from_pretrained(
53
- base_path,
54
- subfolder="tokenizer",
55
- revision=None,
56
- use_fast=False,
57
- )
58
- tokenizer_two = AutoTokenizer.from_pretrained(
59
- base_path,
60
- subfolder="tokenizer_2",
61
- revision=None,
62
- use_fast=False,
63
- )
64
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
65
 
66
- text_encoder_one = CLIPTextModel.from_pretrained(
67
- base_path,
68
- subfolder="text_encoder",
69
- torch_dtype=torch.float16,
70
- )
71
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
72
- base_path,
73
- subfolder="text_encoder_2",
74
- torch_dtype=torch.float16,
75
- )
76
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
77
- base_path,
78
- subfolder="image_encoder",
79
- torch_dtype=torch.float16,
80
- )
81
- vae = AutoencoderKL.from_pretrained(base_path,
82
- subfolder="vae",
83
- torch_dtype=torch.float16,
84
- )
85
 
86
  # "stabilityai/stable-diffusion-xl-base-1.0",
87
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
88
- base_path,
89
- subfolder="unet_encoder",
90
- torch_dtype=torch.float16,
91
- )
92
 
93
  parsing_model = Parsing(0)
94
  openpose_model = OpenPose(0)
@@ -99,12 +64,7 @@ vae.requires_grad_(False)
99
  unet.requires_grad_(False)
100
  text_encoder_one.requires_grad_(False)
101
  text_encoder_two.requires_grad_(False)
102
- tensor_transfrom = transforms.Compose(
103
- [
104
- transforms.ToTensor(),
105
- transforms.Normalize([0.5], [0.5]),
106
- ]
107
- )
108
 
109
  pipe = TryonPipeline.from_pretrained(
110
  base_path,
@@ -150,7 +110,15 @@ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_ste
150
  if is_checked:
151
  keypoints = openpose_model(human_img.resize((384,512)))
152
  model_parse, _ = parsing_model(human_img.resize((384,512)))
153
- mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
 
 
 
 
 
 
 
 
154
  mask = mask.resize((768,1024))
155
  else:
156
  mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
@@ -277,14 +245,25 @@ with image_blocks as demo:
277
  )
278
 
279
  with gr.Column():
280
- garm_img = gr.Image(label="Garment", sources='upload', type="pil")
281
  with gr.Row(elem_id="prompt-container"):
282
  with gr.Row():
283
- prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
284
  example = gr.Examples(
285
  inputs=garm_img,
286
  examples_per_page=8,
287
  examples=garm_list_path)
 
 
 
 
 
 
 
 
 
 
 
288
  with gr.Column():
289
  # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
290
  masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False)
@@ -294,7 +273,6 @@ with image_blocks as demo:
294
 
295
 
296
 
297
-
298
  with gr.Column():
299
  try_button = gr.Button(value="Try-on")
300
  with gr.Accordion(label="Advanced Settings", open=False):
@@ -306,7 +284,5 @@ with image_blocks as demo:
306
 
307
  try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon')
308
 
309
-
310
-
311
-
312
  image_blocks.launch()
 
4
  from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
5
  from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
6
  from src.unet_hacked_tryon import UNet2DConditionModel
7
+ from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPTextModel, CLIPTextModelWithProjection,)
 
 
 
 
 
8
  from diffusers import DDPMScheduler,AutoencoderKL
9
  from typing import List
10
 
 
20
  from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
21
  from torchvision.transforms.functional import to_pil_image
22
 
23
+ topwears = ["shirt", "t-shirt", "top", "blouse", "sweatshirt"]
24
+ bottomwears = ["short", "shorts", "trousers", "leggings", "sweatshirt", "jeans", "skirts", "joggers", "pants", "dhoti", "lungi", "capris", "palazzos"]
25
+
26
 
27
  def pil_to_binary_mask(pil_image, threshold=0):
28
  np_image = np.array(pil_image)
 
38
  return output_mask
39
 
40
 
41
+ base_path = './IDM-VTON'
42
  example_path = os.path.join(os.path.dirname(__file__), 'example')
43
 
44
+ unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16,)
 
 
 
 
45
  unet.requires_grad_(False)
46
+ tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", revision=None, use_fast=False,)
47
+ tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", revision=None, use_fast=False,)
 
 
 
 
 
 
 
 
 
 
48
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
49
 
50
+ text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16,)
51
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16,)
52
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16,)
53
+ vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # "stabilityai/stable-diffusion-xl-base-1.0",
56
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16,)
 
 
 
 
57
 
58
  parsing_model = Parsing(0)
59
  openpose_model = OpenPose(0)
 
64
  unet.requires_grad_(False)
65
  text_encoder_one.requires_grad_(False)
66
  text_encoder_two.requires_grad_(False)
67
+ tensor_transfrom = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]),])
 
 
 
 
 
68
 
69
  pipe = TryonPipeline.from_pretrained(
70
  base_path,
 
110
  if is_checked:
111
  keypoints = openpose_model(human_img.resize((384,512)))
112
  model_parse, _ = parsing_model(human_img.resize((384,512)))
113
+
114
+ # using lambda functions to check if the description contains any -
115
+ contains_word = lambda s, l: any(map(lambda x: x in s, l))
116
+ # topwears
117
+ if contains_word(desc,topwears):
118
+ mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
119
+ # bottomwears
120
+ if contains_word(desc,bottomwears):
121
+ mask, mask_gray = get_mask_location('hd', "lower_body", model_parse, keypoints)
122
  mask = mask.resize((768,1024))
123
  else:
124
  mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
 
245
  )
246
 
247
  with gr.Column():
248
+ garm_img = gr.Image(label="Topwear", sources='upload', type="pil")
249
  with gr.Row(elem_id="prompt-container"):
250
  with gr.Row():
251
+ prompt = gr.Textbox(placeholder="Description of topwear ex) Short Sleeve Black Round Neck T-shirts", show_label=False, elem_id="prompt")
252
  example = gr.Examples(
253
  inputs=garm_img,
254
  examples_per_page=8,
255
  examples=garm_list_path)
256
+
257
+ with gr.Column():
258
+ garm_img = gr.Image(label="Bottomwear", sources='upload', type="pil")
259
+ with gr.Row(elem_id="prompt-container"):
260
+ with gr.Row():
261
+ prompt = gr.Textbox(placeholder="Description of bottomwear ex) Olive Cargo Pants", show_label=False, elem_id="prompt")
262
+ example = gr.Examples(
263
+ inputs=garm_img,
264
+ examples_per_page=8,
265
+ examples=garm_list_path)
266
+
267
  with gr.Column():
268
  # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
269
  masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False)
 
273
 
274
 
275
 
 
276
  with gr.Column():
277
  try_button = gr.Button(value="Try-on")
278
  with gr.Accordion(label="Advanced Settings", open=False):
 
284
 
285
  try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon')
286
 
287
+
 
 
288
  image_blocks.launch()