File size: 6,000 Bytes
e8aba21
 
 
5c91952
e8aba21
 
 
 
 
 
5f15b01
4d41a79
e8aba21
 
 
 
 
 
 
98a5805
b3f612f
e8aba21
 
 
 
 
 
98a5805
 
 
 
 
e8aba21
 
 
9d8045e
e8aba21
 
4d41a79
e8aba21
 
 
 
 
 
 
 
 
 
 
5f15b01
e8aba21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0444efc
 
e8aba21
 
 
 
 
a9d7705
5c91952
e8aba21
 
 
5f15b01
e8aba21
5f15b01
e8aba21
 
 
7e0f853
e8aba21
 
7e0f853
 
e8aba21
 
 
c9943d2
e8aba21
c9943d2
 
e8aba21
c9943d2
e8aba21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d41a79
c9943d2
e8aba21
 
b8e3c24
 
d12d5fa
b8e3c24
 
 
 
1ba3f94
 
 
b8e3c24
d12d5fa
b8e3c24
51a2f5e
e8aba21
 
5a38b1d
e8aba21
 
 
 
 
 
 
 
 
 
 
 
bffbfdb
4d41a79
53ffb8e
 
 
8474f38
e1a8ecf
53ffb8e
4d41a79
50d5c37
4d41a79
 
5f15b01
e8aba21
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import sys
# os.chdir("../")
import gradio as gr
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt
import torch
import tempfile
from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama
from PIL import Image
#sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "segment-anything"))
import argparse

import os
import matplotlib.pyplot as plt
from pylab import imshow, imsave


import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

import numpy as np
import cv2
import torch

from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog
coco_metadata = MetadataCatalog.get("coco_2017_val")

# import PointRend project
from detectron2_repo.projects.PointRend import point_rend


title = "# PeopleRemover"
description = """
In this space, you can remove the amount of people you want from a picture.
⚠️ This is just a demo version!
"""

def setup_args(parser):
    parser.add_argument(
        "--lama_config", type=str,
        default="./third_party/lama/configs/prediction/default.yaml",
        help="The path to the config file of lama model. "
             "Default: the config of big-lama",
    )
    parser.add_argument(
        "--lama_ckpt", type=str,
        default="pretrained_models/big-lama",
        help="The path to the lama checkpoint.",
    )

def get_mask(img, num_people_keep, dilate_kernel_size):

    cfg = get_cfg()
    # Add PointRend-specific config
    point_rend.add_pointrend_config(cfg)
    # Load a config from file
    cfg.merge_from_file("detectron2_repo/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco.yaml")
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model

    # Set when using CPU
    cfg.MODEL.DEVICE='cpu'

    # Use a model from PointRend model zoo: https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend#pretrained-models
    cfg.MODEL.WEIGHTS = "detectron2://PointRend/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco/28119989/model_final_ba17b9.pkl"
    predictor = DefaultPredictor(cfg)
    outputs = predictor(img)

    # Select 'people' instances 
    people_instances = outputs["instances"][outputs["instances"].pred_classes == 0]

    # Eliminate the instances of the people we want to keep
    eliminate_instances = people_instances[num_people_keep:]

    # Generate mask
    blank_mask = np.ones((img.shape[0],img.shape[1]), dtype=np.uint8) * 255
    full_mask = np.zeros((img.shape[0],img.shape[1]), dtype=np.uint8) * 255

    for instance_mask in eliminate_instances.pred_masks:
        full_mask = full_mask + blank_mask*instance_mask.to("cpu").numpy()
    
    full_mask = full_mask.reshape((img.shape[0],img.shape[1],1))
    mask = full_mask.astype(np.uint8)

    # Dilation
    kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8)
    mask_dilation = cv2.dilate(mask, kernel, iterations=2)

    return mask_dilation

def get_inpainted_img(img, mask):
    lama_config = args.lama_config
    device = "cuda" if torch.cuda.is_available() else "cpu"

    img_inpainted = inpaint_img_with_builded_lama(
        model['lama'], img, mask, lama_config, device=device)

    return img_inpainted


def remove_people(img, num_people_keep, dilate_kernel_size):
    print('Obtaining mask...')
    mask = get_mask(img, num_people_keep, dilate_kernel_size)
    print('Mask obtained')
    print('Inpainting with LAMA...')
    out = get_inpainted_img(img, mask)
    print('Image Inpainted!')
    return out


# get args 
parser = argparse.ArgumentParser()
setup_args(parser)
args = parser.parse_args(sys.argv[1:])
# build models
model = {}

# build the lama model
lama_config = args.lama_config
lama_ckpt = args.lama_ckpt
device = "cuda" if torch.cuda.is_available() else "cpu"
model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
     

with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    features = gr.State(None)

    with gr.Row():
        with gr.Column(scale=1):
            img = gr.Image(height=300)# value="Input Image" .style(height="200px")
            
            num_people_keep = gr.Number(label="Number of people to keep", minimum=0, maximum=100) 
            dilate_kernel_size = gr.Slider(label="Dilate Kernel Size", minimum=0, maximum=30, step=1, value=5)
            
            lama = gr.Button(value="Remove people", variant="primary", size="sm")#.style(full_width=True, size="sm")
            clear_button_image = gr.Button(value="Reset", variant="secondary", size="sm")#.style(full_width=True, size="sm")
            
        with gr.Column(scale=1):
            img_out = gr.Image(interactive=False,show_download_button=True)# value="Image with People Removed", type="numpy", .style(height="200px")
    
    #mask = gr.outputs.Image(type="numpy", label="Segmentation Mask")#.style(height="200px")

    lama.click(
        remove_people,
        [img, num_people_keep, dilate_kernel_size],
        [img_out]
    )

    def reset(*args):
        return [None for _ in args]

    clear_button_image.click(
        reset,
        [img, features, img_out],
        [img, features, img_out]
    )

    gr.Examples(
        examples=[[os.path.join(os.getcwd(), "examples/002.jpg"), 2, 15], 
                  [os.path.join(os.getcwd(), "examples/013.jpg"), 1, 15], 
                  [os.path.join(os.getcwd(), "examples/014.jpg"), 1, 15], 
                  [os.path.join(os.getcwd(), "examples/015.jpg"), 1, 25],
                  [os.path.join(os.getcwd(), "examples/002.jpg"), 0, 15]],
        inputs=[img, num_people_keep, dilate_kernel_size],
        outputs=img_out,
        fn=remove_people,
        cache_examples=True,
    )

if __name__ == "__main__":
    demo.launch()