import gradio as gr from PIL import Image,ImageFilter import cv2 import numpy as np import base64 import spaces from loadimg import load_img from io import BytesIO import numpy as np import insightface import onnxruntime as ort import huggingface_hub from SegCloth import segment_clothing from transparent_background import Remover import uuid from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms # Load the model lazily model = None detector = None def load_model(): global model, detector path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx") options = ort.SessionOptions() options.intra_op_num_threads = 8 options.inter_op_num_threads = 8 session = ort.InferenceSession( path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"] ) model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session) model.prepare(-1, nms_thresh=0.5, input_size=(640, 640)) detector = model # Load the segmentation model torch.set_float32_matmul_precision(["high", "highest"][0]) birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True) birefnet.to("cuda") transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def refine_edges(image): """ Affine les contours de l'image en sortie en utilisant un filtre de détection de contours et du lissage. """ # Convertir l'image PIL en format numpy pour OpenCV img_np = np.array(image) # Convertir en niveaux de gris pour traiter les contours gray = cv2.cvtColor(img_np, cv2.COLOR_RGBA2GRAY) # Détection des bords avec Canny edges = cv2.Canny(gray, threshold1=50, threshold2=150) # Dilater les bords pour renforcer les contours kernel = np.ones((3, 3), np.uint8) edges_dilated = cv2.dilate(edges, kernel, iterations=1) # Lisser les bords (anti-aliasing) blurred = cv2.GaussianBlur(edges_dilated, (5, 5), 0) # Ajouter les bords comme masque alpha alpha = Image.fromarray(blurred).convert("L") image.putalpha(alpha) # Filtrage supplémentaire pour améliorer l'esthétique refined_image = image.filter(ImageFilter.SMOOTH_MORE) return refined_image def save_image(img): unique_name = str(uuid.uuid4()) + ".png" img.save(unique_name) return unique_name @spaces.GPU def rm_background(image): im = load_img(image, output_type="pil") im = im.convert("RGB") image_size = im.size origin = im.copy() image = load_img(im) input_images = transform_image(image).unsqueeze(0).to("cuda") # Prediction with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) return image @spaces.GPU def detect_and_segment_persons(image, clothes): img = np.array(image) img = img[:, :, ::-1] # RGB -> BGR if detector is None: load_model() # Ensure the model is loaded bboxes, kpss = detector.detect(img) if bboxes.shape[0] == 0: return [rm_background(image)] height, width, _ = img.shape bboxes = np.round(bboxes[:, :4]).astype(int) bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width) bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height) bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width) bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height) all_segmented_images = [] for i in range(bboxes.shape[0]): bbox = bboxes[i] x1, y1, x2, y2 = bbox person_img = img[y1:y2, x1:x2] pil_img = Image.fromarray(person_img[:, :, ::-1]) img_rm_background = rm_background(pil_img) segmented_result = segment_clothing(img_rm_background, clothes) all_segmented_images.extend(segmented_result) return all_segmented_images @spaces.GPU def process_image(input_image): try: clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"] results = detect_and_segment_persons(input_image, clothes) # results = [refine_edges(image) for image in results] return results except Exception as e: return f"Error occurred: {e}" # Gradio Interface def gradio_interface(image): results = process_image(image) if isinstance(results, list): return results else: return "Error: " + results # Create Gradio app interface = gr.Interface( fn=gradio_interface, inputs=gr.Image(type="pil"), outputs=gr.Gallery(label="Segmented Results"), title="Clothing Segmentation API" ) interface.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860)