Saad0KH's picture
Update app.py
9cde8dc verified
import gradio as gr
from PIL import Image,ImageFilter
import cv2
import numpy as np
import base64
import spaces
from loadimg import load_img
from io import BytesIO
import numpy as np
import insightface
import onnxruntime as ort
import huggingface_hub
from SegCloth import segment_clothing
from transparent_background import Remover
import uuid
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
# Load the model lazily
model = None
detector = None
def load_model():
global model, detector
path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
options = ort.SessionOptions()
options.intra_op_num_threads = 8
options.inter_op_num_threads = 8
session = ort.InferenceSession(
path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
)
model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
model.prepare(-1, nms_thresh=0.5, input_size=(640, 640))
detector = model
# Load the segmentation model
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
birefnet.to("cuda")
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def refine_edges(image):
"""
Affine les contours de l'image en sortie en utilisant un filtre de détection de contours et du lissage.
"""
# Convertir l'image PIL en format numpy pour OpenCV
img_np = np.array(image)
# Convertir en niveaux de gris pour traiter les contours
gray = cv2.cvtColor(img_np, cv2.COLOR_RGBA2GRAY)
# Détection des bords avec Canny
edges = cv2.Canny(gray, threshold1=50, threshold2=150)
# Dilater les bords pour renforcer les contours
kernel = np.ones((3, 3), np.uint8)
edges_dilated = cv2.dilate(edges, kernel, iterations=1)
# Lisser les bords (anti-aliasing)
blurred = cv2.GaussianBlur(edges_dilated, (5, 5), 0)
# Ajouter les bords comme masque alpha
alpha = Image.fromarray(blurred).convert("L")
image.putalpha(alpha)
# Filtrage supplémentaire pour améliorer l'esthétique
refined_image = image.filter(ImageFilter.SMOOTH_MORE)
return refined_image
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
@spaces.GPU
def rm_background(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
image_size = im.size
origin = im.copy()
image = load_img(im)
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
@spaces.GPU
def detect_and_segment_persons(image, clothes):
img = np.array(image)
img = img[:, :, ::-1] # RGB -> BGR
if detector is None:
load_model() # Ensure the model is loaded
bboxes, kpss = detector.detect(img)
if bboxes.shape[0] == 0:
return [rm_background(image)]
height, width, _ = img.shape
bboxes = np.round(bboxes[:, :4]).astype(int)
bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width)
bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height)
bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width)
bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height)
all_segmented_images = []
for i in range(bboxes.shape[0]):
bbox = bboxes[i]
x1, y1, x2, y2 = bbox
person_img = img[y1:y2, x1:x2]
pil_img = Image.fromarray(person_img[:, :, ::-1])
img_rm_background = rm_background(pil_img)
segmented_result = segment_clothing(img_rm_background, clothes)
all_segmented_images.extend(segmented_result)
return all_segmented_images
@spaces.GPU
def process_image(input_image):
try:
clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
results = detect_and_segment_persons(input_image, clothes)
# results = [refine_edges(image) for image in results]
return results
except Exception as e:
return f"Error occurred: {e}"
# Gradio Interface
def gradio_interface(image):
results = process_image(image)
if isinstance(results, list):
return results
else:
return "Error: " + results
# Create Gradio app
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="pil"),
outputs=gr.Gallery(label="Segmented Results"),
title="Clothing Segmentation API"
)
interface.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860)