File size: 1,284 Bytes
252e766
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
from typing import Tuple

import cv2
import numpy as np
from torch.hub import get_dir
from loguru import logger

from lama_cleaner.plugins.base_plugin import BasePlugin


class RemoveBG(BasePlugin):
    name = "RemoveBG"

    def __init__(self):
        super().__init__()
        from rembg import new_session

        # TODO Update for local development

        hub_dir = get_dir()
        # model_dir = os.path.join(hub_dir, "checkpoints")
        model_dir = os.getcwd()
        # os.environ["U2NET_HOME"] = model_dir
        # os.environ["U2NET_HOME"] = os.getcwd()
        os.environ["U2NET_HOME"] = '/tmp/'

        logger.info(f"Load remove model from: {model_dir}")
        self.session = new_session(model_name="u2net")

    def __call__(self, rgb_np_img, files=None, form=None):
        bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
        return self.forward(bgr_np_img)

    def forward(self, bgr_np_img) -> np.ndarray:
        from rembg import remove
        output = remove(bgr_np_img, session=self.session)
        return output

    def check_dep(self):
        try:
            import rembg
        except ImportError:
            return (
                "RemoveBG is not installed, please install it first. pip install rembg"
            )