Saad0KH's picture
Update SegCloth.py
a05ee1f verified
raw
history blame
1.13 kB
from transformers import pipeline
from PIL import Image
import numpy as np
from io import BytesIO
import base64
# Initialize segmentation pipeline
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
def encode_image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
# Segment image
segments = segmenter(img)
# Create list of masks and their corresponding clothing types
mask_list = []
for s in segments:
if s['label'] in clothes:
mask_list.append((s['mask'], s['label']))
result_images = []
# Paste all masks on top of each other
for mask, clothing_type in mask_list:
current_mask = np.array(mask)
final_mask_bis = Image.fromarray(current_mask)
img.putalpha(final_mask_bis)
imageBase64 = encode_image_to_base64(img)
result_images.append((clothing_type, imageBase64))
return result_images