rsamf's picture
Adding initial files
e7f01f9
raw
history blame
574 Bytes
from graphbook import param, resource
from transformers import AutoModelForImageSegmentation
@resource("BackgroundRemoval/RMBGModel")
@param(
"model_name",
"string",
description="The name of the RMBG model.",
default="briaai/RMBG-1.4",
)
@param(
"use_cuda",
"boolean",
description="Whether to use CUDA acceleration.",
default=True,
)
def rmbg_model(self):
model = AutoModelForImageSegmentation.from_pretrained(
self.model_name, trust_remote_code=True
)
if self.use_cuda:
return model.to("cuda")
return model