File size: 5,292 Bytes
e8aba21
 
 
 
5c91952
e8aba21
 
 
 
 
 
 
 
5f15b01
e8aba21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f15b01
e8aba21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c91952
e8aba21
 
 
5f15b01
e8aba21
5f15b01
e8aba21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f15b01
e8aba21
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import sys
# sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))
# os.chdir("../")
import gradio as gr
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt
import torch
import tempfile
from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama
#from utils import load_img_to_array, save_array_to_img, dilate_mask, \
#    show_mask, show_points
from PIL import Image
sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "segment-anything"))
import argparse

import os
import matplotlib.pyplot as plt
from pylab import imshow, imsave


import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

import numpy as np
import cv2
import torch

from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog
coco_metadata = MetadataCatalog.get("coco_2017_val")

# import PointRend project
from detectron2.projects import point_rend


title = "PeopleRemover"
description = """
In this space, you can remove the amount of people you want from a picture.
⚠️ This is just a demo version!
"""

def setup_args(parser):
    parser.add_argument(
        "--lama_config", type=str,
        default="./third_party/lama/configs/prediction/default.yaml",
        help="The path to the config file of lama model. "
             "Default: the config of big-lama",
    )
    parser.add_argument(
        "--lama_ckpt", type=str,
        default="pretrained_models/big-lama",
        help="The path to the lama checkpoint.",
    )

def get_mask(img, num_people_keep, dilate_kernel_size):

    cfg = get_cfg()
    # Add PointRend-specific config
    point_rend.add_pointrend_config(cfg)
    # Load a config from file
    cfg.merge_from_file("detectron2_repo/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco.yaml")
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model

    # Set when using CPU
    cfg.MODEL.DEVICE='cpu'

    # Use a model from PointRend model zoo: https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend#pretrained-models
    cfg.MODEL.WEIGHTS = "detectron2://PointRend/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco/28119989/model_final_ba17b9.pkl"
    predictor = DefaultPredictor(cfg)
    outputs = predictor(img)

    # Select 'people' instances 
    people_instances = outputs["instances"][outputs["instances"].pred_classes == 0]

    # Eliminate the instances of the people we want to keep
    eliminate_instances = people_instances[num_people_keep:]

    # Generate mask
    blank_mask = np.ones((image.shape[0],img.shape[1]), dtype=np.uint8) * 255
    full_mask = np.zeros((image.shape[0],img.shape[1]), dtype=np.uint8) * 255

    for instance_mask in eliminate_instances.pred_masks:
        full_mask = full_mask + blank_mask*instance_mask.to("cpu").numpy()
    
    full_mask = full_mask.reshape((img.shape[0],img.shape[1],1))
    mask = (cv2.cvtColor(full_mask, cv2.COLOR_GRAY2RGBA)).astype(np.uint8)

    # Dilation
    kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8)
    mask_dilation = cv2.dilate(mask, kernel, iterations=2)

    return mask_dilation

def get_inpainted_img(img, mask):
    lama_config = args.lama_config
    device = "cuda" if torch.cuda.is_available() else "cpu"
    out = []
    img_inpainted = inpaint_img_with_builded_lama(
        model['lama'], img, mask, lama_config, device=device)
    out.append(img_inpainted)
    return out


def remove_people(img, num_people_keep, dilate_kernel_size):

    mask = get_mask(img, num_people_keep, dilate_kernel_size)

    out = get_inpainted_img(img, mask)

    return out


# get args 
parser = argparse.ArgumentParser()
setup_args(parser)
args = parser.parse_args(sys.argv[1:])
# build models
model = {}

# build the lama model
lama_config = args.lama_config
lama_ckpt = args.lama_ckpt
device = "cuda" if torch.cuda.is_available() else "cpu"
model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
     

with gr.Blocks() as demo:
    features = gr.State(None)
    
    num_people_keep = gr.Number(label="Number of people to keep", minimum=0, maximum=100) 
    dilate_kernel_size = gr.Slider(label="Dilate Kernel Size", minimum=0, maximum=30, step=1, value=5)

    lama = gr.Button("Inpaint Image", variant="primary").style(full_width=True, size="sm")
    clear_button_image = gr.Button(value="Reset", label="Reset", variant="secondary").style(full_width=True, size="sm")

    img = gr.Image(label="Input Image").style(height="200px")

    #mask = gr.outputs.Image(type="numpy", label="Segmentation Mask").style(height="200px")

    img_out = gr.outputs.Image(
                type="numpy", label="Image with People Removed").style(height="200px")

    lama.click(
        get_inpainted_img,
        [img, num_people_keep, dilate_kernel_size],
        [img_out]
    )

    def reset(*args):
        return [None for _ in args]

    clear_button_image.click(
        reset,
        [img, features, img_out],
        [img, features, img_out]
    )

if __name__ == "__main__":
    demo.launch()