Saad0KH's picture
Update app.py
cc76d9f verified
raw
history blame
4.87 kB
import gradio as gr
from PIL import Image,ImageFilter
import cv2
import numpy as np
import base64
import spaces
from loadimg import load_img
from io import BytesIO
import numpy as np
import insightface
import onnxruntime as ort
import huggingface_hub
from SegCloth import segment_clothing
from transparent_background import Remover
import uuid
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
# Load the model lazily
model = None
detector = None
def load_model():
global model, detector
path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
options = ort.SessionOptions()
options.intra_op_num_threads = 8
options.inter_op_num_threads = 8
session = ort.InferenceSession(
path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
)
model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
model.prepare(-1, nms_thresh=0.5, input_size=(640, 640))
detector = model
# Load the segmentation model
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
birefnet.to("cuda")
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def refine_edges(image):
"""
Affine les contours de l'image en sortie en utilisant un filtre de détection de contours et du lissage.
"""
# Convertir l'image PIL en format numpy pour OpenCV
img_np = np.array(image)
# Convertir en niveaux de gris pour traiter les contours
gray = cv2.cvtColor(img_np, cv2.COLOR_RGBA2GRAY)
# Détection des bords avec Canny
edges = cv2.Canny(gray, threshold1=50, threshold2=150)
# Dilater les bords pour renforcer les contours
kernel = np.ones((3, 3), np.uint8)
edges_dilated = cv2.dilate(edges, kernel, iterations=1)
# Lisser les bords (anti-aliasing)
blurred = cv2.GaussianBlur(edges_dilated, (5, 5), 0)
# Ajouter les bords comme masque alpha
alpha = Image.fromarray(blurred).convert("L")
image.putalpha(alpha)
# Filtrage supplémentaire pour améliorer l'esthétique
refined_image = image.filter(ImageFilter.SMOOTH_MORE)
return refined_image
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
@spaces.GPU
def rm_background(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
image_size = im.size
origin = im.copy()
image = load_img(im)
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
@spaces.GPU
def detect_and_segment_persons(image, clothes):
img = np.array(image)
img = img[:, :, ::-1] # RGB -> BGR
if detector is None:
load_model() # Ensure the model is loaded
bboxes, kpss = detector.detect(img)
if bboxes.shape[0] == 0:
return [rm_background(image)]
height, width, _ = img.shape
bboxes = np.round(bboxes[:, :4]).astype(int)
bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width)
bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height)
bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width)
bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height)
all_segmented_images = []
for i in range(bboxes.shape[0]):
bbox = bboxes[i]
x1, y1, x2, y2 = bbox
person_img = img[y1:y2, x1:x2]
pil_img = Image.fromarray(person_img[:, :, ::-1])
img_rm_background = rm_background(pil_img)
segmented_result = segment_clothing(img_rm_background, clothes)
all_segmented_images.extend(segmented_result)
return all_segmented_images
@spaces.GPU
def process_image(input_image):
try:
clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
results = detect_and_segment_persons(input_image, clothes)
refined_results = [refine_edges(image) for image in results]
return refined_results
except Exception as e:
return f"Error occurred: {e}"
# Gradio Interface
def gradio_interface(image):
results = process_image(image)
if isinstance(results, list):
return results
else:
return "Error: " + results
# Create Gradio app
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="pil"),
outputs=gr.Gallery(label="Segmented Results"),
title="Clothing Segmentation API"
)
interface.launch(server_name="0.0.0.0", server_port=7860)