from flask import Flask, request, jsonify ,send_file from PIL import Image import requests import base64 import spaces import multiprocessing from loadimg import load_img from io import BytesIO import numpy as np import insightface import onnxruntime as ort import huggingface_hub from SegCloth import segment_clothing from transparent_background import Remover import threading import logging import uuid from transformers import AutoModelForImageSegmentation,AutoModelForCausalLM, AutoProcessor import torch from torchvision import transforms import subprocess import logging import json subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) app = Flask(__name__) kwargs = {} kwargs['torch_dtype'] = torch.bfloat16 models = { "microsoft/Phi-3-vision-128k-instruct": AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").cuda().eval() } processors = { "microsoft/Phi-3-vision-128k-instruct": AutoProcessor.from_pretrained("microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True) } subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) user_prompt = '<|user|>\n' assistant_prompt = '<|assistant|>\n' prompt_suffix = "<|end|>\n" def get_image_from_url(url): try: response = requests.get(url) response.raise_for_status() # Vérifie les erreurs HTTP img = Image.open(BytesIO(response.content)) return img except Exception as e: logging.error(f"Error fetching image from URL: {e}") raise # Function to decode a base64 image to PIL.Image.Image def decode_image_from_base64(image_data): image_data = base64.b64decode(image_data) image = Image.open(BytesIO(image_data)).convert("RGB") return image # Function to encode a PIL image to base64 def encode_image_to_base64(image): buffered = BytesIO() image.save(buffered, format="PNG") # Use PNG for compatibility with RGBA return base64.b64encode(buffered.getvalue()).decode('utf-8') def extract_image(image_data): # Vérifie si l'image est en base64 ou URL if image_data.startswith('http://') or image_data.startswith('https://'): return get_image_from_url(image_data) # Télécharge l'image depuis l'URL else: return decode_image_from_base64(image_data) # Décode l'image base64 @spaces.GPU def process_vision(image, text_input=None, model_id="microsoft/Phi-3-vision-128k-instruct"): model = models[model_id] processor = processors[model_id] prompt = f"{user_prompt}<|image_1|>\n{text_input}{prompt_suffix}{assistant_prompt}" image = image.convert("RGB") inputs = processor(prompt, image, return_tensors="pt").to("cuda:0") generate_ids = model.generate(**inputs, max_new_tokens=4128, eos_token_id=processor.tokenizer.eos_token_id, ) generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] return response @app.route('/api/vision', methods=['POST']) def process_api_vision(): try: data = request.json image = data['image'] prompt = data['prompt'] image = extract_image(image) result = process_vision(image,prompt) # Remove ```json and ``` markers if result.startswith("```json"): result = result[7:] # Remove the leading ```json if result.endswith("```"): result = result[:-3] # Remove the trailing ``` # Convert the string result to a Python dictionary try: logging.info(result) result_dict = json.loads(result) except json.JSONDecodeError as e: logging.error(f"JSON decoding error: {e}") return jsonify({'error': 'Invalid JSON format in the response'}), 500 return jsonify(result_dict) except Exception as e: logging.error(f"Error occurred: {e}") return jsonify({'error': str(e)}), 500 # Configure logging logging.basicConfig(level=logging.INFO) # Load the model lazily model = None detector = None def load_model(): global model, detector path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx") options = ort.SessionOptions() options.intra_op_num_threads = 8 options.inter_op_num_threads = 8 session = ort.InferenceSession( path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"] ) model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session) model.prepare(-1, nms_thresh=0.5, input_size=(640, 640)) detector = model logging.info("Model loaded successfully.") torch.set_float32_matmul_precision(["high", "highest"][0]) birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to("cuda") transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) def save_image(img): unique_name = str(uuid.uuid4()) + ".png" img.save(unique_name) return unique_name # Function to decode a base64 image to PIL.Image.Image def decode_image_from_base64(image_data): image_data = base64.b64decode(image_data) image = Image.open(BytesIO(image_data)).convert("RGB") return image # Function to encode a PIL image to base64 def encode_image_to_base64(image): buffered = BytesIO() image.save(buffered, format="PNG") # Use PNG for compatibility with RGBA return base64.b64encode(buffered.getvalue()).decode('utf-8') @spaces.GPU def rm_background(image): im = load_img(image, output_type="pil") im = im.convert("RGB") image_size = im.size origin = im.copy() image = load_img(im) input_images = transform_image(image).unsqueeze(0).to("cuda") # Prediction with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) return (image) @spaces.GPU def remove_background(image): remover = Remover() if isinstance(image, Image.Image): output = remover.process(image) elif isinstance(image, np.ndarray): image_pil = Image.fromarray(image) output = remover.process(image_pil) else: raise TypeError("Unsupported image type") return output @spaces.GPU def detect_and_segment_persons(image, clothes): img = np.array(image) img = img[:, :, ::-1] # RGB -> BGR if detector is None: load_model() # Ensure the model is loaded bboxes, kpss = detector.detect(img) if bboxes.shape[0] == 0: return [save_image(rm_background(image))] height, width, _ = img.shape bboxes = np.round(bboxes[:, :4]).astype(int) bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width) bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height) bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width) bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height) all_segmented_images = [] for i in range(bboxes.shape[0]): bbox = bboxes[i] x1, y1, x2, y2 = bbox person_img = img[y1:y2, x1:x2] pil_img = Image.fromarray(person_img[:, :, ::-1]) img_rm_background = rm_background(pil_img) segmented_result = segment_clothing(img_rm_background, clothes) image_paths = [save_image(img) for img in segmented_result] print(image_paths) all_segmented_images.extend(image_paths) return all_segmented_images @app.route('/', methods=['GET']) def welcome(): return "Welcome to Clothing Segmentation API" @app.route('/api/detect', methods=['POST']) def detect(): try: data = request.json image_base64 = data['image'] image = decode_image_from_base64(image_base64) clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"] result = detect_and_segment_persons(image, clothes) return jsonify({'images': result}) except Exception as e: logging.error(f"Error occurred: {e}") return jsonify({'error': str(e)}), 500 # Route pour récupérer l'image générée @app.route('/api/get_image/', methods=['GET']) def get_image(image_id): # Construire le chemin complet de l'image image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde # Renvoyer l'image try: return send_file(image_path, mimetype='image/png') except FileNotFoundError: return jsonify({'error': 'Image not found'}), 404 if __name__ == "__main__": app.run(debug=True, host="0.0.0.0", port=7860)