File size: 4,875 Bytes
739eb5d
cc76d9f
 
 
739eb5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc76d9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739eb5d
 
 
 
 
7ca1694
739eb5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ca1694
739eb5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ca1694
739eb5d
 
 
 
a60fe31
cc76d9f
a60fe31
739eb5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cde8dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
from PIL import Image,ImageFilter
import cv2
import numpy as np
import base64
import spaces
from loadimg import load_img
from io import BytesIO
import numpy as np
import insightface
import onnxruntime as ort
import huggingface_hub
from SegCloth import segment_clothing
from transparent_background import Remover
import uuid
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms

# Load the model lazily
model = None
detector = None
def load_model():
    global model, detector
    path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
    options = ort.SessionOptions()
    options.intra_op_num_threads = 8
    options.inter_op_num_threads = 8
    session = ort.InferenceSession(
        path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
    )
    model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
    model.prepare(-1, nms_thresh=0.5, input_size=(640, 640))
    detector = model

# Load the segmentation model
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
birefnet.to("cuda")
transform_image = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

def refine_edges(image):
    """
    Affine les contours de l'image en sortie en utilisant un filtre de détection de contours et du lissage.
    """
    # Convertir l'image PIL en format numpy pour OpenCV
    img_np = np.array(image)

    # Convertir en niveaux de gris pour traiter les contours
    gray = cv2.cvtColor(img_np, cv2.COLOR_RGBA2GRAY)

    # Détection des bords avec Canny
    edges = cv2.Canny(gray, threshold1=50, threshold2=150)

    # Dilater les bords pour renforcer les contours
    kernel = np.ones((3, 3), np.uint8)
    edges_dilated = cv2.dilate(edges, kernel, iterations=1)

    # Lisser les bords (anti-aliasing)
    blurred = cv2.GaussianBlur(edges_dilated, (5, 5), 0)

    # Ajouter les bords comme masque alpha
    alpha = Image.fromarray(blurred).convert("L")
    image.putalpha(alpha)

    # Filtrage supplémentaire pour améliorer l'esthétique
    refined_image = image.filter(ImageFilter.SMOOTH_MORE)

    return refined_image

def save_image(img):
    unique_name = str(uuid.uuid4()) + ".png"
    img.save(unique_name)
    return unique_name

@spaces.GPU
def rm_background(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    image_size = im.size
    origin = im.copy()
    image = load_img(im)
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image

@spaces.GPU
def detect_and_segment_persons(image, clothes):
    img = np.array(image)
    img = img[:, :, ::-1]  # RGB -> BGR

    if detector is None:
        load_model()  # Ensure the model is loaded

    bboxes, kpss = detector.detect(img)
    if bboxes.shape[0] == 0:
        return [rm_background(image)]

    height, width, _ = img.shape
    bboxes = np.round(bboxes[:, :4]).astype(int)
    bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width)
    bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height)
    bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width)
    bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height)

    all_segmented_images = []
    for i in range(bboxes.shape[0]):
        bbox = bboxes[i]
        x1, y1, x2, y2 = bbox
        person_img = img[y1:y2, x1:x2]
        pil_img = Image.fromarray(person_img[:, :, ::-1])

        img_rm_background = rm_background(pil_img)
        segmented_result = segment_clothing(img_rm_background, clothes)
        all_segmented_images.extend(segmented_result)

    return all_segmented_images

@spaces.GPU
def process_image(input_image):
    try:
        clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
        results = detect_and_segment_persons(input_image, clothes)
        # results = [refine_edges(image) for image in results]

        return results
    except Exception as e:
        return f"Error occurred: {e}"

# Gradio Interface
def gradio_interface(image):
    results = process_image(image)
    if isinstance(results, list):
        return results
    else:
        return "Error: " + results

# Create Gradio app
interface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Image(type="pil"),
    outputs=gr.Gallery(label="Segmented Results"),
    title="Clothing Segmentation API"
)

interface.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860)