snicolau commited on
Commit
e8aba21
β€’
1 Parent(s): 5f15b01

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +158 -22
  2. lama_inpaint.py +205 -0
app.py CHANGED
@@ -1,29 +1,165 @@
 
 
 
 
1
  import gradio as gr
2
- import requests
 
 
 
 
 
 
 
3
  from PIL import Image
4
- from io import BytesIO
5
-
6
- # Function to upload the image to the Hugging Face model
7
- def upload_image(image):
8
- # Send the image to Hugging Face
9
- response = requests.post(
10
- "https://api.deepai.org/api/analyze-image",
11
- files={"image": image},
12
- headers={"api-key": "YOUR_HUGGING_FACE_API_KEY"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Parse the response and get the result
16
- result = response.json()
17
- prediction = result.get("output", "Error")
18
 
19
- return prediction
20
 
21
- # Gradio interface
22
- iface = gr.Interface(
23
- fn=upload_image,
24
- inputs=gr.Image(type="pil", label="Upload Image"),
25
- outputs="text"
26
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Launch the Gradio app
29
- iface.launch()
 
1
+ import os
2
+ import sys
3
+ # sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))
4
+ # os.chdir("../")
5
  import gradio as gr
6
+ import numpy as np
7
+ from pathlib import Path
8
+ from matplotlib import pyplot as plt
9
+ import torch
10
+ import tempfile
11
+ from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama
12
+ #from utils import load_img_to_array, save_array_to_img, dilate_mask, \
13
+ # show_mask, show_points
14
  from PIL import Image
15
+ sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "segment-anything"))
16
+ import argparse
17
+
18
+ import os
19
+ import matplotlib.pyplot as plt
20
+ from pylab import imshow, imsave
21
+
22
+
23
+ import detectron2
24
+ from detectron2.utils.logger import setup_logger
25
+ setup_logger()
26
+
27
+ import numpy as np
28
+ import cv2
29
+ import torch
30
+
31
+ from detectron2 import model_zoo
32
+ from detectron2.engine import DefaultPredictor
33
+ from detectron2.config import get_cfg
34
+ from detectron2.utils.visualizer import Visualizer, ColorMode
35
+ from detectron2.data import MetadataCatalog
36
+ coco_metadata = MetadataCatalog.get("coco_2017_val")
37
+
38
+ # import PointRend project
39
+ from detectron2.projects import point_rend
40
+
41
+
42
+ title = "PeopleRemover"
43
+ description = """
44
+ In this space, you can remove the amount of people you want from a picture.
45
+ ⚠️ This is just a demo version!
46
+ """
47
+
48
+ def setup_args(parser):
49
+ parser.add_argument(
50
+ "--lama_config", type=str,
51
+ default="./third_party/lama/configs/prediction/default.yaml",
52
+ help="The path to the config file of lama model. "
53
+ "Default: the config of big-lama",
54
  )
55
+ parser.add_argument(
56
+ "--lama_ckpt", type=str,
57
+ default="pretrained_models/big-lama",
58
+ help="The path to the lama checkpoint.",
59
+ )
60
+
61
+ def get_mask(img, num_people_keep, dilate_kernel_size):
62
+
63
+ cfg = get_cfg()
64
+ # Add PointRend-specific config
65
+ point_rend.add_pointrend_config(cfg)
66
+ # Load a config from file
67
+ cfg.merge_from_file("detectron2_repo/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco.yaml")
68
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
69
+
70
+ # Set when using CPU
71
+ cfg.MODEL.DEVICE='cpu'
72
+
73
+ # Use a model from PointRend model zoo: https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend#pretrained-models
74
+ cfg.MODEL.WEIGHTS = "detectron2://PointRend/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco/28119989/model_final_ba17b9.pkl"
75
+ predictor = DefaultPredictor(cfg)
76
+ outputs = predictor(img)
77
+
78
+ # Select 'people' instances
79
+ people_instances = outputs["instances"][outputs["instances"].pred_classes == 0]
80
+
81
+ # Eliminate the instances of the people we want to keep
82
+ eliminate_instances = people_instances[num_people_keep:]
83
+
84
+ # Generate mask
85
+ blank_mask = np.ones((image.shape[0],img.shape[1]), dtype=np.uint8) * 255
86
+ full_mask = np.zeros((image.shape[0],img.shape[1]), dtype=np.uint8) * 255
87
+
88
+ for instance_mask in eliminate_instances.pred_masks:
89
+ full_mask = full_mask + blank_mask*instance_mask.to("cpu").numpy()
90
+
91
+ full_mask = full_mask.reshape((img.shape[0],img.shape[1],1))
92
+ mask = (cv2.cvtColor(full_mask, cv2.COLOR_GRAY2RGBA)).astype(np.uint8)
93
 
94
+ # Dilation
95
+ kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8)
96
+ mask_dilation = cv2.dilate(mask, kernel, iterations=2)
97
 
98
+ return mask_dilation
99
 
100
+ def get_inpainted_img(img, mask):
101
+ lama_config = args.lama_config
102
+ device = "cuda" if torch.cuda.is_available() else "cpu"
103
+ out = []
104
+ img_inpainted = inpaint_img_with_builded_lama(
105
+ model['lama'], img, mask, lama_config, device=device)
106
+ out.append(img_inpainted)
107
+ return out
108
+
109
+
110
+ def remove_people(img, num_people_keep, dilate_kernel_size):
111
+
112
+ mask = get_mask(img, num_people_keep, dilate_kernel_size)
113
+
114
+ out = get_inpainted_img(img, mask)
115
+
116
+ return out
117
+
118
+
119
+ # get args
120
+ parser = argparse.ArgumentParser()
121
+ setup_args(parser)
122
+ args = parser.parse_args(sys.argv[1:])
123
+ # build models
124
+ model = {}
125
+
126
+ # build the lama model
127
+ lama_config = args.lama_config
128
+ lama_ckpt = args.lama_ckpt
129
+ device = "cuda" if torch.cuda.is_available() else "cpu"
130
+ model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
131
+
132
+
133
+ with gr.Blocks() as demo:
134
+ features = gr.State(None)
135
+
136
+ num_people_keep = gr.Number(label="Number of people to keep", minimum=0, maximum=100)
137
+ dilate_kernel_size = gr.Slider(label="Dilate Kernel Size", minimum=0, maximum=30, step=1, value=5)
138
+
139
+ lama = gr.Button("Inpaint Image", variant="primary").style(full_width=True, size="sm")
140
+ clear_button_image = gr.Button(value="Reset", label="Reset", variant="secondary").style(full_width=True, size="sm")
141
+
142
+ img = gr.Image(label="Input Image").style(height="200px")
143
+
144
+ #mask = gr.outputs.Image(type="numpy", label="Segmentation Mask").style(height="200px")
145
+
146
+ img_out = gr.outputs.Image(
147
+ type="numpy", label="Image with People Removed").style(height="200px")
148
+
149
+ lama.click(
150
+ get_inpainted_img,
151
+ [img, num_people_keep, dilate_kernel_size],
152
+ [img_out]
153
+ )
154
+
155
+ def reset(*args):
156
+ return [None for _ in args]
157
+
158
+ clear_button_image.click(
159
+ reset,
160
+ [img, features, img_out],
161
+ [img, features, img_out]
162
+ )
163
 
164
+ if __name__ == "__main__":
165
+ demo.launch()
lama_inpaint.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import torch
5
+ import yaml
6
+ import glob
7
+ import argparse
8
+ from PIL import Image
9
+ from omegaconf import OmegaConf
10
+ from pathlib import Path
11
+
12
+ os.environ['OMP_NUM_THREADS'] = '1'
13
+ os.environ['OPENBLAS_NUM_THREADS'] = '1'
14
+ os.environ['MKL_NUM_THREADS'] = '1'
15
+ os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
16
+ os.environ['NUMEXPR_NUM_THREADS'] = '1'
17
+
18
+ sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "lama"))
19
+
20
+ from saicinpainting.evaluation.utils import move_to_device
21
+ from saicinpainting.training.trainers import load_checkpoint
22
+ from saicinpainting.evaluation.data import pad_tensor_to_modulo
23
+
24
+ from utils import load_img_to_array, save_array_to_img
25
+
26
+
27
+ @torch.no_grad()
28
+ def inpaint_img_with_lama(
29
+ img: np.ndarray,
30
+ mask: np.ndarray,
31
+ config_p: str,
32
+ ckpt_p: str,
33
+ mod=8,
34
+ device="cuda"
35
+ ):
36
+ assert len(mask.shape) == 2
37
+ if np.max(mask) == 1:
38
+ mask = mask * 255
39
+ img = torch.from_numpy(img).float().div(255.)
40
+ mask = torch.from_numpy(mask).float()
41
+ predict_config = OmegaConf.load(config_p)
42
+ predict_config.model.path = ckpt_p
43
+ # device = torch.device(predict_config.device)
44
+ device = torch.device(device)
45
+
46
+ train_config_path = os.path.join(
47
+ predict_config.model.path, 'config.yaml')
48
+
49
+ with open(train_config_path, 'r') as f:
50
+ train_config = OmegaConf.create(yaml.safe_load(f))
51
+
52
+ train_config.training_model.predict_only = True
53
+ train_config.visualizer.kind = 'noop'
54
+
55
+ checkpoint_path = os.path.join(
56
+ predict_config.model.path, 'models',
57
+ predict_config.model.checkpoint
58
+ )
59
+ model = load_checkpoint(
60
+ train_config, checkpoint_path, strict=False, map_location=device)
61
+ model.freeze()
62
+ if not predict_config.get('refine', False):
63
+ model.to(device)
64
+
65
+ batch = {}
66
+ batch['image'] = img.permute(2, 0, 1).unsqueeze(0)
67
+ batch['mask'] = mask[None, None]
68
+ unpad_to_size = [batch['image'].shape[2], batch['image'].shape[3]]
69
+ batch['image'] = pad_tensor_to_modulo(batch['image'], mod)
70
+ batch['mask'] = pad_tensor_to_modulo(batch['mask'], mod)
71
+ batch = move_to_device(batch, device)
72
+ batch['mask'] = (batch['mask'] > 0) * 1
73
+
74
+ batch = model(batch)
75
+ cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
76
+ cur_res = cur_res.detach().cpu().numpy()
77
+
78
+ if unpad_to_size is not None:
79
+ orig_height, orig_width = unpad_to_size
80
+ cur_res = cur_res[:orig_height, :orig_width]
81
+
82
+ cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
83
+ return cur_res
84
+
85
+
86
+ def build_lama_model(
87
+ config_p: str,
88
+ ckpt_p: str,
89
+ device="cuda"
90
+ ):
91
+ predict_config = OmegaConf.load(config_p)
92
+ predict_config.model.path = ckpt_p
93
+ # device = torch.device(predict_config.device)
94
+ device = torch.device(device)
95
+
96
+ train_config_path = os.path.join(
97
+ predict_config.model.path, 'config.yaml')
98
+
99
+ with open(train_config_path, 'r') as f:
100
+ train_config = OmegaConf.create(yaml.safe_load(f))
101
+
102
+ train_config.training_model.predict_only = True
103
+ train_config.visualizer.kind = 'noop'
104
+
105
+ checkpoint_path = os.path.join(
106
+ predict_config.model.path, 'models',
107
+ predict_config.model.checkpoint
108
+ )
109
+ model = load_checkpoint(
110
+ train_config, checkpoint_path, strict=False, map_location=device)
111
+ model.freeze()
112
+ if not predict_config.get('refine', False):
113
+ model.to(device)
114
+
115
+ return model
116
+
117
+
118
+ @torch.no_grad()
119
+ def inpaint_img_with_builded_lama(
120
+ model,
121
+ img: np.ndarray,
122
+ mask: np.ndarray,
123
+ config_p: str,
124
+ mod=8,
125
+ device="cuda"
126
+ ):
127
+ assert len(mask.shape) == 2
128
+ if np.max(mask) == 1:
129
+ mask = mask * 255
130
+ img = torch.from_numpy(img).float().div(255.)
131
+ mask = torch.from_numpy(mask).float()
132
+ predict_config = OmegaConf.load(config_p)
133
+
134
+ batch = {}
135
+ batch['image'] = img.permute(2, 0, 1).unsqueeze(0)
136
+ batch['mask'] = mask[None, None]
137
+ unpad_to_size = [batch['image'].shape[2], batch['image'].shape[3]]
138
+ batch['image'] = pad_tensor_to_modulo(batch['image'], mod)
139
+ batch['mask'] = pad_tensor_to_modulo(batch['mask'], mod)
140
+ batch = move_to_device(batch, device)
141
+ batch['mask'] = (batch['mask'] > 0) * 1
142
+
143
+ batch = model(batch)
144
+ cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
145
+ cur_res = cur_res.detach().cpu().numpy()
146
+
147
+ if unpad_to_size is not None:
148
+ orig_height, orig_width = unpad_to_size
149
+ cur_res = cur_res[:orig_height, :orig_width]
150
+
151
+ cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
152
+ return cur_res
153
+
154
+
155
+ def setup_args(parser):
156
+ parser.add_argument(
157
+ "--input_img", type=str, required=True,
158
+ help="Path to a single input img",
159
+ )
160
+ parser.add_argument(
161
+ "--input_mask_glob", type=str, required=True,
162
+ help="Glob to input masks",
163
+ )
164
+ parser.add_argument(
165
+ "--output_dir", type=str, required=True,
166
+ help="Output path to the directory with results.",
167
+ )
168
+ parser.add_argument(
169
+ "--lama_config", type=str,
170
+ default="./third_party/lama/configs/prediction/default.yaml",
171
+ help="The path to the config file of lama model. "
172
+ "Default: the config of big-lama",
173
+ )
174
+ parser.add_argument(
175
+ "--lama_ckpt", type=str, required=True,
176
+ help="The path to the lama checkpoint.",
177
+ )
178
+
179
+
180
+ if __name__ == "__main__":
181
+ """Example usage:
182
+ python lama_inpaint.py \
183
+ --input_img FA_demo/FA1_dog.png \
184
+ --input_mask_glob "results/FA1_dog/mask*.png" \
185
+ --output_dir results \
186
+ --lama_config lama/configs/prediction/default.yaml \
187
+ --lama_ckpt big-lama
188
+ """
189
+ parser = argparse.ArgumentParser()
190
+ setup_args(parser)
191
+ args = parser.parse_args(sys.argv[1:])
192
+ device = "cuda" if torch.cuda.is_available() else "cpu"
193
+
194
+ img_stem = Path(args.input_img).stem
195
+ mask_ps = sorted(glob.glob(args.input_mask_glob))
196
+ out_dir = Path(args.output_dir) / img_stem
197
+ out_dir.mkdir(parents=True, exist_ok=True)
198
+
199
+ img = load_img_to_array(args.input_img)
200
+ for mask_p in mask_ps:
201
+ mask = load_img_to_array(mask_p)
202
+ img_inpainted_p = out_dir / f"inpainted_with_{Path(mask_p).name}"
203
+ img_inpainted = inpaint_img_with_lama(
204
+ img, mask, args.lama_config, args.lama_ckpt, device=device)
205
+ save_array_to_img(img_inpainted, img_inpainted_p)