|
import gradio as gr |
|
from gradio.components.image_editor import EditorValue |
|
from gradio_imageslider import ImageSlider |
|
from PIL import Image |
|
from typing import cast |
|
import numpy as np |
|
from simple_lama_inpainting import SimpleLama |
|
|
|
|
|
simple_lama = SimpleLama() |
|
|
|
def HWC3(x): |
|
if x.ndim == 2: |
|
x = x[:, :, None] |
|
H, W, C = x.shape |
|
if C == 3: |
|
return x |
|
if C == 1: |
|
return np.concatenate([x, x, x], axis=2) |
|
if C == 4: |
|
color = x[:, :, 0:3].astype(np.float32) |
|
alpha = x[:, :, 3:4].astype(np.float32) / 255.0 |
|
y = color * alpha + 255.0 * (1.0 - alpha) |
|
y = y.clip(0, 255).astype(np.uint8) |
|
return y |
|
|
|
def process_image( |
|
image: Image.Image | str | None, |
|
mask: Image.Image | str | None, |
|
progress: gr.Progress = gr.Progress(), |
|
) -> Image.Image | None: |
|
progress(0, desc="Preparing inputs...") |
|
if image is None or mask is None: |
|
return None |
|
|
|
if isinstance(mask, str): |
|
mask = Image.open(mask) |
|
if isinstance(image, str): |
|
image = Image.open(image) |
|
image = np.array(image) |
|
image = HWC3(image) |
|
|
|
result = simple_lama(image, mask) |
|
result.save("inpainted.png") |
|
return result |
|
|
|
def resize_image(img: Image.Image, min_side_length: int = 768) -> Image.Image: |
|
if img.width <= min_side_length and img.height <= min_side_length: |
|
return img |
|
|
|
aspect_ratio = img.width / img.height |
|
if img.width < img.height: |
|
new_height = int(min_side_length / aspect_ratio) |
|
return img.resize((min_side_length, new_height)) |
|
|
|
new_width = int(min_side_length * aspect_ratio) |
|
return img.resize((new_width, min_side_length)) |
|
|
|
|
|
async def process( |
|
image_and_mask: EditorValue | None, |
|
progress: gr.Progress = gr.Progress(), |
|
) -> tuple[Image.Image, Image.Image] | None: |
|
if not image_and_mask: |
|
gr.Info("Please upload an image and draw a mask") |
|
return None |
|
|
|
|
|
image_np = image_and_mask["background"] |
|
image_np = cast(np.ndarray, image_np) |
|
|
|
if np.sum(image_np) == 0: |
|
gr.Info("Please upload an image") |
|
return None |
|
|
|
alpha_channel = image_and_mask["layers"][0] |
|
alpha_channel = cast(np.ndarray, alpha_channel) |
|
mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) |
|
|
|
if np.sum(mask_np) == 0: |
|
gr.Info("Please mark the areas you want to remove") |
|
return None |
|
|
|
mask = Image.fromarray(mask_np) |
|
mask = resize_image(mask) |
|
|
|
image = Image.fromarray(image_np) |
|
image = resize_image(image) |
|
|
|
output = process_image( |
|
image, |
|
mask, |
|
progress, |
|
) |
|
|
|
if output is None: |
|
gr.Info("Processing failed") |
|
return None |
|
progress(100, desc="Processing completed") |
|
return image, output |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_and_mask = gr.ImageMask( |
|
label="Upload Image and Draw Mask", |
|
layers=False, |
|
show_fullscreen_button=False, |
|
sources=["upload"], |
|
show_download_button=False, |
|
interactive=True, |
|
height="full", |
|
width="full", |
|
brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"), |
|
transforms=[], |
|
) |
|
|
|
|
|
with gr.Column(): |
|
image_slider = ImageSlider( |
|
label="Result", |
|
interactive=False, |
|
) |
|
|
|
process_btn = gr.ClearButton( |
|
value="Run", |
|
variant="primary", |
|
size="lg", |
|
components=[image_slider], |
|
) |
|
|
|
process_btn.click( |
|
fn=lambda _: gr.update(interactive=False, value="Processing..."), |
|
inputs=[], |
|
outputs=[process_btn], |
|
api_name=False, |
|
).then( |
|
fn=process, |
|
inputs=[ |
|
image_and_mask, |
|
], |
|
outputs=[image_slider], |
|
api_name=False, |
|
).then( |
|
fn=lambda _: gr.update(interactive=True, value="Run"), |
|
inputs=[], |
|
outputs=[process_btn], |
|
api_name=False, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
debug=False, |
|
share=False, |
|
show_api=False, |
|
) |