Kuldip2411 commited on
Commit
8478c4c
·
verified ·
1 Parent(s): 344b784

Upload 13 files

Browse files
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio.components.image_editor import EditorValue
3
+ from gradio_imageslider import ImageSlider
4
+ from PIL import Image
5
+ from typing import cast
6
+ import numpy as np
7
+ from simple_lama_inpainting import SimpleLama
8
+
9
+
10
+ simple_lama = SimpleLama()
11
+
12
+ def HWC3(x):
13
+ if x.ndim == 2:
14
+ x = x[:, :, None]
15
+ H, W, C = x.shape
16
+ if C == 3:
17
+ return x
18
+ if C == 1:
19
+ return np.concatenate([x, x, x], axis=2)
20
+ if C == 4:
21
+ color = x[:, :, 0:3].astype(np.float32)
22
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
23
+ y = color * alpha + 255.0 * (1.0 - alpha)
24
+ y = y.clip(0, 255).astype(np.uint8)
25
+ return y
26
+
27
+ def process_image(
28
+ image: Image.Image | str | None,
29
+ mask: Image.Image | str | None,
30
+ progress: gr.Progress = gr.Progress(),
31
+ ) -> Image.Image | None:
32
+ progress(0, desc="Preparing inputs...")
33
+ if image is None or mask is None:
34
+ return None
35
+
36
+ if isinstance(mask, str):
37
+ mask = Image.open(mask)
38
+ if isinstance(image, str):
39
+ image = Image.open(image)
40
+ image = np.array(image)
41
+ image = HWC3(image)
42
+
43
+ result = simple_lama(image, mask)
44
+ result.save("inpainted.png")
45
+ return result
46
+
47
+ def resize_image(img: Image.Image, min_side_length: int = 768) -> Image.Image:
48
+ if img.width <= min_side_length and img.height <= min_side_length:
49
+ return img
50
+
51
+ aspect_ratio = img.width / img.height
52
+ if img.width < img.height:
53
+ new_height = int(min_side_length / aspect_ratio)
54
+ return img.resize((min_side_length, new_height))
55
+
56
+ new_width = int(min_side_length * aspect_ratio)
57
+ return img.resize((new_width, min_side_length))
58
+
59
+
60
+ async def process(
61
+ image_and_mask: EditorValue | None,
62
+ progress: gr.Progress = gr.Progress(),
63
+ ) -> tuple[Image.Image, Image.Image] | None:
64
+ if not image_and_mask:
65
+ gr.Info("Please upload an image and draw a mask")
66
+ return None
67
+
68
+
69
+ image_np = image_and_mask["background"]
70
+ image_np = cast(np.ndarray, image_np)
71
+
72
+ if np.sum(image_np) == 0:
73
+ gr.Info("Please upload an image")
74
+ return None
75
+
76
+ alpha_channel = image_and_mask["layers"][0]
77
+ alpha_channel = cast(np.ndarray, alpha_channel)
78
+ mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8)
79
+
80
+ if np.sum(mask_np) == 0:
81
+ gr.Info("Please mark the areas you want to remove")
82
+ return None
83
+
84
+ mask = Image.fromarray(mask_np)
85
+ mask = resize_image(mask)
86
+
87
+ image = Image.fromarray(image_np)
88
+ image = resize_image(image)
89
+
90
+ output = process_image(
91
+ image,
92
+ mask,
93
+ progress,
94
+ )
95
+
96
+ if output is None:
97
+ gr.Info("Processing failed")
98
+ return None
99
+ progress(100, desc="Processing completed")
100
+ return image, output
101
+
102
+
103
+ with gr.Blocks() as demo:
104
+ with gr.Row():
105
+ with gr.Column():
106
+ image_and_mask = gr.ImageMask(
107
+ label="Upload Image and Draw Mask",
108
+ layers=False,
109
+ show_fullscreen_button=False,
110
+ sources=["upload"],
111
+ show_download_button=False,
112
+ interactive=True,
113
+ height="full",
114
+ width="full",
115
+ brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"),
116
+ transforms=[],
117
+ )
118
+
119
+
120
+ with gr.Column():
121
+ image_slider = ImageSlider(
122
+ label="Result",
123
+ interactive=False,
124
+ )
125
+
126
+ process_btn = gr.ClearButton(
127
+ value="Run",
128
+ variant="primary",
129
+ size="lg",
130
+ components=[image_slider],
131
+ )
132
+
133
+ process_btn.click(
134
+ fn=lambda _: gr.update(interactive=False, value="Processing..."),
135
+ inputs=[],
136
+ outputs=[process_btn],
137
+ api_name=False,
138
+ ).then(
139
+ fn=process,
140
+ inputs=[
141
+ image_and_mask,
142
+ ],
143
+ outputs=[image_slider],
144
+ api_name=False,
145
+ ).then(
146
+ fn=lambda _: gr.update(interactive=True, value="Run"),
147
+ inputs=[],
148
+ outputs=[process_btn],
149
+ api_name=False,
150
+ )
151
+
152
+
153
+ if __name__ == "__main__":
154
+ demo.launch(
155
+ debug=False,
156
+ share=False,
157
+ show_api=False,
158
+ )
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==4.44.0
2
+ gradio_imageslider==0.0.20
3
+ torch==2.4.1
4
+ opencv-python==4.10.0
simple_lama_inpainting/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from simple_lama_inpainting.models.model import SimpleLama
2
+
3
+ __all__ = ['SimpleLama',]
simple_lama_inpainting/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (249 Bytes). View file
 
simple_lama_inpainting/cli.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from simple_lama_inpainting.models.model import SimpleLama
2
+ from PIL import Image
3
+ from pathlib import Path
4
+ import fire
5
+
6
+
7
+ def main(image_path: str, mask_path: str, out_path: str | None = None):
8
+ """Apply lama inpainting using given image and mask.
9
+
10
+ Args:
11
+ img_path (str): Path to input image (RGB)
12
+ mask_path (str): Path to input mask (Binary 1-CH Image.
13
+ Pixels with value 255 will be inpainted)
14
+ out_path (str, optional): Optional output imaga path.
15
+ If not provided it will be saved to the same
16
+ path as input image.
17
+ Defaults to None.
18
+ """
19
+ image_path = Path(image_path)
20
+ mask_path = Path(mask_path)
21
+
22
+ img = Image.open(image_path).convert("RGB")
23
+ mask = Image.open(mask_path).convert("L")
24
+
25
+ assert img.mode == "RGB" and mask.mode == "L"
26
+
27
+ lama = SimpleLama()
28
+ result = lama(img, mask)
29
+ if out_path is None:
30
+ out_path = image_path.with_stem(image_path.stem + "_out")
31
+
32
+ Path.mkdir(Path(out_path).parent, exist_ok=True, parents=True)
33
+ result.save(out_path)
34
+ print(f"Inpainted image is saved to {out_path}")
35
+
36
+
37
+ def lama_cli():
38
+ fire.Fire(main)
39
+
40
+
41
+ if __name__ == "__main__":
42
+ fire.Fire(main)
simple_lama_inpainting/models/__init__.py ADDED
File without changes
simple_lama_inpainting/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (164 Bytes). View file
 
simple_lama_inpainting/models/__pycache__/model.cpython-310.pyc ADDED
Binary file (1.67 kB). View file
 
simple_lama_inpainting/models/model.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from simple_lama_inpainting.utils.util import prepare_img_and_mask, download_model
6
+
7
+ LAMA_MODEL_URL = os.environ.get(
8
+ "LAMA_MODEL_URL",
9
+ "https://github.com/enesmsahin/simple-lama-inpainting/releases/download/v0.1.0/big-lama.pt", # noqa
10
+ )
11
+
12
+
13
+ class SimpleLama:
14
+ def __init__(
15
+ self,
16
+ device: torch.device = torch.device(
17
+ "cuda" if torch.cuda.is_available() else "cpu"
18
+ ),
19
+ ) -> None:
20
+ if os.environ.get("LAMA_MODEL"):
21
+ model_path = os.environ.get("LAMA_MODEL")
22
+ if not os.path.exists(model_path):
23
+ raise FileNotFoundError(
24
+ f"lama torchscript model not found: {model_path}"
25
+ )
26
+ else:
27
+ model_path = download_model(LAMA_MODEL_URL)
28
+
29
+ self.model = torch.jit.load(model_path, map_location=device)
30
+ self.model.eval()
31
+ self.model.to(device)
32
+ self.device = device
33
+
34
+ def __call__(self, image: Image.Image | np.ndarray, mask: Image.Image | np.ndarray):
35
+ image, mask = prepare_img_and_mask(image, mask, self.device)
36
+
37
+ with torch.inference_mode():
38
+ inpainted = self.model(image, mask)
39
+
40
+ cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy()
41
+ cur_res = np.clip(cur_res * 255, 0, 255).astype(np.uint8)
42
+
43
+ cur_res = Image.fromarray(cur_res)
44
+ return cur_res
simple_lama_inpainting/utils/__init__.py ADDED
File without changes
simple_lama_inpainting/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (163 Bytes). View file
 
simple_lama_inpainting/utils/__pycache__/util.cpython-310.pyc ADDED
Binary file (2.73 kB). View file
 
simple_lama_inpainting/utils/util.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import numpy as np
5
+ import cv2
6
+ from PIL import Image
7
+ from torch.hub import download_url_to_file, get_dir
8
+ from urllib.parse import urlparse
9
+
10
+
11
+ # Source https://github.com/advimman/lama
12
+ def get_image(image):
13
+ if isinstance(image, Image.Image):
14
+ img = np.array(image)
15
+ elif isinstance(image, np.ndarray):
16
+ img = image.copy()
17
+ else:
18
+ raise Exception("Input image should be either PIL Image or numpy array!")
19
+
20
+ if img.ndim == 3:
21
+ img = np.transpose(img, (2, 0, 1)) # chw
22
+ elif img.ndim == 2:
23
+ img = img[np.newaxis, ...]
24
+
25
+ assert img.ndim == 3
26
+
27
+ img = img.astype(np.float32) / 255
28
+ return img
29
+
30
+
31
+ def ceil_modulo(x, mod):
32
+ if x % mod == 0:
33
+ return x
34
+ return (x // mod + 1) * mod
35
+
36
+
37
+ def scale_image(img, factor, interpolation=cv2.INTER_AREA):
38
+ if img.shape[0] == 1:
39
+ img = img[0]
40
+ else:
41
+ img = np.transpose(img, (1, 2, 0))
42
+
43
+ img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
44
+
45
+ if img.ndim == 2:
46
+ img = img[None, ...]
47
+ else:
48
+ img = np.transpose(img, (2, 0, 1))
49
+ return img
50
+
51
+
52
+ def pad_img_to_modulo(img, mod):
53
+ channels, height, width = img.shape
54
+ out_height = ceil_modulo(height, mod)
55
+ out_width = ceil_modulo(width, mod)
56
+ return np.pad(
57
+ img,
58
+ ((0, 0), (0, out_height - height), (0, out_width - width)),
59
+ mode="symmetric",
60
+ )
61
+
62
+
63
+ def prepare_img_and_mask(image, mask, device, pad_out_to_modulo=8, scale_factor=None):
64
+ out_image = get_image(image)
65
+ out_mask = get_image(mask)
66
+
67
+ if scale_factor is not None:
68
+ out_image = scale_image(out_image, scale_factor)
69
+ out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST)
70
+
71
+ if pad_out_to_modulo is not None and pad_out_to_modulo > 1:
72
+ out_image = pad_img_to_modulo(out_image, pad_out_to_modulo)
73
+ out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo)
74
+
75
+ out_image = torch.from_numpy(out_image).unsqueeze(0).to(device)
76
+ out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device)
77
+
78
+ out_mask = (out_mask > 0) * 1
79
+
80
+ return out_image, out_mask
81
+
82
+
83
+ # Source: https://github.com/Sanster/lama-cleaner/blob/6cfc7c30f1d6428c02e21d153048381923498cac/lama_cleaner/helper.py # noqa
84
+ def get_cache_path_by_url(url):
85
+ parts = urlparse(url)
86
+ hub_dir = get_dir()
87
+ model_dir = os.path.join(hub_dir, "checkpoints")
88
+ if not os.path.isdir(model_dir):
89
+ os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
90
+ filename = os.path.basename(parts.path)
91
+ cached_file = os.path.join(model_dir, filename)
92
+ return cached_file
93
+
94
+
95
+ def download_model(url):
96
+ cached_file = get_cache_path_by_url(url)
97
+ if not os.path.exists(cached_file):
98
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
99
+ hash_prefix = None
100
+ download_url_to_file(url, cached_file, hash_prefix, progress=True)
101
+ return cached_file