Spaces:
Running
Running
File size: 6,000 Bytes
e8aba21 5c91952 e8aba21 5f15b01 4d41a79 e8aba21 98a5805 b3f612f e8aba21 98a5805 e8aba21 9d8045e e8aba21 4d41a79 e8aba21 5f15b01 e8aba21 0444efc e8aba21 a9d7705 5c91952 e8aba21 5f15b01 e8aba21 5f15b01 e8aba21 7e0f853 e8aba21 7e0f853 e8aba21 c9943d2 e8aba21 c9943d2 e8aba21 c9943d2 e8aba21 4d41a79 c9943d2 e8aba21 b8e3c24 d12d5fa b8e3c24 1ba3f94 b8e3c24 d12d5fa b8e3c24 51a2f5e e8aba21 5a38b1d e8aba21 bffbfdb 4d41a79 53ffb8e 8474f38 e1a8ecf 53ffb8e 4d41a79 50d5c37 4d41a79 5f15b01 e8aba21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import os
import sys
# os.chdir("../")
import gradio as gr
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt
import torch
import tempfile
from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama
from PIL import Image
#sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "segment-anything"))
import argparse
import os
import matplotlib.pyplot as plt
from pylab import imshow, imsave
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
import numpy as np
import cv2
import torch
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog
coco_metadata = MetadataCatalog.get("coco_2017_val")
# import PointRend project
from detectron2_repo.projects.PointRend import point_rend
title = "# PeopleRemover"
description = """
In this space, you can remove the amount of people you want from a picture.
β οΈ This is just a demo version!
"""
def setup_args(parser):
parser.add_argument(
"--lama_config", type=str,
default="./third_party/lama/configs/prediction/default.yaml",
help="The path to the config file of lama model. "
"Default: the config of big-lama",
)
parser.add_argument(
"--lama_ckpt", type=str,
default="pretrained_models/big-lama",
help="The path to the lama checkpoint.",
)
def get_mask(img, num_people_keep, dilate_kernel_size):
cfg = get_cfg()
# Add PointRend-specific config
point_rend.add_pointrend_config(cfg)
# Load a config from file
cfg.merge_from_file("detectron2_repo/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
# Set when using CPU
cfg.MODEL.DEVICE='cpu'
# Use a model from PointRend model zoo: https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend#pretrained-models
cfg.MODEL.WEIGHTS = "detectron2://PointRend/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco/28119989/model_final_ba17b9.pkl"
predictor = DefaultPredictor(cfg)
outputs = predictor(img)
# Select 'people' instances
people_instances = outputs["instances"][outputs["instances"].pred_classes == 0]
# Eliminate the instances of the people we want to keep
eliminate_instances = people_instances[num_people_keep:]
# Generate mask
blank_mask = np.ones((img.shape[0],img.shape[1]), dtype=np.uint8) * 255
full_mask = np.zeros((img.shape[0],img.shape[1]), dtype=np.uint8) * 255
for instance_mask in eliminate_instances.pred_masks:
full_mask = full_mask + blank_mask*instance_mask.to("cpu").numpy()
full_mask = full_mask.reshape((img.shape[0],img.shape[1],1))
mask = full_mask.astype(np.uint8)
# Dilation
kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8)
mask_dilation = cv2.dilate(mask, kernel, iterations=2)
return mask_dilation
def get_inpainted_img(img, mask):
lama_config = args.lama_config
device = "cuda" if torch.cuda.is_available() else "cpu"
img_inpainted = inpaint_img_with_builded_lama(
model['lama'], img, mask, lama_config, device=device)
return img_inpainted
def remove_people(img, num_people_keep, dilate_kernel_size):
print('Obtaining mask...')
mask = get_mask(img, num_people_keep, dilate_kernel_size)
print('Mask obtained')
print('Inpainting with LAMA...')
out = get_inpainted_img(img, mask)
print('Image Inpainted!')
return out
# get args
parser = argparse.ArgumentParser()
setup_args(parser)
args = parser.parse_args(sys.argv[1:])
# build models
model = {}
# build the lama model
lama_config = args.lama_config
lama_ckpt = args.lama_ckpt
device = "cuda" if torch.cuda.is_available() else "cpu"
model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
features = gr.State(None)
with gr.Row():
with gr.Column(scale=1):
img = gr.Image(height=300)# value="Input Image" .style(height="200px")
num_people_keep = gr.Number(label="Number of people to keep", minimum=0, maximum=100)
dilate_kernel_size = gr.Slider(label="Dilate Kernel Size", minimum=0, maximum=30, step=1, value=5)
lama = gr.Button(value="Remove people", variant="primary", size="sm")#.style(full_width=True, size="sm")
clear_button_image = gr.Button(value="Reset", variant="secondary", size="sm")#.style(full_width=True, size="sm")
with gr.Column(scale=1):
img_out = gr.Image(interactive=False,show_download_button=True)# value="Image with People Removed", type="numpy", .style(height="200px")
#mask = gr.outputs.Image(type="numpy", label="Segmentation Mask")#.style(height="200px")
lama.click(
remove_people,
[img, num_people_keep, dilate_kernel_size],
[img_out]
)
def reset(*args):
return [None for _ in args]
clear_button_image.click(
reset,
[img, features, img_out],
[img, features, img_out]
)
gr.Examples(
examples=[[os.path.join(os.getcwd(), "examples/002.jpg"), 2, 15],
[os.path.join(os.getcwd(), "examples/013.jpg"), 1, 15],
[os.path.join(os.getcwd(), "examples/014.jpg"), 1, 15],
[os.path.join(os.getcwd(), "examples/015.jpg"), 1, 25],
[os.path.join(os.getcwd(), "examples/002.jpg"), 0, 15]],
inputs=[img, num_people_keep, dilate_kernel_size],
outputs=img_out,
fn=remove_people,
cache_examples=True,
)
if __name__ == "__main__":
demo.launch() |