Saad0KH commited on
Commit
b56c80f
·
verified ·
1 Parent(s): b92eece

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -288
app.py CHANGED
@@ -1,145 +1,93 @@
1
  import os
2
- from flask import Flask, request, jsonify,send_file
3
- from PIL import Image
4
- from io import BytesIO
5
- import torch
6
  import base64
7
- import io
8
  import logging
9
- import gradio as gr
10
- import numpy as np
11
- import spaces
12
  import uuid
13
- import random
14
- from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
15
- from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
16
- from src.unet_hacked_tryon import UNet2DConditionModel
 
 
 
17
  from transformers import (
18
  CLIPImageProcessor,
19
  CLIPVisionModelWithProjection,
20
  CLIPTextModel,
21
  CLIPTextModelWithProjection,
22
- AutoTokenizer,
23
  )
24
  from diffusers import DDPMScheduler, AutoencoderKL
25
- from utils_mask import get_mask_location
26
- from torchvision import transforms
27
- import apply_net
28
  from preprocess.humanparsing.run_parsing import Parsing
29
  from preprocess.openpose.run_openpose import OpenPose
30
  from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
31
- from torchvision.transforms.functional import to_pil_image
 
 
 
32
 
33
  app = Flask(__name__)
34
 
35
  base_path = 'yisol/IDM-VTON'
36
  example_path = os.path.join(os.path.dirname(__file__), 'example')
37
 
38
- unet = UNet2DConditionModel.from_pretrained(
39
- base_path,
40
- subfolder="unet",
41
- torch_dtype=torch.float16,
42
- force_download=False
43
- )
44
- unet.requires_grad_(False)
45
- tokenizer_one = AutoTokenizer.from_pretrained(
46
- base_path,
47
- subfolder="tokenizer",
48
- revision=None,
49
- use_fast=False,
50
- force_download=False
51
- )
52
- tokenizer_two = AutoTokenizer.from_pretrained(
53
- base_path,
54
- subfolder="tokenizer_2",
55
- revision=None,
56
- use_fast=False,
57
- force_download=False
58
- )
59
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
60
 
61
- text_encoder_one = CLIPTextModel.from_pretrained(
62
- base_path,
63
- subfolder="text_encoder",
64
- torch_dtype=torch.float16,
65
- force_download=False
66
- )
67
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
68
- base_path,
69
- subfolder="text_encoder_2",
70
- torch_dtype=torch.float16,
71
- force_download=False
72
- )
73
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
74
- base_path,
75
- subfolder="image_encoder",
76
- torch_dtype=torch.float16,
77
- force_download=False
78
- )
79
- vae = AutoencoderKL.from_pretrained(base_path,
80
- subfolder="vae",
81
- torch_dtype=torch.float16,
82
- force_download=False
83
- )
84
 
85
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
86
- base_path,
87
- subfolder="unet_encoder",
88
- torch_dtype=torch.float16,
89
- force_download=False
90
- )
91
 
92
  parsing_model = Parsing(0)
93
  openpose_model = OpenPose(0)
94
 
95
- UNet_Encoder.requires_grad_(False)
96
- image_encoder.requires_grad_(False)
97
- vae.requires_grad_(False)
98
- unet.requires_grad_(False)
99
- text_encoder_one.requires_grad_(False)
100
- text_encoder_two.requires_grad_(False)
101
- tensor_transfrom = transforms.Compose(
102
- [
103
- transforms.ToTensor(),
104
- transforms.Normalize([0.5], [0.5]),
105
- ]
106
- )
107
 
108
  pipe = TryonPipeline.from_pretrained(
109
- base_path,
110
- unet=unet,
111
- vae=vae,
112
- feature_extractor= CLIPImageProcessor(),
113
- text_encoder = text_encoder_one,
114
- text_encoder_2 = text_encoder_two,
115
- tokenizer = tokenizer_one,
116
- tokenizer_2 = tokenizer_two,
117
- scheduler = noise_scheduler,
118
- image_encoder=image_encoder,
119
- torch_dtype=torch.float16,
120
- force_download=False
121
  )
122
  pipe.unet_encoder = UNet_Encoder
123
 
124
  def pil_to_binary_mask(pil_image, threshold=0):
125
- np_image = np.array(pil_image)
126
- grayscale_image = Image.fromarray(np_image).convert("L")
127
- binary_mask = np.array(grayscale_image) > threshold
128
- mask = np.zeros(binary_mask.shape, dtype=np.uint8)
129
- for i in range(binary_mask.shape[0]):
130
- for j in range(binary_mask.shape[1]):
131
- if binary_mask[i, j]:
132
- mask[i, j] = 1
133
- mask = (mask * 255).astype(np.uint8)
134
- output_mask = Image.fromarray(mask)
135
- return output_mask
136
 
137
  def get_image_from_url(url):
138
  try:
139
  response = requests.get(url)
140
- response.raise_for_status() # Vérifie les erreurs HTTP
141
- img = Image.open(BytesIO(response.content))
142
- return img
143
  except Exception as e:
144
  logging.error(f"Error fetching image from URL: {e}")
145
  raise
@@ -147,8 +95,7 @@ def get_image_from_url(url):
147
  def decode_image_from_base64(base64_str):
148
  try:
149
  img_data = base64.b64decode(base64_str)
150
- img = Image.open(BytesIO(img_data))
151
- return img
152
  except Exception as e:
153
  logging.error(f"Error decoding image: {e}")
154
  raise
@@ -157,36 +104,33 @@ def encode_image_to_base64(img):
157
  try:
158
  buffered = BytesIO()
159
  img.save(buffered, format="PNG")
160
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
161
- return img_str
162
  except Exception as e:
163
  logging.error(f"Error encoding image: {e}")
164
  raise
165
 
166
  def save_image(img):
167
- unique_name = str(uuid.uuid4()) + ".webp"
168
- img.save(unique_name, format="WEBP", lossless=True)
169
  return unique_name
170
 
171
  @spaces.GPU
172
- def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie = 'upper_body'):
173
  device = "cuda"
174
  openpose_model.preprocessor.body_estimation.model.to(device)
175
  pipe.to(device)
176
  pipe.unet_encoder.to(device)
177
 
178
  garm_img = garm_img.convert("RGB").resize((768, 1024))
179
- human_img_orig = dict["background"].convert("RGB")
180
 
181
  if is_checked_crop:
182
  width, height = human_img_orig.size
183
- target_width = int(min(width, height * (3 / 4)))
184
- target_height = int(min(height, width * (4 / 3)))
185
  left = (width - target_width) / 2
186
  top = (height - target_height) / 2
187
- right = (width + target_width) / 2
188
- bottom = (height + target_height) / 2
189
- cropped_img = human_img_orig.crop((left, top, right, bottom))
190
  crop_size = cropped_img.size
191
  human_img = cropped_img.resize((768, 1024))
192
  else:
@@ -195,78 +139,51 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
195
  if is_checked:
196
  keypoints = openpose_model(human_img.resize((384, 512)))
197
  model_parse, _ = parsing_model(human_img.resize((384, 512)))
198
- mask, mask_gray = get_mask_location('hd', categorie , model_parse, keypoints)
199
  mask = mask.resize((768, 1024))
200
  else:
201
- mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
 
202
  mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
203
  mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
204
 
205
  human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
206
  human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
207
 
208
- args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
209
- pose_img = args.func(args, human_img_arg)
210
- pose_img = pose_img[:, :, ::-1]
 
211
  pose_img = Image.fromarray(pose_img).resize((768, 1024))
212
 
213
- with torch.no_grad():
214
- with torch.cuda.amp.autocast():
215
- prompt = "model is wearing " + garment_des
216
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
217
- with torch.inference_mode():
218
- (
219
- prompt_embeds,
220
- negative_prompt_embeds,
221
- pooled_prompt_embeds,
222
- negative_pooled_prompt_embeds,
223
- ) = pipe.encode_prompt(
224
- prompt,
225
- num_images_per_prompt=1,
226
- do_classifier_free_guidance=True,
227
- negative_prompt=negative_prompt,
228
- )
229
-
230
- prompt = "a photo of " + garment_des
231
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
232
- if not isinstance(prompt, list):
233
- prompt = [prompt] * 1
234
- if not isinstance(negative_prompt, list):
235
- negative_prompt = [negative_prompt] * 1
236
- with torch.inference_mode():
237
- (
238
- prompt_embeds_c,
239
- _,
240
- _,
241
- _,
242
- ) = pipe.encode_prompt(
243
- prompt,
244
- num_images_per_prompt=1,
245
- do_classifier_free_guidance=False,
246
- negative_prompt=negative_prompt,
247
- )
248
-
249
- pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
250
- garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
251
- generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
252
- images = pipe(
253
- prompt_embeds=prompt_embeds.to(device, torch.float16),
254
- negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
255
- pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
256
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
257
- num_inference_steps=denoise_steps,
258
- generator=generator,
259
- strength=1.0,
260
- pose_img=pose_img.to(device, torch.float16),
261
- text_embeds_cloth=prompt_embeds_c.to(device, torch.float16),
262
- cloth=garm_tensor.to(device, torch.float16),
263
- mask_image=mask,
264
- image=human_img,
265
- height=1024,
266
- width=768,
267
- ip_adapter_image=garm_img.resize((768, 1024)),
268
- guidance_scale=2.0,
269
- )[0]
270
 
271
  if is_checked_crop:
272
  out_img = images[0].resize(crop_size)
@@ -275,126 +192,56 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
275
  else:
276
  return images[0], mask_gray
277
 
278
-
279
  def clear_gpu_memory():
280
  torch.cuda.empty_cache()
281
  torch.cuda.synchronize()
282
 
283
  def process_image(image_data):
284
- # Vérifie si l'image est en base64 ou URL
285
- if image_data.startswith('http://') or image_data.startswith('https://'):
286
- return get_image_from_url(image_data) # Télécharge l'image depuis l'URL
287
- else:
288
- return decode_image_from_base64(image_data) # Décode l'image base64
289
 
290
  @app.route('/tryon', methods=['POST'])
291
  def tryon():
292
  data = request.json
293
- human_image = process_image(data['human_image'])
294
- garment_image = process_image(data['garment_image'])
295
- description = data.get('description')
296
- use_auto_mask = data.get('use_auto_mask', True)
297
- use_auto_crop = data.get('use_auto_crop', False)
298
- denoise_steps = int(data.get('denoise_steps', 30))
299
- seed = int(data.get('seed', 42))
300
- categorie = data.get('categorie' , 'upper_body')
301
- human_dict = {
302
- 'background': human_image,
303
- 'layers': [human_image] if not use_auto_mask else None,
304
- 'composite': None
305
- }
306
- #clear_gpu_memory()
307
-
308
- output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed , categorie)
309
-
310
- output_base64 = encode_image_to_base64(output_image)
311
- mask_base64 = encode_image_to_base64(mask_image)
312
-
313
- return jsonify({
314
- 'output_image': output_base64,
315
- 'mask_image': mask_base64
316
- })
317
-
318
- @app.route('/tryon-v2', methods=['POST'])
319
- def tryon_v2():
320
-
321
- data = request.json
322
- human_image_data = data['human_image']
323
- garment_image_data = data['garment_image']
324
-
325
- # Process images (base64 ou URL)
326
- human_image = process_image(human_image_data)
327
- garment_image = process_image(garment_image_data)
328
-
329
- description = data.get('description')
330
- use_auto_mask = data.get('use_auto_mask', True)
331
- use_auto_crop = data.get('use_auto_crop', False)
332
- denoise_steps = int(data.get('denoise_steps', 30))
333
- seed = int(data.get('seed', random.randint(0, 9999999)))
334
- categorie = data.get('categorie', 'upper_body')
335
-
336
- # Vérifie si 'mask_image' est présent dans les données
337
- mask_image = None
338
- if 'mask_image' in data:
339
- mask_image_data = data['mask_image']
340
- mask_image = process_image(mask_image_data)
341
-
342
- human_dict = {
343
- 'background': human_image,
344
- 'layers': [mask_image] if not use_auto_mask else None,
345
- 'composite': None
346
- }
347
- output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed , categorie)
348
- return jsonify({
349
- 'image_id': save_image(output_image)
350
- })
351
-
352
- @spaces.GPU
353
- def generate_mask(human_img, categorie='upper_body'):
354
- device = "cuda"
355
- openpose_model.preprocessor.body_estimation.model.to(device)
356
- pipe.to(device)
357
-
358
  try:
359
- # Redimensionner l'image pour le modèle
360
- human_img_resized = human_img.convert("RGB").resize((384, 512))
361
-
362
- # Générer les points clés et le masque
363
- keypoints = openpose_model(human_img_resized)
364
- model_parse, _ = parsing_model(human_img_resized)
365
- mask, _ = get_mask_location('hd', categorie, model_parse, keypoints)
366
-
367
- # Redimensionner le masque à la taille d'origine de l'image
368
- mask_resized = mask.resize(human_img.size)
369
-
370
- return mask_resized
371
- except Exception as e:
372
- logging.error(f"Error generating mask: {e}")
373
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
-
376
- @app.route('/generate_mask', methods=['POST'])
377
- def generate_mask_api():
378
- try:
379
- # Récupérer les données de l'image à partir de la requête
380
- data = request.json
381
- base64_image = data.get('human_image')
382
- categorie = data.get('categorie', 'upper_body')
383
-
384
- # Décodage de l'image à partir de base64
385
- human_img = process_image(base64_image)
386
-
387
- # Appeler la fonction pour générer le masque
388
- mask_resized = generate_mask(human_img, categorie)
389
-
390
- # Encodage du masque en base64 pour la réponse
391
- mask_base64 = encode_image_to_base64(mask_resized)
392
-
393
  return jsonify({
394
- 'mask_image': mask_base64
395
- }), 200
 
 
396
  except Exception as e:
397
- logging.error(f"Error generating mask: {e}")
398
  return jsonify({'error': str(e)}), 500
399
 
400
  # Route pour récupérer l'image générée
@@ -410,4 +257,4 @@ def get_image(image_id):
410
  return jsonify({'error': 'Image not found'}), 404
411
 
412
  if __name__ == "__main__":
413
- app.run(debug=False, host="0.0.0.0", port=7860)
 
1
  import os
 
 
 
 
2
  import base64
 
3
  import logging
 
 
 
4
  import uuid
5
+ import requests
6
+ import torch
7
+ from flask import Flask, request, jsonify, send_file
8
+ from PIL import Image
9
+ from io import BytesIO
10
+ from torchvision import transforms
11
+ from torchvision.transforms.functional import to_pil_image
12
  from transformers import (
13
  CLIPImageProcessor,
14
  CLIPVisionModelWithProjection,
15
  CLIPTextModel,
16
  CLIPTextModelWithProjection,
17
+ AutoTokenizer
18
  )
19
  from diffusers import DDPMScheduler, AutoencoderKL
 
 
 
20
  from preprocess.humanparsing.run_parsing import Parsing
21
  from preprocess.openpose.run_openpose import OpenPose
22
  from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
23
+ from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
24
+ from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
25
+ from src.unet_hacked_tryon import UNet2DConditionModel
26
+ import apply_net
27
 
28
  app = Flask(__name__)
29
 
30
  base_path = 'yisol/IDM-VTON'
31
  example_path = os.path.join(os.path.dirname(__file__), 'example')
32
 
33
+ # Load models
34
+ def load_model(name, subfolder, dtype=torch.float16):
35
+ return torch.load(
36
+ os.path.join(base_path, subfolder, name),
37
+ map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
38
+ dtype=dtype
39
+ )
40
+
41
+ unet = load_model("unet.pt", "unet")
42
+ tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False)
43
+ tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False)
 
 
 
 
 
 
 
 
 
 
44
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
45
 
46
+ text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16)
47
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16)
48
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16)
49
+ vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16)
 
 
 
 
 
52
 
53
  parsing_model = Parsing(0)
54
  openpose_model = OpenPose(0)
55
 
56
+ # Disable gradient computation
57
+ for model in [unet, UNet_Encoder, image_encoder, vae, text_encoder_one, text_encoder_two]:
58
+ model.requires_grad_(False)
59
+
60
+ tensor_transfrom = transforms.Compose([
61
+ transforms.ToTensor(),
62
+ transforms.Normalize([0.5], [0.5]),
63
+ ])
 
 
 
 
64
 
65
  pipe = TryonPipeline.from_pretrained(
66
+ base_path,
67
+ unet=unet,
68
+ vae=vae,
69
+ feature_extractor=CLIPImageProcessor(),
70
+ text_encoder=text_encoder_one,
71
+ text_encoder_2=text_encoder_two,
72
+ tokenizer=tokenizer_one,
73
+ tokenizer_2=tokenizer_two,
74
+ scheduler=noise_scheduler,
75
+ image_encoder=image_encoder,
76
+ torch_dtype=torch.float16
 
77
  )
78
  pipe.unet_encoder = UNet_Encoder
79
 
80
  def pil_to_binary_mask(pil_image, threshold=0):
81
+ np_image = np.array(pil_image.convert("L"))
82
+ binary_mask = np_image > threshold
83
+ mask = (binary_mask * 255).astype(np.uint8)
84
+ return Image.fromarray(mask)
 
 
 
 
 
 
 
85
 
86
  def get_image_from_url(url):
87
  try:
88
  response = requests.get(url)
89
+ response.raise_for_status()
90
+ return Image.open(BytesIO(response.content))
 
91
  except Exception as e:
92
  logging.error(f"Error fetching image from URL: {e}")
93
  raise
 
95
  def decode_image_from_base64(base64_str):
96
  try:
97
  img_data = base64.b64decode(base64_str)
98
+ return Image.open(BytesIO(img_data))
 
99
  except Exception as e:
100
  logging.error(f"Error decoding image: {e}")
101
  raise
 
104
  try:
105
  buffered = BytesIO()
106
  img.save(buffered, format="PNG")
107
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
 
108
  except Exception as e:
109
  logging.error(f"Error encoding image: {e}")
110
  raise
111
 
112
  def save_image(img):
113
+ unique_name = f"{uuid.uuid4()}.webp"
114
+ img.save(unique_name, format="WEBP", lossless=True)
115
  return unique_name
116
 
117
  @spaces.GPU
118
+ def start_tryon(human_dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie='upper_body'):
119
  device = "cuda"
120
  openpose_model.preprocessor.body_estimation.model.to(device)
121
  pipe.to(device)
122
  pipe.unet_encoder.to(device)
123
 
124
  garm_img = garm_img.convert("RGB").resize((768, 1024))
125
+ human_img_orig = human_dict["background"].convert("RGB")
126
 
127
  if is_checked_crop:
128
  width, height = human_img_orig.size
129
+ target_width = min(width, height * (3 / 4))
130
+ target_height = min(height, width * (4 / 3))
131
  left = (width - target_width) / 2
132
  top = (height - target_height) / 2
133
+ cropped_img = human_img_orig.crop((left, top, width - left, height - top))
 
 
134
  crop_size = cropped_img.size
135
  human_img = cropped_img.resize((768, 1024))
136
  else:
 
139
  if is_checked:
140
  keypoints = openpose_model(human_img.resize((384, 512)))
141
  model_parse, _ = parsing_model(human_img.resize((384, 512)))
142
+ mask, mask_gray = get_mask_location('hd', categorie, model_parse, keypoints)
143
  mask = mask.resize((768, 1024))
144
  else:
145
+ mask = pil_to_binary_mask(human_dict['layers'][0].convert("RGB").resize((768, 1024)))
146
+
147
  mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
148
  mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
149
 
150
  human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
151
  human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
152
 
153
+ args = apply_net.create_argument_parser().parse_args(
154
+ ('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda')
155
+ )
156
+ pose_img = args.func(args, human_img_arg)[:, :, ::-1]
157
  pose_img = Image.fromarray(pose_img).resize((768, 1024))
158
 
159
+ with torch.no_grad(), torch.cuda.amp.autocast():
160
+ prompt = f"model is wearing {garment_des}"
161
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
162
+ prompt_embeds = pipe.encode_prompt(prompt, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
163
+ prompt = f"a photo of {garment_des}"
164
+ prompt_embeds_c = pipe.encode_prompt(prompt, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=negative_prompt)
165
+
166
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
167
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
168
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
169
+ images = pipe(
170
+ prompt_embeds=prompt_embeds.to(device, torch.float16),
171
+ negative_prompt_embeds=prompt_embeds[1].to(device, torch.float16),
172
+ pooled_prompt_embeds=prompt_embeds[2].to(device, torch.float16),
173
+ negative_pooled_prompt_embeds=prompt_embeds[3].to(device, torch.float16),
174
+ num_inference_steps=denoise_steps,
175
+ generator=generator,
176
+ strength=1.0,
177
+ pose_img=pose_img,
178
+ text_embeds_cloth=prompt_embeds_c.to(device, torch.float16),
179
+ cloth=garm_tensor,
180
+ mask_image=mask,
181
+ image=human_img,
182
+ height=1024,
183
+ width=768,
184
+ ip_adapter_image=garm_img.resize((768, 1024)),
185
+ guidance_scale=2.0
186
+ )[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  if is_checked_crop:
189
  out_img = images[0].resize(crop_size)
 
192
  else:
193
  return images[0], mask_gray
194
 
 
195
  def clear_gpu_memory():
196
  torch.cuda.empty_cache()
197
  torch.cuda.synchronize()
198
 
199
  def process_image(image_data):
200
+ if image_data.startswith(('http://', 'https://')):
201
+ return get_image_from_url(image_data)
202
+ return decode_image_from_base64(image_data)
 
 
203
 
204
  @app.route('/tryon', methods=['POST'])
205
  def tryon():
206
  data = request.json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  try:
208
+ human_image_data = process_image(data['human_image'])
209
+ garment_image_data = process_image(data['garment_image'])
210
+ category = data.get('category', 'upper_body')
211
+ description = data.get('description', '')
212
+ checked = data.get('checked', False)
213
+ checked_crop = data.get('checked_crop', False)
214
+ denoise_steps = data.get('denoise_steps', 50)
215
+ seed = data.get('seed', None)
216
+
217
+ human_dict = {
218
+ "background": human_image_data,
219
+ "layers": [human_image_data],
220
+ }
221
+
222
+ result_img, mask_img = start_tryon(
223
+ human_dict,
224
+ garment_image_data,
225
+ description,
226
+ checked,
227
+ checked_crop,
228
+ denoise_steps,
229
+ seed,
230
+ category
231
+ )
232
+
233
+ encoded_image = encode_image_to_base64(result_img)
234
+ encoded_mask = encode_image_to_base64(mask_img)
235
+
236
+ #clear_gpu_memory()
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  return jsonify({
239
+ 'result_image': encoded_image,
240
+ 'mask_image': encoded_mask,
241
+ })
242
+
243
  except Exception as e:
244
+ logging.error(f"Error in /tryon endpoint: {e}")
245
  return jsonify({'error': str(e)}), 500
246
 
247
  # Route pour récupérer l'image générée
 
257
  return jsonify({'error': 'Image not found'}), 404
258
 
259
  if __name__ == "__main__":
260
+ app.run(debug=False, host="0.0.0.0", port=7860)