Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch.nn.functional as F | |
from torchvision.transforms.functional import normalize | |
from skimage import io | |
import torch, os | |
from PIL import Image | |
from briarmbg import BriaRMBG | |
import gradio as gr | |
import cv2 | |
import numpy as np | |
import time | |
import random | |
from PIL import Image | |
bgrm = BriaRMBG.from_pretrained("briaai/RMBG-1.4") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
bgrm.to(device) | |
print("device:", device) | |
def resize_image(image): | |
image = image.convert('RGB') | |
model_input_size = (1024, 1024) | |
image = image.resize(model_input_size, Image.BILINEAR) | |
return image | |
def process(image): | |
# prepare input | |
orig_image = Image.fromarray(image) | |
w,h = orig_im_size = orig_image.size | |
image = resize_image(orig_image) | |
im_np = np.array(image) | |
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1) | |
im_tensor = torch.unsqueeze(im_tensor,0) | |
im_tensor = torch.divide(im_tensor,255.0) | |
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) | |
if torch.cuda.is_available(): | |
im_tensor=im_tensor.cuda() | |
#inference | |
result=bgrm(im_tensor) | |
# post process | |
result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0) | |
ma = torch.max(result) | |
mi = torch.min(result) | |
result = (result-mi)/(ma-mi) | |
# image to pil | |
im_array = (result*255).cpu().data.numpy().astype(np.uint8) | |
pil_im = Image.fromarray(np.squeeze(im_array)) | |
# paste the mask on the original image | |
new_im = Image.new("RGBA", pil_im.size, (0,255,0,255)) | |
new_im.paste(orig_image, mask=pil_im) | |
# new_orig_image = orig_image.convert('RGBA') | |
return new_im | |
def process_video(video, progress=gr.Progress()): | |
cap = cv2.VideoCapture(video) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Get total frames | |
writer = None | |
tmpname ='output.mp4' | |
processed_frames = 0 | |
start_time = time.time() | |
i=0 | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if ret is False: | |
break | |
if time.time() - start_time >= 20 * 60 - 5: | |
print("GPU Timeout is coming") | |
cap.release() | |
writer.release() | |
return tmpname | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
img = Image.fromarray(frame).convert('RGB') | |
if writer is None: | |
writer = cv2.VideoWriter(tmpname, cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), img.size) | |
processed_frames += 1 | |
print(f"Processing frame {processed_frames}") | |
progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}") | |
out = process(np.array(img)) | |
writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB)) | |
cap.release() | |
writer.release() | |
return tmpname | |
title = "ποΈ Video Background Removal Tool π₯" | |
description = """Please note that if your video file is long (has a high number of frames), there is a chance that processing break due to GPU timeout. In this case, consider trying Fast mode.""" | |
examples = [['./input.mp4']] | |
iface = gr.Interface( | |
fn=process_video, | |
inputs=["video"], | |
outputs="video", | |
examples=examples, | |
title=title, | |
description=description | |
) | |
iface.launch() | |