Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image,ImageFilter | |
import cv2 | |
import numpy as np | |
import base64 | |
import spaces | |
from loadimg import load_img | |
from io import BytesIO | |
import numpy as np | |
import insightface | |
import onnxruntime as ort | |
import huggingface_hub | |
from SegCloth import segment_clothing | |
from transparent_background import Remover | |
import uuid | |
from transformers import AutoModelForImageSegmentation | |
import torch | |
from torchvision import transforms | |
# Load the model lazily | |
model = None | |
detector = None | |
def load_model(): | |
global model, detector | |
path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx") | |
options = ort.SessionOptions() | |
options.intra_op_num_threads = 8 | |
options.inter_op_num_threads = 8 | |
session = ort.InferenceSession( | |
path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"] | |
) | |
model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session) | |
model.prepare(-1, nms_thresh=0.5, input_size=(640, 640)) | |
detector = model | |
# Load the segmentation model | |
torch.set_float32_matmul_precision(["high", "highest"][0]) | |
birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True) | |
birefnet.to("cuda") | |
transform_image = transforms.Compose([ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]) | |
def refine_edges(image): | |
""" | |
Affine les contours de l'image en sortie en utilisant un filtre de détection de contours et du lissage. | |
""" | |
# Convertir l'image PIL en format numpy pour OpenCV | |
img_np = np.array(image) | |
# Convertir en niveaux de gris pour traiter les contours | |
gray = cv2.cvtColor(img_np, cv2.COLOR_RGBA2GRAY) | |
# Détection des bords avec Canny | |
edges = cv2.Canny(gray, threshold1=50, threshold2=150) | |
# Dilater les bords pour renforcer les contours | |
kernel = np.ones((3, 3), np.uint8) | |
edges_dilated = cv2.dilate(edges, kernel, iterations=1) | |
# Lisser les bords (anti-aliasing) | |
blurred = cv2.GaussianBlur(edges_dilated, (5, 5), 0) | |
# Ajouter les bords comme masque alpha | |
alpha = Image.fromarray(blurred).convert("L") | |
image.putalpha(alpha) | |
# Filtrage supplémentaire pour améliorer l'esthétique | |
refined_image = image.filter(ImageFilter.SMOOTH_MORE) | |
return refined_image | |
def save_image(img): | |
unique_name = str(uuid.uuid4()) + ".png" | |
img.save(unique_name) | |
return unique_name | |
def rm_background(image): | |
im = load_img(image, output_type="pil") | |
im = im.convert("RGB") | |
image_size = im.size | |
origin = im.copy() | |
image = load_img(im) | |
input_images = transform_image(image).unsqueeze(0).to("cuda") | |
# Prediction | |
with torch.no_grad(): | |
preds = birefnet(input_images)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(image_size) | |
image.putalpha(mask) | |
return image | |
def detect_and_segment_persons(image, clothes): | |
img = np.array(image) | |
img = img[:, :, ::-1] # RGB -> BGR | |
if detector is None: | |
load_model() # Ensure the model is loaded | |
bboxes, kpss = detector.detect(img) | |
if bboxes.shape[0] == 0: | |
return [rm_background(image)] | |
height, width, _ = img.shape | |
bboxes = np.round(bboxes[:, :4]).astype(int) | |
bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width) | |
bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height) | |
bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width) | |
bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height) | |
all_segmented_images = [] | |
for i in range(bboxes.shape[0]): | |
bbox = bboxes[i] | |
x1, y1, x2, y2 = bbox | |
person_img = img[y1:y2, x1:x2] | |
pil_img = Image.fromarray(person_img[:, :, ::-1]) | |
img_rm_background = rm_background(pil_img) | |
segmented_result = segment_clothing(img_rm_background, clothes) | |
all_segmented_images.extend(segmented_result) | |
return all_segmented_images | |
def process_image(input_image): | |
try: | |
clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"] | |
results = detect_and_segment_persons(input_image, clothes) | |
# results = [refine_edges(image) for image in results] | |
return results | |
except Exception as e: | |
return f"Error occurred: {e}" | |
# Gradio Interface | |
def gradio_interface(image): | |
results = process_image(image) | |
if isinstance(results, list): | |
return results | |
else: | |
return "Error: " + results | |
# Create Gradio app | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Gallery(label="Segmented Results"), | |
title="Clothing Segmentation API" | |
) | |
interface.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) | |