File size: 878 Bytes
e7f01f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ac5852
 
 
 
 
 
 
 
e7f01f9
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from graphbook import param, resource
from transformers import AutoModelForImageSegmentation


@resource("BackgroundRemoval/RMBGModel")
@param(
    "model_name",
    "string",
    description="The name of the RMBG model.",
    default="briaai/RMBG-1.4",
)
@param(
    "use_cuda",
    "boolean",
    description="Whether to use CUDA acceleration.",
    default=True,
)
def rmbg_model(self):
    """
    Loads the background removal model from the Hugging Face model hub.
    By default, we are using [briaai/RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4).
    
    Args:
        model_name (str): The name of the RMBG model.
        use_cuda (bool): Whether to use CUDA acceleration.
    """
    model = AutoModelForImageSegmentation.from_pretrained(
        self.model_name, trust_remote_code=True
    )
    if self.use_cuda:
        return model.to("cuda")
    return model