Spaces:
Sleeping
Sleeping
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation | |
from PIL import Image, ImageFilter | |
import torch.nn as nn | |
import os | |
import gradio as gr | |
processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
title = "Background remover π" | |
description = " Image segmentation model which removes the background and optionally adds a white border." | |
article = 'Inference done on "mattmdjaga/segformer_b2_clothes" model' | |
folder_path = "Images" | |
example_list = [] | |
if os.path.exists(folder_path) and os.path.isdir(folder_path): | |
file_paths = [os.path.join(folder_path, file_name) for file_name in os.listdir(folder_path)] | |
for file_path in file_paths: | |
example_list.append(['Large',file_path]) | |
def predict(border_size, image): | |
sizes = {'Large': 5, 'Medium': 3, 'Small': 1, 'None': 0} | |
image = image.convert('RGB') | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits.cpu() | |
upsampled_logits = nn.functional.interpolate( | |
logits, | |
size=image.size[::-1], | |
mode="bilinear", | |
align_corners=False, | |
) | |
pred_seg = upsampled_logits.argmax(dim=1)[0] | |
non_background_mask = pred_seg != 0 | |
# Convert tensor mask to PIL Image with an alpha channel | |
non_background_pil_mask = Image.fromarray(non_background_mask.numpy().astype('uint8') * 255, 'L') | |
# Create a composite image using the non-background mask | |
composite_image = Image.new('RGBA', image.size, color=(0, 0, 0, 0)) | |
composite_image.paste(image.convert('RGBA'), mask=non_background_pil_mask) | |
if sizes[border_size] != 0: | |
stroke_radius = sizes[border_size] | |
img = composite_image # RGBA image | |
stroke_image = Image.new("RGBA", img.size, (255, 255, 255, 255)) | |
img_alpha = img.getchannel(3).point(lambda x: 255 if x>0 else 0) | |
stroke_alpha = img_alpha.filter(ImageFilter.MaxFilter(stroke_radius)) | |
stroke_alpha = stroke_alpha.filter(ImageFilter.SMOOTH) | |
stroke_image.putalpha(stroke_alpha) | |
output = Image.alpha_composite(stroke_image, img) | |
return output | |
else: | |
return composite_image | |
iface = gr.Interface(fn=predict, | |
inputs=[gr.Dropdown(['None','Small', 'Medium', 'Large'], label='Select Border Size'), | |
gr.Image(type='pil', label='Select Image.')], | |
outputs=gr.Image(type='pil', label='Output with background removed (sorta?)'), | |
title=title, | |
description=description, | |
article=article, | |
examples=example_list) | |
iface.launch() |