File size: 3,756 Bytes
62fa28c
3ff00f1
8414811
62fa28c
4459aee
 
 
 
 
7713aa5
025c436
4459aee
8da9807
 
4459aee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8da9807
 
 
 
 
 
 
 
 
 
 
 
1289ee4
 
 
 
 
 
 
 
 
 
 
 
ba27f48
 
4459aee
 
 
 
8da9807
2edf2fd
8da9807
26e18b2
 
4459aee
26e18b2
 
 
 
 
 
4459aee
c2ffaef
4459aee
 
 
 
26e18b2
ba27f48
4459aee
 
ba27f48
2edf2fd
 
ba27f48
c2ffaef
 
ba27f48
c2ffaef
62fa28c
2edf2fd
 
 
 
4459aee
 
8da9807
 
 
 
 
ba27f48
 
 
8da9807
 
 
513909f
8da9807
4459aee
62fa28c
2edf2fd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from flask import Flask, request, jsonify
from PIL import Image
import base64
from io import BytesIO
import numpy as np
import cv2
import insightface
import onnxruntime as ort
import huggingface_hub
from SegCloth import segment_clothing
from transparent_background import Remover

app = Flask(__name__)

# Charger le modèle
def load_model():
    path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
    options = ort.SessionOptions()
    options.intra_op_num_threads = 8
    options.inter_op_num_threads = 8
    session = ort.InferenceSession(
        path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
    )
    model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
    return model

detector = load_model()
detector.prepare(-1, nms_thresh=0.5, input_size=(640, 640))

# Fonction pour décoder une image encodée en base64 en objet PIL.Image.Image
def decode_image_from_base64(image_data):
    image_data = base64.b64decode(image_data)
    image = Image.open(BytesIO(image_data)).convert("RGB")  # Convertir en RGB pour éviter les problèmes de canal alpha
    return image

# Fonction pour encoder une image PIL en base64
def encode_image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

#@spaces.GPU
def remove_background(image):
    remover = Remover()
    if isinstance(image, Image.Image):
        output = remover.process(image)
    elif isinstance(image, np.ndarray):
        image_pil = Image.fromarray(image)
        output = remover.process(image_pil)
    else:
        raise TypeError("Unsupported image type")
    return output

# Détecter les personnes et segmenter leurs vêtements
def detect_and_segment_persons(image, clothes):
    img = np.array(image)
    img = img[:, :, ::-1]  # RGB -> BGR

    bboxes, kpss = detector.detect(img)
    if bboxes.shape[0] == 0:  # Aucun visage détecté
        return []

    height, width, _ = img.shape  # Get image dimensions

    bboxes = np.round(bboxes[:, :4]).astype(int)
    
    # Clamp bounding boxes within image boundaries
    bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width)   # x1
    bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height)  # y1
    bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width)   # x2
    bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height)  # y2

    all_segmented_images = []
    for i in range(bboxes.shape[0]):
        bbox = bboxes[i]
        x1, y1, x2, y2 = bbox
        person_img = img[y1:y2, x1:x2]

        # Convert numpy array to PIL Image
        pil_img = Image.fromarray(person_img[:, :, ::-1])  # BGR -> RGB

        # Segment clothing for the detected person
        img_rm_background = remove_background(pil_img)
        segmented_result = segment_clothing(img_rm_background, clothes)

        # Combine the segmented images for all persons
        all_segmented_images.extend(segmented_result)

    return all_segmented_images

@app.route('/', methods=['GET'])
def welcome():
   return "Welcome to Clothing Segmentation API"

@app.route('/api/detect', methods=['POST'])
def detect():
    try:
        data = request.json
        image_base64 = data['image']
        image = decode_image_from_base64(image_base64)
        
        # Détection et segmentation des personnes
        clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
        person_images_base64 = detect_and_segment_persons(image, clothes)
        
        return jsonify({'images': person_images_base64})
    except Exception as e:
        print(e)
        return jsonify({'error': str(e)}), 500

if __name__ == "__main__":
    app.run(debug=True, host="0.0.0.0", port=7860)