File size: 3,484 Bytes
7f87578
d8988f8
 
7f87578
d8988f8
7f87578
d8988f8
 
7f87578
d8988f8
17f6f2d
d8988f8
 
 
7f87578
d8988f8
 
 
 
 
 
 
17f6f2d
7f87578
d8988f8
 
 
 
 
 
 
 
 
 
 
 
7f87578
d8988f8
 
 
 
 
7f87578
d8988f8
 
 
 
 
 
 
 
 
 
 
 
 
 
7f87578
d8988f8
 
 
 
 
 
7f87578
d8988f8
 
7f87578
d8988f8
 
7f87578
d8988f8
 
 
 
 
 
7f87578
d8988f8
 
 
7f87578
d8988f8
 
 
 
 
7f87578
d8988f8
 
 
 
 
7f87578
d8988f8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
from gradio_imageslider import ImageSlider
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import cv2
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation

torch.set_float32_matmul_precision(["high", "highest"][0])

# Load BiRefNet model for background removal
birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

def load_img(image, output_type="numpy"):
    if output_type == "pil":
        return Image.open(image).convert("RGB")
    else:
        return np.array(Image.open(image).convert("RGB"))

def add_text_to_image(image, text, position, color, font_size):
    img = Image.fromarray(image)
    draw = ImageDraw.Draw(img)
    font = ImageFont.truetype("arial.ttf", font_size)
    draw.text(position, text, fill=color, font=font)
    return np.array(img)

def inpaint_image(image, mask, inpaint_radius):
    img_cv = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    mask_cv = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
    result = cv2.inpaint(img_cv, mask_cv, inpaint_radius, cv2.INPAINT_TELEA)
    return cv2.cvtColor(result, cv2.COLOR_BGR2RGB)

def background_removal(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    image_size = im.size
    origin = im.copy()
    image = load_img(im)
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    im.putalpha(mask)
    return (im, origin)

def update_image(image, text, color, font_size, mask_image, inpaint_radius):
    img_with_text = add_text_to_image(image, text, (50, 50), color, font_size)
    if mask_image is not None:
        mask = np.array(mask_image)
        img_with_text = inpaint_image(img_with_text, mask, inpaint_radius)
    return img_with_text

def fn(image):
    return background_removal(image)

slider1 = ImageSlider(label="Original Image", type="pil")
slider2 = ImageSlider(label="Processed Image", type="pil")

image_input = gr.Image(label="Upload an image for background removal")
text_input = gr.Textbox(label="Enter Text to Add", placeholder="Your text here...")
color_input = gr.ColorPicker(label="Text Color")
font_size_input = gr.Slider(minimum=10, maximum=100, label="Font Size")
mask_input = gr.Image(type="numpy", label="Upload Mask Image (for Inpainting)", optional=True)
inpaint_radius_input = gr.Slider(minimum=1, maximum=50, value=3, label="Inpaint Radius")

bg_removal_interface = gr.Interface(
    fn, inputs=image_input, outputs=slider1, examples=["chameleon.jpg"]
)

design_editing_interface = gr.Interface(
    fn=lambda image, text, color, font_size, mask_image, inpaint_radius: update_image(image, text, color, font_size, mask_image, inpaint_radius),
    inputs=[image_input, text_input, color_input, font_size_input, mask_input, inpaint_radius_input],
    outputs=slider2
)

demo = gr.TabbedInterface(
    [bg_removal_interface, design_editing_interface],
    ["Background Removal", "Design Editing"],
    title="Advanced Image Editor"
)

if __name__ == "__main__":
    demo.launch()