import gradio as gr from gradio.components.image_editor import EditorValue from gradio_imageslider import ImageSlider from PIL import Image from typing import cast import numpy as np from simple_lama_inpainting import SimpleLama simple_lama = SimpleLama() def HWC3(x): if x.ndim == 2: x = x[:, :, None] H, W, C = x.shape if C == 3: return x if C == 1: return np.concatenate([x, x, x], axis=2) if C == 4: color = x[:, :, 0:3].astype(np.float32) alpha = x[:, :, 3:4].astype(np.float32) / 255.0 y = color * alpha + 255.0 * (1.0 - alpha) y = y.clip(0, 255).astype(np.uint8) return y def process_image( image: Image.Image | str | None, mask: Image.Image | str | None, progress: gr.Progress = gr.Progress(), ) -> Image.Image | None: progress(0, desc="Preparing inputs...") if image is None or mask is None: return None if isinstance(mask, str): mask = Image.open(mask) if isinstance(image, str): image = Image.open(image) image = np.array(image) image = HWC3(image) result = simple_lama(image, mask) result.save("inpainted.png") return result def resize_image(img: Image.Image, min_side_length: int = 768) -> Image.Image: if img.width <= min_side_length and img.height <= min_side_length: return img aspect_ratio = img.width / img.height if img.width < img.height: new_height = int(min_side_length / aspect_ratio) return img.resize((min_side_length, new_height)) new_width = int(min_side_length * aspect_ratio) return img.resize((new_width, min_side_length)) async def process( image_and_mask: EditorValue | None, progress: gr.Progress = gr.Progress(), ) -> tuple[Image.Image, Image.Image] | None: if not image_and_mask: gr.Info("Please upload an image and draw a mask") return None image_np = image_and_mask["background"] image_np = cast(np.ndarray, image_np) if np.sum(image_np) == 0: gr.Info("Please upload an image") return None alpha_channel = image_and_mask["layers"][0] alpha_channel = cast(np.ndarray, alpha_channel) mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) if np.sum(mask_np) == 0: gr.Info("Please mark the areas you want to remove") return None mask = Image.fromarray(mask_np) mask = resize_image(mask) image = Image.fromarray(image_np) image = resize_image(image) output = process_image( image, mask, progress, ) if output is None: gr.Info("Processing failed") return None progress(100, desc="Processing completed") return image, output with gr.Blocks() as demo: with gr.Row(): with gr.Column(): image_and_mask = gr.ImageMask( label="Upload Image and Draw Mask", layers=False, show_fullscreen_button=False, sources=["upload"], show_download_button=False, interactive=True, height="full", width="full", brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"), transforms=[], ) with gr.Column(): image_slider = ImageSlider( label="Result", interactive=False, ) process_btn = gr.ClearButton( value="Run", variant="primary", size="lg", components=[image_slider], ) process_btn.click( fn=lambda _: gr.update(interactive=False, value="Processing..."), inputs=[], outputs=[process_btn], api_name=False, ).then( fn=process, inputs=[ image_and_mask, ], outputs=[image_slider], api_name=False, ).then( fn=lambda _: gr.update(interactive=True, value="Run"), inputs=[], outputs=[process_btn], api_name=False, ) if __name__ == "__main__": demo.launch( debug=False, share=False, show_api=False, )