Spaces:
Runtime error
Runtime error
import os | |
from typing import Tuple | |
import cv2 | |
import numpy as np | |
from torch.hub import get_dir | |
from loguru import logger | |
from lama_cleaner.plugins.base_plugin import BasePlugin | |
class RemoveBG(BasePlugin): | |
name = "RemoveBG" | |
def __init__(self): | |
super().__init__() | |
from rembg import new_session | |
# TODO Update for local development | |
hub_dir = get_dir() | |
# model_dir = os.path.join(hub_dir, "checkpoints") | |
model_dir = os.getcwd() | |
# os.environ["U2NET_HOME"] = model_dir | |
# os.environ["U2NET_HOME"] = os.getcwd() | |
os.environ["U2NET_HOME"] = '/tmp/' | |
logger.info(f"Load remove model from: {model_dir}") | |
self.session = new_session(model_name="u2net") | |
def __call__(self, rgb_np_img, files=None, form=None): | |
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) | |
return self.forward(bgr_np_img) | |
def forward(self, bgr_np_img) -> np.ndarray: | |
from rembg import remove | |
output = remove(bgr_np_img, session=self.session) | |
return output | |
def check_dep(self): | |
try: | |
import rembg | |
except ImportError: | |
return ( | |
"RemoveBG is not installed, please install it first. pip install rembg" | |
) | |