Saad0KH's picture
Update app.py
2edf2fd verified
raw
history blame
3.76 kB
from flask import Flask, request, jsonify
from PIL import Image
import base64
from io import BytesIO
import numpy as np
import cv2
import insightface
import onnxruntime as ort
import huggingface_hub
from SegCloth import segment_clothing
from transparent_background import Remover
app = Flask(__name__)
# Charger le modèle
def load_model():
path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
options = ort.SessionOptions()
options.intra_op_num_threads = 8
options.inter_op_num_threads = 8
session = ort.InferenceSession(
path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
)
model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
return model
detector = load_model()
detector.prepare(-1, nms_thresh=0.5, input_size=(640, 640))
# Fonction pour décoder une image encodée en base64 en objet PIL.Image.Image
def decode_image_from_base64(image_data):
image_data = base64.b64decode(image_data)
image = Image.open(BytesIO(image_data)).convert("RGB") # Convertir en RGB pour éviter les problèmes de canal alpha
return image
# Fonction pour encoder une image PIL en base64
def encode_image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
#@spaces.GPU
def remove_background(image):
remover = Remover()
if isinstance(image, Image.Image):
output = remover.process(image)
elif isinstance(image, np.ndarray):
image_pil = Image.fromarray(image)
output = remover.process(image_pil)
else:
raise TypeError("Unsupported image type")
return output
# Détecter les personnes et segmenter leurs vêtements
def detect_and_segment_persons(image, clothes):
img = np.array(image)
img = img[:, :, ::-1] # RGB -> BGR
bboxes, kpss = detector.detect(img)
if bboxes.shape[0] == 0: # Aucun visage détecté
return []
height, width, _ = img.shape # Get image dimensions
bboxes = np.round(bboxes[:, :4]).astype(int)
# Clamp bounding boxes within image boundaries
bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width) # x1
bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height) # y1
bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width) # x2
bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height) # y2
all_segmented_images = []
for i in range(bboxes.shape[0]):
bbox = bboxes[i]
x1, y1, x2, y2 = bbox
person_img = img[y1:y2, x1:x2]
# Convert numpy array to PIL Image
pil_img = Image.fromarray(person_img[:, :, ::-1]) # BGR -> RGB
# Segment clothing for the detected person
img_rm_background = remove_background(pil_img)
segmented_result = segment_clothing(img_rm_background, clothes)
# Combine the segmented images for all persons
all_segmented_images.extend(segmented_result)
return all_segmented_images
@app.route('/', methods=['GET'])
def welcome():
return "Welcome to Clothing Segmentation API"
@app.route('/api/detect', methods=['POST'])
def detect():
try:
data = request.json
image_base64 = data['image']
image = decode_image_from_base64(image_base64)
# Détection et segmentation des personnes
clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
person_images_base64 = detect_and_segment_persons(image, clothes)
return jsonify({'images': person_images_base64})
except Exception as e:
print(e)
return jsonify({'error': str(e)}), 500
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=7860)