Spaces:
Running
Running
File size: 5,407 Bytes
e8aba21 5c91952 e8aba21 5f15b01 e8aba21 897fc9e 98356f1 e8aba21 98356f1 e8aba21 5f15b01 e8aba21 5c91952 e8aba21 5f15b01 e8aba21 5f15b01 e8aba21 5f15b01 e8aba21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import os
import sys
# sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))
# os.chdir("../")
import gradio as gr
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt
import torch
import tempfile
from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama
#from utils import load_img_to_array, save_array_to_img, dilate_mask, \
# show_mask, show_points
from PIL import Image
sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "segment-anything"))
import argparse
import os
import matplotlib.pyplot as plt
from pylab import imshow, imsave
import detectron2_repo as detectron2
from detectron2_repo.detectron2.utils.logger import setup_logger
setup_logger()
import numpy as np
import cv2
import torch
from detectron2_repo.detectron2 import model_zoo
from detectron2_repo.detectron2.engine import DefaultPredictor
from detectron2_repo.detectron2.config import get_cfg
from detectron2_repo.detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2_repo.detectron2.data import MetadataCatalog
coco_metadata = MetadataCatalog.get("coco_2017_val")
# import PointRend project
from detectron2.projects import point_rend
title = "PeopleRemover"
description = """
In this space, you can remove the amount of people you want from a picture.
β οΈ This is just a demo version!
"""
def setup_args(parser):
parser.add_argument(
"--lama_config", type=str,
default="./third_party/lama/configs/prediction/default.yaml",
help="The path to the config file of lama model. "
"Default: the config of big-lama",
)
parser.add_argument(
"--lama_ckpt", type=str,
default="pretrained_models/big-lama",
help="The path to the lama checkpoint.",
)
def get_mask(img, num_people_keep, dilate_kernel_size):
cfg = get_cfg()
# Add PointRend-specific config
point_rend.add_pointrend_config(cfg)
# Load a config from file
cfg.merge_from_file("detectron2_repo/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
# Set when using CPU
cfg.MODEL.DEVICE='cpu'
# Use a model from PointRend model zoo: https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend#pretrained-models
cfg.MODEL.WEIGHTS = "detectron2://PointRend/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco/28119989/model_final_ba17b9.pkl"
predictor = DefaultPredictor(cfg)
outputs = predictor(img)
# Select 'people' instances
people_instances = outputs["instances"][outputs["instances"].pred_classes == 0]
# Eliminate the instances of the people we want to keep
eliminate_instances = people_instances[num_people_keep:]
# Generate mask
blank_mask = np.ones((image.shape[0],img.shape[1]), dtype=np.uint8) * 255
full_mask = np.zeros((image.shape[0],img.shape[1]), dtype=np.uint8) * 255
for instance_mask in eliminate_instances.pred_masks:
full_mask = full_mask + blank_mask*instance_mask.to("cpu").numpy()
full_mask = full_mask.reshape((img.shape[0],img.shape[1],1))
mask = (cv2.cvtColor(full_mask, cv2.COLOR_GRAY2RGBA)).astype(np.uint8)
# Dilation
kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8)
mask_dilation = cv2.dilate(mask, kernel, iterations=2)
return mask_dilation
def get_inpainted_img(img, mask):
lama_config = args.lama_config
device = "cuda" if torch.cuda.is_available() else "cpu"
out = []
img_inpainted = inpaint_img_with_builded_lama(
model['lama'], img, mask, lama_config, device=device)
out.append(img_inpainted)
return out
def remove_people(img, num_people_keep, dilate_kernel_size):
mask = get_mask(img, num_people_keep, dilate_kernel_size)
out = get_inpainted_img(img, mask)
return out
# get args
parser = argparse.ArgumentParser()
setup_args(parser)
args = parser.parse_args(sys.argv[1:])
# build models
model = {}
# build the lama model
lama_config = args.lama_config
lama_ckpt = args.lama_ckpt
device = "cuda" if torch.cuda.is_available() else "cpu"
model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
with gr.Blocks() as demo:
features = gr.State(None)
num_people_keep = gr.Number(label="Number of people to keep", minimum=0, maximum=100)
dilate_kernel_size = gr.Slider(label="Dilate Kernel Size", minimum=0, maximum=30, step=1, value=5)
lama = gr.Button("Inpaint Image", variant="primary").style(full_width=True, size="sm")
clear_button_image = gr.Button(value="Reset", label="Reset", variant="secondary").style(full_width=True, size="sm")
img = gr.Image(label="Input Image").style(height="200px")
#mask = gr.outputs.Image(type="numpy", label="Segmentation Mask").style(height="200px")
img_out = gr.outputs.Image(
type="numpy", label="Image with People Removed").style(height="200px")
lama.click(
get_inpainted_img,
[img, num_people_keep, dilate_kernel_size],
[img_out]
)
def reset(*args):
return [None for _ in args]
clear_button_image.click(
reset,
[img, features, img_out],
[img, features, img_out]
)
if __name__ == "__main__":
demo.launch() |