Spaces:
mikitona
/
Running on Zero

mikitona commited on
Commit
30456df
ยท
verified ยท
1 Parent(s): 4f711d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -72
app.py CHANGED
@@ -178,81 +178,77 @@ def start_tryon(
178
  pose_img = pose_img[:, :, ::-1]
179
  pose_img = Image.fromarray(pose_img).resize((768, 1024))
180
 
181
- with torch.no_grad():
182
- with torch.cuda.amp.autocast():
183
- prompt = "model is wearing " + garment_des
184
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
185
- (
186
- prompt_embeds,
187
- negative_prompt_embeds,
188
- pooled_prompt_embeds,
189
- negative_pooled_prompt_embeds,
190
- ) = pipe.encode_prompt(
191
- prompt,
192
- num_images_per_prompt=num_images,
193
- do_classifier_free_guidance=True,
194
- negative_prompt=negative_prompt,
195
- )
196
 
197
- prompt = "a photo of " + garment_des
198
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
199
- if not isinstance(prompt, List):
200
- prompt = [prompt] * num_images
201
- if not isinstance(negative_prompt, List):
202
- negative_prompt = [negative_prompt] * num_images
203
- (
204
- prompt_embeds_c,
205
- _,
206
- _,
207
- _,
208
- ) = pipe.encode_prompt(
209
- prompt,
210
- num_images_per_prompt=num_images,
211
- do_classifier_free_guidance=False,
212
- negative_prompt=negative_prompt,
213
- )
 
 
 
 
214
 
215
- pose_img_tensor = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
216
- garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
217
- pose_img_tensor = pose_img_tensor.repeat(num_images, 1, 1, 1)
218
- garm_tensor = garm_tensor.repeat(num_images, 1, 1, 1)
219
- human_imgs = [human_img] * num_images
220
- masks = [mask] * num_images
221
- ip_adapter_images = [garm_img.resize((768, 1024))] * num_images
222
-
223
- if seed is not None and seed != -1:
224
- generator = [torch.Generator(device).manual_seed(seed + i) for i in range(num_images)]
225
- else:
226
- generator = None
227
-
228
- images = pipe(
229
- prompt_embeds=prompt_embeds.to(device, torch.float16),
230
- negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
231
- pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
232
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
233
- num_inference_steps=denoise_steps,
234
- generator=generator,
235
- strength=1.0,
236
- pose_img=pose_img_tensor.to(device, torch.float16),
237
- text_embeds_cloth=prompt_embeds_c.to(device, torch.float16),
238
- cloth=garm_tensor.to(device, torch.float16),
239
- mask_image=masks,
240
- image=human_imgs,
241
- height=1024,
242
- width=768,
243
- ip_adapter_image=ip_adapter_images,
244
- guidance_scale=2.0,
245
- )[0]
246
 
247
- if is_checked_crop:
248
- output_images = []
249
- for img in images:
250
- out_img = img.resize(crop_size)
251
- human_img_orig.paste(out_img, (int(left), int(top)))
252
- output_images.append(human_img_orig.copy())
253
- return output_images, mask_gray
254
- else:
255
- return images, mask_gray
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
 
258
  garm_list = os.listdir(os.path.join(example_path, "cloth"))
 
178
  pose_img = pose_img[:, :, ::-1]
179
  pose_img = Image.fromarray(pose_img).resize((768, 1024))
180
 
181
+ output_images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ for i in range(int(num_images)):
184
+ current_seed = seed + i if seed is not None and seed != -1 else None
185
+ generator = (
186
+ torch.Generator(device).manual_seed(int(current_seed)) if current_seed is not None else None
187
+ )
188
+
189
+ with torch.no_grad():
190
+ with torch.cuda.amp.autocast():
191
+ prompt = "model is wearing " + garment_des
192
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
193
+ (
194
+ prompt_embeds,
195
+ negative_prompt_embeds,
196
+ pooled_prompt_embeds,
197
+ negative_pooled_prompt_embeds,
198
+ ) = pipe.encode_prompt(
199
+ prompt,
200
+ num_images_per_prompt=1,
201
+ do_classifier_free_guidance=True,
202
+ negative_prompt=negative_prompt,
203
+ )
204
 
205
+ prompt_c = "a photo of " + garment_des
206
+ negative_prompt_c = "monochrome, lowres, bad anatomy, worst quality, low quality"
207
+ (
208
+ prompt_embeds_c,
209
+ _,
210
+ _,
211
+ _,
212
+ ) = pipe.encode_prompt(
213
+ prompt_c,
214
+ num_images_per_prompt=1,
215
+ do_classifier_free_guidance=False,
216
+ negative_prompt=negative_prompt_c,
217
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
+ pose_img_tensor = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
220
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
221
+
222
+ images = pipe(
223
+ prompt_embeds=prompt_embeds.to(device, torch.float16),
224
+ negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
225
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
226
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(
227
+ device, torch.float16
228
+ ),
229
+ num_inference_steps=denoise_steps,
230
+ generator=generator,
231
+ strength=1.0,
232
+ pose_img=pose_img_tensor.to(device, torch.float16),
233
+ text_embeds_cloth=prompt_embeds_c.to(device, torch.float16),
234
+ cloth=garm_tensor.to(device, torch.float16),
235
+ mask_image=mask,
236
+ image=human_img,
237
+ height=1024,
238
+ width=768,
239
+ ip_adapter_image=garm_img.resize((768, 1024)),
240
+ guidance_scale=2.0,
241
+ )[0]
242
+
243
+ if is_checked_crop:
244
+ out_img = images[0].resize(crop_size)
245
+ human_img_copy = human_img_orig.copy()
246
+ human_img_copy.paste(out_img, (int(left), int(top)))
247
+ output_images.append(human_img_copy)
248
+ else:
249
+ output_images.append(images[0])
250
+
251
+ return output_images, mask_gray
252
 
253
 
254
  garm_list = os.listdir(os.path.join(example_path, "cloth"))