Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
b92eece
·
verified ·
1 Parent(s): c9a29b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +331 -173
app.py CHANGED
@@ -1,21 +1,27 @@
1
  import os
2
- from flask import Flask, request, jsonify, send_file
3
  from PIL import Image
4
  from io import BytesIO
5
- import base64
6
  import torch
7
- import requests
 
 
 
8
  import numpy as np
9
- import uuid
10
  import spaces
 
 
 
 
 
11
  from transformers import (
12
  CLIPImageProcessor,
13
  CLIPVisionModelWithProjection,
14
  CLIPTextModel,
15
  CLIPTextModelWithProjection,
16
- AutoTokenizer
17
  )
18
- from diffusers import DDPMScheduler, AutoencoderKL, UNet2DConditionModel
19
  from utils_mask import get_mask_location
20
  from torchvision import transforms
21
  import apply_net
@@ -26,78 +32,114 @@ from torchvision.transforms.functional import to_pil_image
26
 
27
  app = Flask(__name__)
28
 
29
- # Variables globales pour stocker les modèles
30
- models_loaded = False
31
 
32
- def load_models():
33
- global unet, tokenizer_one, tokenizer_two, noise_scheduler, text_encoder_one, text_encoder_two
34
- global image_encoder, vae, UNet_Encoder, parsing_model, openpose_model, pipe
35
- global models_loaded
36
-
37
- if not models_loaded:
38
- base_path = 'yisol/IDM-VTON'
39
- unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16, force_download=False)
40
- unet.requires_grad_(False)
41
-
42
- tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False, force_download=False)
43
- tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False, force_download=False)
44
-
45
- noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
46
- text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16, force_download=False)
47
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16, force_download=False)
48
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16, force_download=False)
49
- vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16, force_download=False)
50
-
51
- # Set the correct encoder_hid_dim_type here
52
- UNet_Encoder = UNet2DConditionModel.from_pretrained(
53
- base_path,
54
- subfolder="unet_encoder",
55
- torch_dtype=torch.float16,
56
- encoder_hid_dim_type="text_proj", # Update based on model type
57
- force_download=False
58
- )
59
-
60
- parsing_model = Parsing(0)
61
- openpose_model = OpenPose(0)
62
-
63
- UNet_Encoder.requires_grad_(False)
64
- image_encoder.requires_grad_(False)
65
- vae.requires_grad_(False)
66
- unet.requires_grad_(False)
67
- text_encoder_one.requires_grad_(False)
68
- text_encoder_two.requires_grad_(False)
69
-
70
- tensor_transfrom = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
71
-
72
- pipe = TryonPipeline.from_pretrained(
73
- base_path,
74
- unet=unet,
75
- vae=vae,
76
- feature_extractor=CLIPImageProcessor(),
77
- text_encoder=text_encoder_one,
78
- text_encoder_2=text_encoder_two,
79
- tokenizer=tokenizer_one,
80
- tokenizer_2=tokenizer_two,
81
- scheduler=noise_scheduler,
82
- image_encoder=image_encoder,
83
- torch_dtype=torch.float16,
84
- force_download=False
85
- )
86
- pipe.unet_encoder = UNet_Encoder
87
-
88
- models_loaded = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def pil_to_binary_mask(pil_image, threshold=0):
91
- np_image = np.array(pil_image.convert("L")) # Convert to grayscale directly
92
- binary_mask = np_image > threshold
93
- mask = np.uint8(binary_mask * 255)
94
- return Image.fromarray(mask)
 
 
 
 
 
 
 
95
 
96
  def get_image_from_url(url):
97
  try:
98
  response = requests.get(url)
99
- response.raise_for_status()
100
- return Image.open(BytesIO(response.content))
 
101
  except Exception as e:
102
  logging.error(f"Error fetching image from URL: {e}")
103
  raise
@@ -105,7 +147,8 @@ def get_image_from_url(url):
105
  def decode_image_from_base64(base64_str):
106
  try:
107
  img_data = base64.b64decode(base64_str)
108
- return Image.open(BytesIO(img_data))
 
109
  except Exception as e:
110
  logging.error(f"Error decoding image: {e}")
111
  raise
@@ -114,142 +157,257 @@ def encode_image_to_base64(img):
114
  try:
115
  buffered = BytesIO()
116
  img.save(buffered, format="PNG")
117
- return base64.b64encode(buffered.getvalue()).decode("utf-8")
 
118
  except Exception as e:
119
  logging.error(f"Error encoding image: {e}")
120
  raise
121
 
122
  def save_image(img):
123
- unique_name = f"{uuid.uuid4()}.webp"
124
- img.save(unique_name, format="WEBP", lossless=True)
125
  return unique_name
126
 
127
- def clear_gpu_memory():
128
- torch.cuda.empty_cache()
129
- torch.cuda.ipc_collect()
130
-
131
  @spaces.GPU
132
- def start_tryon(human_dict, garment_image, garment_description, use_auto_mask, use_auto_crop, denoise_steps, seed, category='upper_body'):
133
  device = "cuda"
134
  openpose_model.preprocessor.body_estimation.model.to(device)
135
  pipe.to(device)
136
  pipe.unet_encoder.to(device)
137
 
138
- garment_image = garment_image.convert("RGB").resize((768, 1024))
139
- human_image_orig = human_dict["background"].convert("RGB")
140
 
141
- if use_auto_crop:
142
- width, height = human_image_orig.size
143
  target_width = int(min(width, height * (3 / 4)))
144
  target_height = int(min(height, width * (4 / 3)))
145
- left, top = (width - target_width) / 2, (height - target_height) / 2
146
- right, bottom = (width + target_width) / 2, (height + target_height) / 2
147
- cropped_img = human_image_orig.crop((left, top, right, bottom)).resize((768, 1024))
 
 
 
 
148
  else:
149
- cropped_img = human_image_orig.resize((768, 1024))
150
 
151
- if use_auto_mask:
152
- keypoints = openpose_model(cropped_img.resize((384, 512)))
153
- model_parse, _ = parsing_model(cropped_img.resize((384, 512)))
154
- mask, mask_gray = get_mask_location('hd', category, model_parse, keypoints)
155
  mask = mask.resize((768, 1024))
156
  else:
157
- mask = pil_to_binary_mask(human_dict['layers'][0].convert("RGB").resize((768, 1024)))
158
-
159
- mask_gray = (1 - transforms.ToTensor()(mask)) * transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])(cropped_img)
160
  mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
161
 
162
- human_image_arg = _apply_exif_orientation(cropped_img.resize((384, 512)))
163
- human_image_arg = convert_PIL_to_numpy(human_image_arg, format="BGR")
164
-
165
- args = apply_net.create_argument_parser().parse_args(
166
- ('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
167
- pose_image = args.func(args, human_image_arg)
168
- pose_image = Image.fromarray(pose_image[:, :, ::-1]).resize((768, 1024))
169
-
170
- with torch.no_grad(), torch.cuda.amp.autocast():
171
- prompt = "model is wearing " + garment_description
172
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
173
- prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
174
- prompt, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt
175
- )
176
-
177
- prompt_c = "a photo of " + garment_description
178
- negative_prompt_c = "monochrome, lowres, bad anatomy, worst quality, low quality"
179
- prompt_embeds_c, _, _, _ = pipe.encode_prompt(
180
- prompt_c, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=negative_prompt_c
181
- )
182
-
183
- pose_image = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])(pose_image).unsqueeze(0).to(device, torch.float16)
184
- garment_tensor = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])(garment_image).unsqueeze(0).to(device, torch.float16)
185
-
186
- images = pipe(
187
- prompt_embeds=prompt_embeds.to(device, torch.float16),
188
- negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
189
- pose_image=pose_image,
190
- garment_image=garment_tensor,
191
- mask_image=mask_gray.to(device, torch.float16),
192
- generator=torch.Generator(device).manual_seed(seed),
193
- num_inference_steps=denoise_steps
194
- ).images
195
-
196
- if images:
197
- output_image = images[0]
198
- output_base64 = encode_image_to_base64(output_image)
199
- mask_image = mask
200
- mask_base64 = encode_image_to_base64(mask_image)
201
- return output_image, mask_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  else:
203
- raise ValueError("Failed to generate image")
204
 
205
 
206
- # Route pour récupérer l'image générée
207
- @app.route('/api/get_image/<image_id>', methods=['GET'])
208
- def get_image(image_id):
209
- # Construire le chemin complet de l'image
210
- image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde
211
 
212
- # Renvoyer l'image
213
- try:
214
- return send_file(image_path, mimetype='image/webp')
215
- except FileNotFoundError:
216
- return jsonify({'error': 'Image not found'}), 404
 
217
 
218
  @app.route('/tryon', methods=['POST'])
219
- def tryon_handler():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  try:
221
- data = request.json
222
- human_image = decode_image_from_base64(data['human_image'])
223
- garment_image = decode_image_from_base64(data['garment_image'])
224
- description = data.get('description')
225
- use_auto_mask = data.get('use_auto_mask', True)
226
- use_auto_crop = data.get('use_auto_crop', False)
227
- denoise_steps = int(data.get('denoise_steps', 30))
228
- seed = int(data.get('seed', 42))
229
- category = data.get('category', 'upper_body')
230
-
231
- human_dict = {
232
- 'background': human_image,
233
- 'layers': [human_image] if not use_auto_mask else None,
234
- 'composite': None
235
- }
236
- clear_gpu_memory()
237
 
238
- output_image, mask_image = start_tryon(
239
- human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed, category
240
- )
 
241
 
242
- output_base64 = encode_image_to_base64(output_image)
243
- mask_base64 = encode_image_to_base64(mask_image)
 
 
 
 
 
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  return jsonify({
246
- 'output_image': output_base64,
247
  'mask_image': mask_base64
248
- })
249
  except Exception as e:
250
- logging.error(f"Error in tryon_handler: {e}")
251
  return jsonify({'error': str(e)}), 500
252
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  if __name__ == "__main__":
254
- load_models() # Charge les modèles au démarrage
255
- app.run(host='0.0.0.0', port=7860)
 
1
  import os
2
+ from flask import Flask, request, jsonify,send_file
3
  from PIL import Image
4
  from io import BytesIO
 
5
  import torch
6
+ import base64
7
+ import io
8
+ import logging
9
+ import gradio as gr
10
  import numpy as np
 
11
  import spaces
12
+ import uuid
13
+ import random
14
+ from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
15
+ from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
16
+ from src.unet_hacked_tryon import UNet2DConditionModel
17
  from transformers import (
18
  CLIPImageProcessor,
19
  CLIPVisionModelWithProjection,
20
  CLIPTextModel,
21
  CLIPTextModelWithProjection,
22
+ AutoTokenizer,
23
  )
24
+ from diffusers import DDPMScheduler, AutoencoderKL
25
  from utils_mask import get_mask_location
26
  from torchvision import transforms
27
  import apply_net
 
32
 
33
  app = Flask(__name__)
34
 
35
+ base_path = 'yisol/IDM-VTON'
36
+ example_path = os.path.join(os.path.dirname(__file__), 'example')
37
 
38
+ unet = UNet2DConditionModel.from_pretrained(
39
+ base_path,
40
+ subfolder="unet",
41
+ torch_dtype=torch.float16,
42
+ force_download=False
43
+ )
44
+ unet.requires_grad_(False)
45
+ tokenizer_one = AutoTokenizer.from_pretrained(
46
+ base_path,
47
+ subfolder="tokenizer",
48
+ revision=None,
49
+ use_fast=False,
50
+ force_download=False
51
+ )
52
+ tokenizer_two = AutoTokenizer.from_pretrained(
53
+ base_path,
54
+ subfolder="tokenizer_2",
55
+ revision=None,
56
+ use_fast=False,
57
+ force_download=False
58
+ )
59
+ noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
60
+
61
+ text_encoder_one = CLIPTextModel.from_pretrained(
62
+ base_path,
63
+ subfolder="text_encoder",
64
+ torch_dtype=torch.float16,
65
+ force_download=False
66
+ )
67
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
68
+ base_path,
69
+ subfolder="text_encoder_2",
70
+ torch_dtype=torch.float16,
71
+ force_download=False
72
+ )
73
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
74
+ base_path,
75
+ subfolder="image_encoder",
76
+ torch_dtype=torch.float16,
77
+ force_download=False
78
+ )
79
+ vae = AutoencoderKL.from_pretrained(base_path,
80
+ subfolder="vae",
81
+ torch_dtype=torch.float16,
82
+ force_download=False
83
+ )
84
+
85
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
86
+ base_path,
87
+ subfolder="unet_encoder",
88
+ torch_dtype=torch.float16,
89
+ force_download=False
90
+ )
91
+
92
+ parsing_model = Parsing(0)
93
+ openpose_model = OpenPose(0)
94
+
95
+ UNet_Encoder.requires_grad_(False)
96
+ image_encoder.requires_grad_(False)
97
+ vae.requires_grad_(False)
98
+ unet.requires_grad_(False)
99
+ text_encoder_one.requires_grad_(False)
100
+ text_encoder_two.requires_grad_(False)
101
+ tensor_transfrom = transforms.Compose(
102
+ [
103
+ transforms.ToTensor(),
104
+ transforms.Normalize([0.5], [0.5]),
105
+ ]
106
+ )
107
+
108
+ pipe = TryonPipeline.from_pretrained(
109
+ base_path,
110
+ unet=unet,
111
+ vae=vae,
112
+ feature_extractor= CLIPImageProcessor(),
113
+ text_encoder = text_encoder_one,
114
+ text_encoder_2 = text_encoder_two,
115
+ tokenizer = tokenizer_one,
116
+ tokenizer_2 = tokenizer_two,
117
+ scheduler = noise_scheduler,
118
+ image_encoder=image_encoder,
119
+ torch_dtype=torch.float16,
120
+ force_download=False
121
+ )
122
+ pipe.unet_encoder = UNet_Encoder
123
 
124
  def pil_to_binary_mask(pil_image, threshold=0):
125
+ np_image = np.array(pil_image)
126
+ grayscale_image = Image.fromarray(np_image).convert("L")
127
+ binary_mask = np.array(grayscale_image) > threshold
128
+ mask = np.zeros(binary_mask.shape, dtype=np.uint8)
129
+ for i in range(binary_mask.shape[0]):
130
+ for j in range(binary_mask.shape[1]):
131
+ if binary_mask[i, j]:
132
+ mask[i, j] = 1
133
+ mask = (mask * 255).astype(np.uint8)
134
+ output_mask = Image.fromarray(mask)
135
+ return output_mask
136
 
137
  def get_image_from_url(url):
138
  try:
139
  response = requests.get(url)
140
+ response.raise_for_status() # Vérifie les erreurs HTTP
141
+ img = Image.open(BytesIO(response.content))
142
+ return img
143
  except Exception as e:
144
  logging.error(f"Error fetching image from URL: {e}")
145
  raise
 
147
  def decode_image_from_base64(base64_str):
148
  try:
149
  img_data = base64.b64decode(base64_str)
150
+ img = Image.open(BytesIO(img_data))
151
+ return img
152
  except Exception as e:
153
  logging.error(f"Error decoding image: {e}")
154
  raise
 
157
  try:
158
  buffered = BytesIO()
159
  img.save(buffered, format="PNG")
160
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
161
+ return img_str
162
  except Exception as e:
163
  logging.error(f"Error encoding image: {e}")
164
  raise
165
 
166
  def save_image(img):
167
+ unique_name = str(uuid.uuid4()) + ".webp"
168
+ img.save(unique_name, format="WEBP", lossless=True)
169
  return unique_name
170
 
 
 
 
 
171
  @spaces.GPU
172
+ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie = 'upper_body'):
173
  device = "cuda"
174
  openpose_model.preprocessor.body_estimation.model.to(device)
175
  pipe.to(device)
176
  pipe.unet_encoder.to(device)
177
 
178
+ garm_img = garm_img.convert("RGB").resize((768, 1024))
179
+ human_img_orig = dict["background"].convert("RGB")
180
 
181
+ if is_checked_crop:
182
+ width, height = human_img_orig.size
183
  target_width = int(min(width, height * (3 / 4)))
184
  target_height = int(min(height, width * (4 / 3)))
185
+ left = (width - target_width) / 2
186
+ top = (height - target_height) / 2
187
+ right = (width + target_width) / 2
188
+ bottom = (height + target_height) / 2
189
+ cropped_img = human_img_orig.crop((left, top, right, bottom))
190
+ crop_size = cropped_img.size
191
+ human_img = cropped_img.resize((768, 1024))
192
  else:
193
+ human_img = human_img_orig.resize((768, 1024))
194
 
195
+ if is_checked:
196
+ keypoints = openpose_model(human_img.resize((384, 512)))
197
+ model_parse, _ = parsing_model(human_img.resize((384, 512)))
198
+ mask, mask_gray = get_mask_location('hd', categorie , model_parse, keypoints)
199
  mask = mask.resize((768, 1024))
200
  else:
201
+ mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
202
+ mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
 
203
  mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
204
 
205
+ human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
206
+ human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
207
+
208
+ args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
209
+ pose_img = args.func(args, human_img_arg)
210
+ pose_img = pose_img[:, :, ::-1]
211
+ pose_img = Image.fromarray(pose_img).resize((768, 1024))
212
+
213
+ with torch.no_grad():
214
+ with torch.cuda.amp.autocast():
215
+ prompt = "model is wearing " + garment_des
216
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
217
+ with torch.inference_mode():
218
+ (
219
+ prompt_embeds,
220
+ negative_prompt_embeds,
221
+ pooled_prompt_embeds,
222
+ negative_pooled_prompt_embeds,
223
+ ) = pipe.encode_prompt(
224
+ prompt,
225
+ num_images_per_prompt=1,
226
+ do_classifier_free_guidance=True,
227
+ negative_prompt=negative_prompt,
228
+ )
229
+
230
+ prompt = "a photo of " + garment_des
231
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
232
+ if not isinstance(prompt, list):
233
+ prompt = [prompt] * 1
234
+ if not isinstance(negative_prompt, list):
235
+ negative_prompt = [negative_prompt] * 1
236
+ with torch.inference_mode():
237
+ (
238
+ prompt_embeds_c,
239
+ _,
240
+ _,
241
+ _,
242
+ ) = pipe.encode_prompt(
243
+ prompt,
244
+ num_images_per_prompt=1,
245
+ do_classifier_free_guidance=False,
246
+ negative_prompt=negative_prompt,
247
+ )
248
+
249
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
250
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
251
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
252
+ images = pipe(
253
+ prompt_embeds=prompt_embeds.to(device, torch.float16),
254
+ negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
255
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
256
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
257
+ num_inference_steps=denoise_steps,
258
+ generator=generator,
259
+ strength=1.0,
260
+ pose_img=pose_img.to(device, torch.float16),
261
+ text_embeds_cloth=prompt_embeds_c.to(device, torch.float16),
262
+ cloth=garm_tensor.to(device, torch.float16),
263
+ mask_image=mask,
264
+ image=human_img,
265
+ height=1024,
266
+ width=768,
267
+ ip_adapter_image=garm_img.resize((768, 1024)),
268
+ guidance_scale=2.0,
269
+ )[0]
270
+
271
+ if is_checked_crop:
272
+ out_img = images[0].resize(crop_size)
273
+ human_img_orig.paste(out_img, (int(left), int(top)))
274
+ return human_img_orig, mask_gray
275
  else:
276
+ return images[0], mask_gray
277
 
278
 
279
+ def clear_gpu_memory():
280
+ torch.cuda.empty_cache()
281
+ torch.cuda.synchronize()
 
 
282
 
283
+ def process_image(image_data):
284
+ # Vérifie si l'image est en base64 ou URL
285
+ if image_data.startswith('http://') or image_data.startswith('https://'):
286
+ return get_image_from_url(image_data) # Télécharge l'image depuis l'URL
287
+ else:
288
+ return decode_image_from_base64(image_data) # Décode l'image base64
289
 
290
  @app.route('/tryon', methods=['POST'])
291
+ def tryon():
292
+ data = request.json
293
+ human_image = process_image(data['human_image'])
294
+ garment_image = process_image(data['garment_image'])
295
+ description = data.get('description')
296
+ use_auto_mask = data.get('use_auto_mask', True)
297
+ use_auto_crop = data.get('use_auto_crop', False)
298
+ denoise_steps = int(data.get('denoise_steps', 30))
299
+ seed = int(data.get('seed', 42))
300
+ categorie = data.get('categorie' , 'upper_body')
301
+ human_dict = {
302
+ 'background': human_image,
303
+ 'layers': [human_image] if not use_auto_mask else None,
304
+ 'composite': None
305
+ }
306
+ #clear_gpu_memory()
307
+
308
+ output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed , categorie)
309
+
310
+ output_base64 = encode_image_to_base64(output_image)
311
+ mask_base64 = encode_image_to_base64(mask_image)
312
+
313
+ return jsonify({
314
+ 'output_image': output_base64,
315
+ 'mask_image': mask_base64
316
+ })
317
+
318
+ @app.route('/tryon-v2', methods=['POST'])
319
+ def tryon_v2():
320
+
321
+ data = request.json
322
+ human_image_data = data['human_image']
323
+ garment_image_data = data['garment_image']
324
+
325
+ # Process images (base64 ou URL)
326
+ human_image = process_image(human_image_data)
327
+ garment_image = process_image(garment_image_data)
328
+
329
+ description = data.get('description')
330
+ use_auto_mask = data.get('use_auto_mask', True)
331
+ use_auto_crop = data.get('use_auto_crop', False)
332
+ denoise_steps = int(data.get('denoise_steps', 30))
333
+ seed = int(data.get('seed', random.randint(0, 9999999)))
334
+ categorie = data.get('categorie', 'upper_body')
335
+
336
+ # Vérifie si 'mask_image' est présent dans les données
337
+ mask_image = None
338
+ if 'mask_image' in data:
339
+ mask_image_data = data['mask_image']
340
+ mask_image = process_image(mask_image_data)
341
+
342
+ human_dict = {
343
+ 'background': human_image,
344
+ 'layers': [mask_image] if not use_auto_mask else None,
345
+ 'composite': None
346
+ }
347
+ output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed , categorie)
348
+ return jsonify({
349
+ 'image_id': save_image(output_image)
350
+ })
351
+
352
+ @spaces.GPU
353
+ def generate_mask(human_img, categorie='upper_body'):
354
+ device = "cuda"
355
+ openpose_model.preprocessor.body_estimation.model.to(device)
356
+ pipe.to(device)
357
+
358
  try:
359
+ # Redimensionner l'image pour le modèle
360
+ human_img_resized = human_img.convert("RGB").resize((384, 512))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
+ # Générer les points clés et le masque
363
+ keypoints = openpose_model(human_img_resized)
364
+ model_parse, _ = parsing_model(human_img_resized)
365
+ mask, _ = get_mask_location('hd', categorie, model_parse, keypoints)
366
 
367
+ # Redimensionner le masque à la taille d'origine de l'image
368
+ mask_resized = mask.resize(human_img.size)
369
+
370
+ return mask_resized
371
+ except Exception as e:
372
+ logging.error(f"Error generating mask: {e}")
373
+ raise e
374
 
375
+
376
+ @app.route('/generate_mask', methods=['POST'])
377
+ def generate_mask_api():
378
+ try:
379
+ # Récupérer les données de l'image à partir de la requête
380
+ data = request.json
381
+ base64_image = data.get('human_image')
382
+ categorie = data.get('categorie', 'upper_body')
383
+
384
+ # Décodage de l'image à partir de base64
385
+ human_img = process_image(base64_image)
386
+
387
+ # Appeler la fonction pour générer le masque
388
+ mask_resized = generate_mask(human_img, categorie)
389
+
390
+ # Encodage du masque en base64 pour la réponse
391
+ mask_base64 = encode_image_to_base64(mask_resized)
392
+
393
  return jsonify({
 
394
  'mask_image': mask_base64
395
+ }), 200
396
  except Exception as e:
397
+ logging.error(f"Error generating mask: {e}")
398
  return jsonify({'error': str(e)}), 500
399
 
400
+ # Route pour récupérer l'image générée
401
+ @app.route('/api/get_image/<image_id>', methods=['GET'])
402
+ def get_image(image_id):
403
+ # Construire le chemin complet de l'image
404
+ image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde
405
+
406
+ # Renvoyer l'image
407
+ try:
408
+ return send_file(image_path, mimetype='image/webp')
409
+ except FileNotFoundError:
410
+ return jsonify({'error': 'Image not found'}), 404
411
+
412
  if __name__ == "__main__":
413
+ app.run(debug=False, host="0.0.0.0", port=7860)