Upload 13 files
Browse files- app.py +158 -0
- requirements.txt +4 -0
- simple_lama_inpainting/__init__.py +3 -0
- simple_lama_inpainting/__pycache__/__init__.cpython-310.pyc +0 -0
- simple_lama_inpainting/cli.py +42 -0
- simple_lama_inpainting/models/__init__.py +0 -0
- simple_lama_inpainting/models/__pycache__/__init__.cpython-310.pyc +0 -0
- simple_lama_inpainting/models/__pycache__/model.cpython-310.pyc +0 -0
- simple_lama_inpainting/models/model.py +44 -0
- simple_lama_inpainting/utils/__init__.py +0 -0
- simple_lama_inpainting/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- simple_lama_inpainting/utils/__pycache__/util.cpython-310.pyc +0 -0
- simple_lama_inpainting/utils/util.py +101 -0
app.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from gradio.components.image_editor import EditorValue
|
3 |
+
from gradio_imageslider import ImageSlider
|
4 |
+
from PIL import Image
|
5 |
+
from typing import cast
|
6 |
+
import numpy as np
|
7 |
+
from simple_lama_inpainting import SimpleLama
|
8 |
+
|
9 |
+
|
10 |
+
simple_lama = SimpleLama()
|
11 |
+
|
12 |
+
def HWC3(x):
|
13 |
+
if x.ndim == 2:
|
14 |
+
x = x[:, :, None]
|
15 |
+
H, W, C = x.shape
|
16 |
+
if C == 3:
|
17 |
+
return x
|
18 |
+
if C == 1:
|
19 |
+
return np.concatenate([x, x, x], axis=2)
|
20 |
+
if C == 4:
|
21 |
+
color = x[:, :, 0:3].astype(np.float32)
|
22 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
23 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
24 |
+
y = y.clip(0, 255).astype(np.uint8)
|
25 |
+
return y
|
26 |
+
|
27 |
+
def process_image(
|
28 |
+
image: Image.Image | str | None,
|
29 |
+
mask: Image.Image | str | None,
|
30 |
+
progress: gr.Progress = gr.Progress(),
|
31 |
+
) -> Image.Image | None:
|
32 |
+
progress(0, desc="Preparing inputs...")
|
33 |
+
if image is None or mask is None:
|
34 |
+
return None
|
35 |
+
|
36 |
+
if isinstance(mask, str):
|
37 |
+
mask = Image.open(mask)
|
38 |
+
if isinstance(image, str):
|
39 |
+
image = Image.open(image)
|
40 |
+
image = np.array(image)
|
41 |
+
image = HWC3(image)
|
42 |
+
|
43 |
+
result = simple_lama(image, mask)
|
44 |
+
result.save("inpainted.png")
|
45 |
+
return result
|
46 |
+
|
47 |
+
def resize_image(img: Image.Image, min_side_length: int = 768) -> Image.Image:
|
48 |
+
if img.width <= min_side_length and img.height <= min_side_length:
|
49 |
+
return img
|
50 |
+
|
51 |
+
aspect_ratio = img.width / img.height
|
52 |
+
if img.width < img.height:
|
53 |
+
new_height = int(min_side_length / aspect_ratio)
|
54 |
+
return img.resize((min_side_length, new_height))
|
55 |
+
|
56 |
+
new_width = int(min_side_length * aspect_ratio)
|
57 |
+
return img.resize((new_width, min_side_length))
|
58 |
+
|
59 |
+
|
60 |
+
async def process(
|
61 |
+
image_and_mask: EditorValue | None,
|
62 |
+
progress: gr.Progress = gr.Progress(),
|
63 |
+
) -> tuple[Image.Image, Image.Image] | None:
|
64 |
+
if not image_and_mask:
|
65 |
+
gr.Info("Please upload an image and draw a mask")
|
66 |
+
return None
|
67 |
+
|
68 |
+
|
69 |
+
image_np = image_and_mask["background"]
|
70 |
+
image_np = cast(np.ndarray, image_np)
|
71 |
+
|
72 |
+
if np.sum(image_np) == 0:
|
73 |
+
gr.Info("Please upload an image")
|
74 |
+
return None
|
75 |
+
|
76 |
+
alpha_channel = image_and_mask["layers"][0]
|
77 |
+
alpha_channel = cast(np.ndarray, alpha_channel)
|
78 |
+
mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8)
|
79 |
+
|
80 |
+
if np.sum(mask_np) == 0:
|
81 |
+
gr.Info("Please mark the areas you want to remove")
|
82 |
+
return None
|
83 |
+
|
84 |
+
mask = Image.fromarray(mask_np)
|
85 |
+
mask = resize_image(mask)
|
86 |
+
|
87 |
+
image = Image.fromarray(image_np)
|
88 |
+
image = resize_image(image)
|
89 |
+
|
90 |
+
output = process_image(
|
91 |
+
image,
|
92 |
+
mask,
|
93 |
+
progress,
|
94 |
+
)
|
95 |
+
|
96 |
+
if output is None:
|
97 |
+
gr.Info("Processing failed")
|
98 |
+
return None
|
99 |
+
progress(100, desc="Processing completed")
|
100 |
+
return image, output
|
101 |
+
|
102 |
+
|
103 |
+
with gr.Blocks() as demo:
|
104 |
+
with gr.Row():
|
105 |
+
with gr.Column():
|
106 |
+
image_and_mask = gr.ImageMask(
|
107 |
+
label="Upload Image and Draw Mask",
|
108 |
+
layers=False,
|
109 |
+
show_fullscreen_button=False,
|
110 |
+
sources=["upload"],
|
111 |
+
show_download_button=False,
|
112 |
+
interactive=True,
|
113 |
+
height="full",
|
114 |
+
width="full",
|
115 |
+
brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"),
|
116 |
+
transforms=[],
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
with gr.Column():
|
121 |
+
image_slider = ImageSlider(
|
122 |
+
label="Result",
|
123 |
+
interactive=False,
|
124 |
+
)
|
125 |
+
|
126 |
+
process_btn = gr.ClearButton(
|
127 |
+
value="Run",
|
128 |
+
variant="primary",
|
129 |
+
size="lg",
|
130 |
+
components=[image_slider],
|
131 |
+
)
|
132 |
+
|
133 |
+
process_btn.click(
|
134 |
+
fn=lambda _: gr.update(interactive=False, value="Processing..."),
|
135 |
+
inputs=[],
|
136 |
+
outputs=[process_btn],
|
137 |
+
api_name=False,
|
138 |
+
).then(
|
139 |
+
fn=process,
|
140 |
+
inputs=[
|
141 |
+
image_and_mask,
|
142 |
+
],
|
143 |
+
outputs=[image_slider],
|
144 |
+
api_name=False,
|
145 |
+
).then(
|
146 |
+
fn=lambda _: gr.update(interactive=True, value="Run"),
|
147 |
+
inputs=[],
|
148 |
+
outputs=[process_btn],
|
149 |
+
api_name=False,
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
demo.launch(
|
155 |
+
debug=False,
|
156 |
+
share=False,
|
157 |
+
show_api=False,
|
158 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.44.0
|
2 |
+
gradio_imageslider==0.0.20
|
3 |
+
torch==2.4.1
|
4 |
+
opencv-python==4.10.0
|
simple_lama_inpainting/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from simple_lama_inpainting.models.model import SimpleLama
|
2 |
+
|
3 |
+
__all__ = ['SimpleLama',]
|
simple_lama_inpainting/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (249 Bytes). View file
|
|
simple_lama_inpainting/cli.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from simple_lama_inpainting.models.model import SimpleLama
|
2 |
+
from PIL import Image
|
3 |
+
from pathlib import Path
|
4 |
+
import fire
|
5 |
+
|
6 |
+
|
7 |
+
def main(image_path: str, mask_path: str, out_path: str | None = None):
|
8 |
+
"""Apply lama inpainting using given image and mask.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
img_path (str): Path to input image (RGB)
|
12 |
+
mask_path (str): Path to input mask (Binary 1-CH Image.
|
13 |
+
Pixels with value 255 will be inpainted)
|
14 |
+
out_path (str, optional): Optional output imaga path.
|
15 |
+
If not provided it will be saved to the same
|
16 |
+
path as input image.
|
17 |
+
Defaults to None.
|
18 |
+
"""
|
19 |
+
image_path = Path(image_path)
|
20 |
+
mask_path = Path(mask_path)
|
21 |
+
|
22 |
+
img = Image.open(image_path).convert("RGB")
|
23 |
+
mask = Image.open(mask_path).convert("L")
|
24 |
+
|
25 |
+
assert img.mode == "RGB" and mask.mode == "L"
|
26 |
+
|
27 |
+
lama = SimpleLama()
|
28 |
+
result = lama(img, mask)
|
29 |
+
if out_path is None:
|
30 |
+
out_path = image_path.with_stem(image_path.stem + "_out")
|
31 |
+
|
32 |
+
Path.mkdir(Path(out_path).parent, exist_ok=True, parents=True)
|
33 |
+
result.save(out_path)
|
34 |
+
print(f"Inpainted image is saved to {out_path}")
|
35 |
+
|
36 |
+
|
37 |
+
def lama_cli():
|
38 |
+
fire.Fire(main)
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == "__main__":
|
42 |
+
fire.Fire(main)
|
simple_lama_inpainting/models/__init__.py
ADDED
File without changes
|
simple_lama_inpainting/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (164 Bytes). View file
|
|
simple_lama_inpainting/models/__pycache__/model.cpython-310.pyc
ADDED
Binary file (1.67 kB). View file
|
|
simple_lama_inpainting/models/model.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from simple_lama_inpainting.utils.util import prepare_img_and_mask, download_model
|
6 |
+
|
7 |
+
LAMA_MODEL_URL = os.environ.get(
|
8 |
+
"LAMA_MODEL_URL",
|
9 |
+
"https://github.com/enesmsahin/simple-lama-inpainting/releases/download/v0.1.0/big-lama.pt", # noqa
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
class SimpleLama:
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
device: torch.device = torch.device(
|
17 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
),
|
19 |
+
) -> None:
|
20 |
+
if os.environ.get("LAMA_MODEL"):
|
21 |
+
model_path = os.environ.get("LAMA_MODEL")
|
22 |
+
if not os.path.exists(model_path):
|
23 |
+
raise FileNotFoundError(
|
24 |
+
f"lama torchscript model not found: {model_path}"
|
25 |
+
)
|
26 |
+
else:
|
27 |
+
model_path = download_model(LAMA_MODEL_URL)
|
28 |
+
|
29 |
+
self.model = torch.jit.load(model_path, map_location=device)
|
30 |
+
self.model.eval()
|
31 |
+
self.model.to(device)
|
32 |
+
self.device = device
|
33 |
+
|
34 |
+
def __call__(self, image: Image.Image | np.ndarray, mask: Image.Image | np.ndarray):
|
35 |
+
image, mask = prepare_img_and_mask(image, mask, self.device)
|
36 |
+
|
37 |
+
with torch.inference_mode():
|
38 |
+
inpainted = self.model(image, mask)
|
39 |
+
|
40 |
+
cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy()
|
41 |
+
cur_res = np.clip(cur_res * 255, 0, 255).astype(np.uint8)
|
42 |
+
|
43 |
+
cur_res = Image.fromarray(cur_res)
|
44 |
+
return cur_res
|
simple_lama_inpainting/utils/__init__.py
ADDED
File without changes
|
simple_lama_inpainting/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (163 Bytes). View file
|
|
simple_lama_inpainting/utils/__pycache__/util.cpython-310.pyc
ADDED
Binary file (2.73 kB). View file
|
|
simple_lama_inpainting/utils/util.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
from PIL import Image
|
7 |
+
from torch.hub import download_url_to_file, get_dir
|
8 |
+
from urllib.parse import urlparse
|
9 |
+
|
10 |
+
|
11 |
+
# Source https://github.com/advimman/lama
|
12 |
+
def get_image(image):
|
13 |
+
if isinstance(image, Image.Image):
|
14 |
+
img = np.array(image)
|
15 |
+
elif isinstance(image, np.ndarray):
|
16 |
+
img = image.copy()
|
17 |
+
else:
|
18 |
+
raise Exception("Input image should be either PIL Image or numpy array!")
|
19 |
+
|
20 |
+
if img.ndim == 3:
|
21 |
+
img = np.transpose(img, (2, 0, 1)) # chw
|
22 |
+
elif img.ndim == 2:
|
23 |
+
img = img[np.newaxis, ...]
|
24 |
+
|
25 |
+
assert img.ndim == 3
|
26 |
+
|
27 |
+
img = img.astype(np.float32) / 255
|
28 |
+
return img
|
29 |
+
|
30 |
+
|
31 |
+
def ceil_modulo(x, mod):
|
32 |
+
if x % mod == 0:
|
33 |
+
return x
|
34 |
+
return (x // mod + 1) * mod
|
35 |
+
|
36 |
+
|
37 |
+
def scale_image(img, factor, interpolation=cv2.INTER_AREA):
|
38 |
+
if img.shape[0] == 1:
|
39 |
+
img = img[0]
|
40 |
+
else:
|
41 |
+
img = np.transpose(img, (1, 2, 0))
|
42 |
+
|
43 |
+
img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
|
44 |
+
|
45 |
+
if img.ndim == 2:
|
46 |
+
img = img[None, ...]
|
47 |
+
else:
|
48 |
+
img = np.transpose(img, (2, 0, 1))
|
49 |
+
return img
|
50 |
+
|
51 |
+
|
52 |
+
def pad_img_to_modulo(img, mod):
|
53 |
+
channels, height, width = img.shape
|
54 |
+
out_height = ceil_modulo(height, mod)
|
55 |
+
out_width = ceil_modulo(width, mod)
|
56 |
+
return np.pad(
|
57 |
+
img,
|
58 |
+
((0, 0), (0, out_height - height), (0, out_width - width)),
|
59 |
+
mode="symmetric",
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def prepare_img_and_mask(image, mask, device, pad_out_to_modulo=8, scale_factor=None):
|
64 |
+
out_image = get_image(image)
|
65 |
+
out_mask = get_image(mask)
|
66 |
+
|
67 |
+
if scale_factor is not None:
|
68 |
+
out_image = scale_image(out_image, scale_factor)
|
69 |
+
out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST)
|
70 |
+
|
71 |
+
if pad_out_to_modulo is not None and pad_out_to_modulo > 1:
|
72 |
+
out_image = pad_img_to_modulo(out_image, pad_out_to_modulo)
|
73 |
+
out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo)
|
74 |
+
|
75 |
+
out_image = torch.from_numpy(out_image).unsqueeze(0).to(device)
|
76 |
+
out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device)
|
77 |
+
|
78 |
+
out_mask = (out_mask > 0) * 1
|
79 |
+
|
80 |
+
return out_image, out_mask
|
81 |
+
|
82 |
+
|
83 |
+
# Source: https://github.com/Sanster/lama-cleaner/blob/6cfc7c30f1d6428c02e21d153048381923498cac/lama_cleaner/helper.py # noqa
|
84 |
+
def get_cache_path_by_url(url):
|
85 |
+
parts = urlparse(url)
|
86 |
+
hub_dir = get_dir()
|
87 |
+
model_dir = os.path.join(hub_dir, "checkpoints")
|
88 |
+
if not os.path.isdir(model_dir):
|
89 |
+
os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
|
90 |
+
filename = os.path.basename(parts.path)
|
91 |
+
cached_file = os.path.join(model_dir, filename)
|
92 |
+
return cached_file
|
93 |
+
|
94 |
+
|
95 |
+
def download_model(url):
|
96 |
+
cached_file = get_cache_path_by_url(url)
|
97 |
+
if not os.path.exists(cached_file):
|
98 |
+
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
99 |
+
hash_prefix = None
|
100 |
+
download_url_to_file(url, cached_file, hash_prefix, progress=True)
|
101 |
+
return cached_file
|