File size: 4,391 Bytes
8478c4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import gradio as gr
from gradio.components.image_editor import EditorValue
from gradio_imageslider import ImageSlider
from PIL import Image
from typing import cast
import numpy as np
from simple_lama_inpainting import SimpleLama


simple_lama = SimpleLama()

def HWC3(x):
    if x.ndim == 2:
        x = x[:, :, None]
    H, W, C = x.shape
    if C == 3:
        return x
    if C == 1:
        return np.concatenate([x, x, x], axis=2)
    if C == 4:
        color = x[:, :, 0:3].astype(np.float32)
        alpha = x[:, :, 3:4].astype(np.float32) / 255.0
        y = color * alpha + 255.0 * (1.0 - alpha)
        y = y.clip(0, 255).astype(np.uint8)
        return y

def process_image(
    image: Image.Image | str | None,
    mask: Image.Image | str | None,
    progress: gr.Progress = gr.Progress(),
) -> Image.Image | None:
    progress(0, desc="Preparing inputs...")
    if image is None or mask is None:
        return None
    
    if isinstance(mask, str):
        mask = Image.open(mask)
    if isinstance(image, str):
        image = Image.open(image)
    image = np.array(image)
    image = HWC3(image)

    result = simple_lama(image, mask)
    result.save("inpainted.png")
    return result

def resize_image(img: Image.Image, min_side_length: int = 768) -> Image.Image:
    if img.width <= min_side_length and img.height <= min_side_length:
        return img

    aspect_ratio = img.width / img.height
    if img.width < img.height:
        new_height = int(min_side_length / aspect_ratio)
        return img.resize((min_side_length, new_height))

    new_width = int(min_side_length * aspect_ratio)
    return img.resize((new_width, min_side_length))


async def process(
    image_and_mask: EditorValue | None,
    progress: gr.Progress = gr.Progress(),
) -> tuple[Image.Image, Image.Image] | None:
    if not image_and_mask:
        gr.Info("Please upload an image and draw a mask")
        return None


    image_np = image_and_mask["background"]
    image_np = cast(np.ndarray, image_np)

    if np.sum(image_np) == 0:
        gr.Info("Please upload an image")
        return None

    alpha_channel = image_and_mask["layers"][0]
    alpha_channel = cast(np.ndarray, alpha_channel)
    mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8)

    if np.sum(mask_np) == 0:
        gr.Info("Please mark the areas you want to remove")
        return None

    mask = Image.fromarray(mask_np)
    mask = resize_image(mask)

    image = Image.fromarray(image_np)
    image = resize_image(image)

    output = process_image(
        image,  
        mask,  
        progress,
    )

    if output is None:
        gr.Info("Processing failed")
        return None
    progress(100, desc="Processing completed")
    return image, output


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            image_and_mask = gr.ImageMask(
                label="Upload Image and Draw Mask",
                layers=False,
                show_fullscreen_button=False,
                sources=["upload"],
                show_download_button=False,
                interactive=True,
                height="full",
                width="full",
                brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"),
                transforms=[],
            )


        with gr.Column():
            image_slider = ImageSlider(
                label="Result",
                interactive=False,
            )

            process_btn = gr.ClearButton(
                value="Run",
                variant="primary",
                size="lg",
                components=[image_slider],
            )

            process_btn.click(
                fn=lambda _: gr.update(interactive=False, value="Processing..."),
                inputs=[],
                outputs=[process_btn],
                api_name=False,
            ).then(
                fn=process,
                inputs=[
                    image_and_mask,
                ],
                outputs=[image_slider],
                api_name=False,
            ).then(
                fn=lambda _: gr.update(interactive=True, value="Run"),
                inputs=[],
                outputs=[process_btn],
                api_name=False,
            )


if __name__ == "__main__":
    demo.launch(
        debug=False,
        share=False,
        show_api=False,
    )