Saad0KH's picture
Update SegCloth.py
94bd664 verified
raw
history blame
2.82 kB
from transformers import pipeline
from PIL import Image, ImageChops, ImageOps
import numpy as np
from io import BytesIO
import base64
from transparent_background import Remover
# Initialisation du pipeline de segmentation
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
#@spaces.GPU
def remove_background(image):
remover = Remover()
if isinstance(image, Image.Image):
output = remover.process(image)
elif isinstance(image, np.ndarray):
image_pil = Image.fromarray(image)
output = remover.process(image_pil)
else:
raise TypeError("Unsupported image type")
return output
def encode_image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"], margin=10):
# Segmentation de l'image
segments = segmenter(img)
# Liste des images segmentées
result_images = []
for s in segments:
if s['label'] in clothes:
# Conversion du masque en tableau NumPy
mask_array = np.array(s['mask'])
# Création d'une image vide avec transparence
empty_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
# Conversion du masque en image PIL (niveau de gris)
mask_image = Image.fromarray(mask_array).convert("L")
# Extraction de la partie de l'image correspondant au masque
segmented_part = ImageChops.multiply(img.convert("RGBA"), Image.merge("RGBA", [mask_image, mask_image, mask_image, mask_image]))
# Application du masque sur l'image vide
empty_image.paste(segmented_part, mask=mask_image)
# Déterminer la bounding box du masque
bbox = mask_image.getbbox()
if bbox:
# Ajouter la marge autour de la bounding box
left, top, right, bottom = bbox
left = max(0, left - margin)
top = max(0, top - margin)
right = min(img.width, right + margin)
bottom = min(img.height, bottom + margin)
# Recadrer l'image à la taille du masque avec la marge
cropped_image = empty_image.crop((left, top, right, bottom))
# Encodage de l'image recadrée en base64
image_rm_background = remove_background(cropped_image)
imageBase64 = encode_image_to_base64(image_rm_background)
#result_images.append((s['label'], imageBase64))
result_images.append(imageBase64)
return result_images