from flask import Flask, request, jsonify from PIL import Image import base64 from io import BytesIO import numpy as np import cv2 import insightface import onnxruntime as ort import huggingface_hub from SegCloth import segment_clothing from transparent_background import Remover app = Flask(__name__) # Charger le modèle def load_model(): path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx") options = ort.SessionOptions() options.intra_op_num_threads = 8 options.inter_op_num_threads = 8 session = ort.InferenceSession( path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"] ) model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session) return model detector = load_model() detector.prepare(-1, nms_thresh=0.5, input_size=(640, 640)) # Fonction pour décoder une image encodée en base64 en objet PIL.Image.Image def decode_image_from_base64(image_data): image_data = base64.b64decode(image_data) image = Image.open(BytesIO(image_data)).convert("RGB") # Convertir en RGB pour éviter les problèmes de canal alpha return image # Fonction pour encoder une image PIL en base64 def encode_image_to_base64(image): buffered = BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode('utf-8') #@spaces.GPU def remove_background(image): remover = Remover() if isinstance(image, Image.Image): output = remover.process(image) elif isinstance(image, np.ndarray): image_pil = Image.fromarray(image) output = remover.process(image_pil) else: raise TypeError("Unsupported image type") return output # Détecter les personnes et segmenter leurs vêtements def detect_and_segment_persons(image, clothes): img = np.array(image) img = img[:, :, ::-1] # RGB -> BGR bboxes, kpss = detector.detect(img) if bboxes.shape[0] == 0: # Aucun visage détecté return [] height, width, _ = img.shape # Get image dimensions bboxes = np.round(bboxes[:, :4]).astype(int) # Clamp bounding boxes within image boundaries bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width) # x1 bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height) # y1 bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width) # x2 bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height) # y2 all_segmented_images = [] for i in range(bboxes.shape[0]): bbox = bboxes[i] x1, y1, x2, y2 = bbox person_img = img[y1:y2, x1:x2] # Convert numpy array to PIL Image pil_img = Image.fromarray(person_img[:, :, ::-1]) # BGR -> RGB # Segment clothing for the detected person img_rm_background = remove_background(pil_img) segmented_result = segment_clothing(img_rm_background, clothes) # Combine the segmented images for all persons all_segmented_images.extend(segmented_result) return all_segmented_images @app.route('/', methods=['GET']) def welcome(): return "Welcome to Clothing Segmentation API" @app.route('/api/detect', methods=['POST']) def detect(): try: data = request.json image_base64 = data['image'] image = decode_image_from_base64(image_base64) # Détection et segmentation des personnes clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"] person_images_base64 = detect_and_segment_persons(image, clothes) return jsonify({'images': person_images_base64}) except Exception as e: print(e) return jsonify({'error': str(e)}), 500 if __name__ == "__main__": app.run(debug=True, host="0.0.0.0", port=7860)