rsamf's picture
Adding documentation
3ac5852
raw
history blame contribute delete
878 Bytes
from graphbook import param, resource
from transformers import AutoModelForImageSegmentation
@resource("BackgroundRemoval/RMBGModel")
@param(
"model_name",
"string",
description="The name of the RMBG model.",
default="briaai/RMBG-1.4",
)
@param(
"use_cuda",
"boolean",
description="Whether to use CUDA acceleration.",
default=True,
)
def rmbg_model(self):
"""
Loads the background removal model from the Hugging Face model hub.
By default, we are using [briaai/RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4).
Args:
model_name (str): The name of the RMBG model.
use_cuda (bool): Whether to use CUDA acceleration.
"""
model = AutoModelForImageSegmentation.from_pretrained(
self.model_name, trust_remote_code=True
)
if self.use_cuda:
return model.to("cuda")
return model