Spaces:
Sleeping
Sleeping
import os | |
import sys | |
# sys.path.append(os.path.abspath(os.path.dirname(os.getcwd()))) | |
# os.chdir("../") | |
import gradio as gr | |
import numpy as np | |
from pathlib import Path | |
from matplotlib import pyplot as plt | |
import torch | |
import tempfile | |
from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama | |
#from utils import load_img_to_array, save_array_to_img, dilate_mask, \ | |
# show_mask, show_points | |
from PIL import Image | |
sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "segment-anything")) | |
import argparse | |
import os | |
import matplotlib.pyplot as plt | |
from pylab import imshow, imsave | |
import detectron2 | |
from detectron2.utils.logger import setup_logger | |
setup_logger() | |
import numpy as np | |
import cv2 | |
import torch | |
from detectron2 import model_zoo | |
from detectron2.engine import DefaultPredictor | |
from detectron2.config import get_cfg | |
from detectron2.utils.visualizer import Visualizer, ColorMode | |
from detectron2.data import MetadataCatalog | |
coco_metadata = MetadataCatalog.get("coco_2017_val") | |
# import PointRend project | |
from detectron2_repo.projects.PointRend import point_rend | |
title = "PeopleRemover" | |
description = """ | |
In this space, you can remove the amount of people you want from a picture. | |
β οΈ This is just a demo version! | |
""" | |
def setup_args(parser): | |
parser.add_argument( | |
"--lama_config", type=str, | |
default="./third_party/lama/configs/prediction/default.yaml", | |
help="The path to the config file of lama model. " | |
"Default: the config of big-lama", | |
) | |
parser.add_argument( | |
"--lama_ckpt", type=str, | |
default="pretrained_models/big-lama", | |
help="The path to the lama checkpoint.", | |
) | |
def get_mask(img, num_people_keep, dilate_kernel_size): | |
cfg = get_cfg() | |
# Add PointRend-specific config | |
point_rend.add_pointrend_config(cfg) | |
# Load a config from file | |
cfg.merge_from_file("detectron2_repo/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco.yaml") | |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model | |
# Set when using CPU | |
cfg.MODEL.DEVICE='cpu' | |
# Use a model from PointRend model zoo: https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend#pretrained-models | |
cfg.MODEL.WEIGHTS = "detectron2://PointRend/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco/28119989/model_final_ba17b9.pkl" | |
predictor = DefaultPredictor(cfg) | |
outputs = predictor(img) | |
# Select 'people' instances | |
people_instances = outputs["instances"][outputs["instances"].pred_classes == 0] | |
# Eliminate the instances of the people we want to keep | |
eliminate_instances = people_instances[num_people_keep:] | |
# Generate mask | |
blank_mask = np.ones((img.shape[0],img.shape[1]), dtype=np.uint8) * 255 | |
full_mask = np.zeros((img.shape[0],img.shape[1]), dtype=np.uint8) * 255 | |
for instance_mask in eliminate_instances.pred_masks: | |
full_mask = full_mask + blank_mask*instance_mask.to("cpu").numpy() | |
full_mask = full_mask.reshape((img.shape[0],img.shape[1],1)) | |
mask = full_mask.astype(np.uint8) | |
# Dilation | |
kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8) | |
mask_dilation = cv2.dilate(mask, kernel, iterations=2) | |
return mask_dilation | |
def get_inpainted_img(img, mask): | |
lama_config = args.lama_config | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
img_inpainted = inpaint_img_with_builded_lama( | |
model['lama'], img, mask, lama_config, device=device) | |
return img_inpainted | |
def remove_people(img, num_people_keep, dilate_kernel_size): | |
print('Obtaining mask...') | |
mask = get_mask(img, num_people_keep, dilate_kernel_size) | |
print('Mask obtained') | |
print('Inpainting with LAMA...') | |
out = get_inpainted_img(img, mask) | |
print('Image Inpainted!') | |
return out | |
# get args | |
parser = argparse.ArgumentParser() | |
setup_args(parser) | |
args = parser.parse_args(sys.argv[1:]) | |
# build models | |
model = {} | |
# build the lama model | |
lama_config = args.lama_config | |
lama_ckpt = args.lama_ckpt | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device) | |
with gr.Blocks() as demo: | |
gr.Markdown(description) | |
features = gr.State(None) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
img = gr.Image(height=300)# value="Input Image" .style(height="200px") | |
num_people_keep = gr.Number(label="Number of people to keep", minimum=0, maximum=100) | |
dilate_kernel_size = gr.Slider(label="Dilate Kernel Size", minimum=0, maximum=30, step=1, value=5) | |
lama = gr.Button(value="Remove people", variant="primary", size="sm")#.style(full_width=True, size="sm") | |
clear_button_image = gr.Button(value="Reset", variant="secondary", size="sm")#.style(full_width=True, size="sm") | |
with gr.Column(scale=1): | |
img_out = gr.Image(interactive=False,show_download_button=True)# value="Image with People Removed", type="numpy", .style(height="200px") | |
#mask = gr.outputs.Image(type="numpy", label="Segmentation Mask")#.style(height="200px") | |
lama.click( | |
remove_people, | |
[img, num_people_keep, dilate_kernel_size], | |
[img_out] | |
) | |
def reset(*args): | |
return [None for _ in args] | |
clear_button_image.click( | |
reset, | |
[img, features, img_out], | |
[img, features, img_out] | |
) | |
if __name__ == "__main__": | |
demo.launch() |