File size: 2,821 Bytes
9239cd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image, ImageFilter
import torch.nn as nn
import os
import gradio as gr

processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")

title = "Background remover πŸ‘€"
description = " Image segmentation model which removes the background and optionally adds a white border."
article = 'Inference done on "mattmdjaga/segformer_b2_clothes" model'


folder_path = "Images" 
example_list = []
if os.path.exists(folder_path) and os.path.isdir(folder_path):
    file_paths = [os.path.join(folder_path, file_name) for file_name in os.listdir(folder_path)]
    for file_path in file_paths:
        example_list.append(['Large',file_path])

def predict(border_size, image):
    sizes = {'Large': 5, 'Medium': 3, 'Small': 1, 'None': 0}
    image = image.convert('RGB')
    inputs = processor(images=image, return_tensors="pt")

    outputs = model(**inputs)
    logits = outputs.logits.cpu()

    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )

    pred_seg = upsampled_logits.argmax(dim=1)[0]

    non_background_mask = pred_seg != 0

    # Convert tensor mask to PIL Image with an alpha channel
    non_background_pil_mask = Image.fromarray(non_background_mask.numpy().astype('uint8') * 255, 'L')

    # Create a composite image using the non-background mask
    composite_image = Image.new('RGBA', image.size, color=(0, 0, 0, 0))
    composite_image.paste(image.convert('RGBA'), mask=non_background_pil_mask)

    if sizes[border_size] != 0:
        stroke_radius = sizes[border_size]
        img = composite_image # RGBA image
        stroke_image = Image.new("RGBA", img.size, (255, 255, 255, 255))
        img_alpha = img.getchannel(3).point(lambda x: 255 if x>0 else 0)
        stroke_alpha = img_alpha.filter(ImageFilter.MaxFilter(stroke_radius))
        stroke_alpha = stroke_alpha.filter(ImageFilter.SMOOTH)
        stroke_image.putalpha(stroke_alpha)
        output = Image.alpha_composite(stroke_image, img)
        return output
    else:
        return composite_image
    
iface = gr.Interface(fn=predict,
                    inputs=[gr.Dropdown(['None','Small', 'Medium', 'Large'], label='Select Border Size'),
                            gr.Image(type='pil', label='Select Image.')],
                    outputs=gr.Image(type='pil', label='Output with background removed (sorta?)'),
                             title=title,
                             description=description,
                             article=article,
                             examples=example_list)
iface.launch()