mikitona commited on
Commit
711c33e
·
verified ·
1 Parent(s): f77837a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -40
app.py CHANGED
@@ -122,7 +122,7 @@ pipe = TryonPipeline.from_pretrained(
122
  pipe.unet_encoder = UNet_Encoder
123
 
124
 
125
- @spaces.GPU(duration=120) # 実行時間を120秒に設定
126
  def start_tryon(
127
  dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, num_images
128
  ):
@@ -142,7 +142,7 @@ def start_tryon(
142
  left = (width - target_width) / 2
143
  top = (height - target_height) / 2
144
  right = (width + target_width) / 2
145
- bottom = (height + target_height) / 2
146
  cropped_img = human_img_orig.crop((left, top, right, bottom))
147
  crop_size = cropped_img.size
148
  human_img = cropped_img.resize((768, 1024))
@@ -178,7 +178,42 @@ def start_tryon(
178
  pose_img = pose_img[:, :, ::-1]
179
  pose_img = Image.fromarray(pose_img).resize((768, 1024))
180
 
181
- output_images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  for i in range(int(num_images)):
184
  current_seed = seed + i if seed is not None and seed != -1 else None
@@ -188,37 +223,6 @@ def start_tryon(
188
 
189
  with torch.no_grad():
190
  with torch.cuda.amp.autocast():
191
- prompt = "model is wearing " + garment_des
192
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
193
- (
194
- prompt_embeds,
195
- negative_prompt_embeds,
196
- pooled_prompt_embeds,
197
- negative_pooled_prompt_embeds,
198
- ) = pipe.encode_prompt(
199
- prompt,
200
- num_images_per_prompt=1,
201
- do_classifier_free_guidance=True,
202
- negative_prompt=negative_prompt,
203
- )
204
-
205
- prompt_c = "a photo of " + garment_des
206
- negative_prompt_c = "monochrome, lowres, bad anatomy, worst quality, low quality"
207
- (
208
- prompt_embeds_c,
209
- _,
210
- _,
211
- _,
212
- ) = pipe.encode_prompt(
213
- prompt_c,
214
- num_images_per_prompt=1,
215
- do_classifier_free_guidance=False,
216
- negative_prompt=negative_prompt_c,
217
- )
218
-
219
- pose_img_tensor = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
220
- garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
221
-
222
  images = pipe(
223
  prompt_embeds=prompt_embeds.to(device, torch.float16),
224
  negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
@@ -244,11 +248,9 @@ def start_tryon(
244
  out_img = images[0].resize(crop_size)
245
  human_img_copy = human_img_orig.copy()
246
  human_img_copy.paste(out_img, (int(left), int(top)))
247
- output_images.append(human_img_copy)
248
  else:
249
- output_images.append(images[0])
250
-
251
- return output_images, mask_gray
252
 
253
 
254
  garm_list = os.listdir(os.path.join(example_path, "cloth"))
@@ -299,9 +301,11 @@ with image_blocks as demo:
299
  )
300
  example = gr.Examples(inputs=garm_img, examples_per_page=8, examples=garm_list_path)
301
  with gr.Column():
302
- masked_img = gr.Image(label="Masked image output", elem_id="masked-img", show_share_button=False)
 
 
303
  with gr.Column():
304
- image_out = gr.Gallery(label="Output", elem_id="output-img", show_share_button=False)
305
 
306
  with gr.Column():
307
  try_button = gr.Button(value="Try-on")
 
122
  pipe.unet_encoder = UNet_Encoder
123
 
124
 
125
+ @spaces.GPU(duration=110) # 実行時間を110秒に設定
126
  def start_tryon(
127
  dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, num_images
128
  ):
 
142
  left = (width - target_width) / 2
143
  top = (height - target_height) / 2
144
  right = (width + target_width) / 2
145
+ bottom = (height + target_width) / 2
146
  cropped_img = human_img_orig.crop((left, top, right, bottom))
147
  crop_size = cropped_img.size
148
  human_img = cropped_img.resize((768, 1024))
 
178
  pose_img = pose_img[:, :, ::-1]
179
  pose_img = Image.fromarray(pose_img).resize((768, 1024))
180
 
181
+ # テキストエンコーディングは一度だけ行う
182
+ with torch.no_grad():
183
+ with torch.cuda.amp.autocast():
184
+ prompt = "model is wearing " + garment_des
185
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
186
+ (
187
+ prompt_embeds,
188
+ negative_prompt_embeds,
189
+ pooled_prompt_embeds,
190
+ negative_pooled_prompt_embeds,
191
+ ) = pipe.encode_prompt(
192
+ prompt,
193
+ num_images_per_prompt=1,
194
+ do_classifier_free_guidance=True,
195
+ negative_prompt=negative_prompt,
196
+ )
197
+
198
+ prompt_c = "a photo of " + garment_des
199
+ negative_prompt_c = "monochrome, lowres, bad anatomy, worst quality, low quality"
200
+ (
201
+ prompt_embeds_c,
202
+ _,
203
+ _,
204
+ _,
205
+ ) = pipe.encode_prompt(
206
+ prompt_c,
207
+ num_images_per_prompt=1,
208
+ do_classifier_free_guidance=False,
209
+ negative_prompt=negative_prompt_c,
210
+ )
211
+
212
+ pose_img_tensor = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
213
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
214
+
215
+ # 最初にマスク画像を一度だけ出力
216
+ yield None, mask_gray
217
 
218
  for i in range(int(num_images)):
219
  current_seed = seed + i if seed is not None and seed != -1 else None
 
223
 
224
  with torch.no_grad():
225
  with torch.cuda.amp.autocast():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  images = pipe(
227
  prompt_embeds=prompt_embeds.to(device, torch.float16),
228
  negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
 
248
  out_img = images[0].resize(crop_size)
249
  human_img_copy = human_img_orig.copy()
250
  human_img_copy.paste(out_img, (int(left), int(top)))
251
+ yield human_img_copy, None
252
  else:
253
+ yield images[0], None
 
 
254
 
255
 
256
  garm_list = os.listdir(os.path.join(example_path, "cloth"))
 
301
  )
302
  example = gr.Examples(inputs=garm_img, examples_per_page=8, examples=garm_list_path)
303
  with gr.Column():
304
+ masked_img = gr.Image(
305
+ label="Masked image output", elem_id="masked-img", show_share_button=False
306
+ )
307
  with gr.Column():
308
+ image_out = gr.Image(label="Output", elem_id="output-img", show_share_button=False)
309
 
310
  with gr.Column():
311
  try_button = gr.Button(value="Try-on")