Spaces:
Sleeping
Sleeping
from graphbook import steps | |
from transformers import AutoModelForImageSegmentation | |
import torchvision.transforms.functional as F | |
import torch.nn.functional | |
import torch | |
import os | |
import os.path as osp | |
class RemoveBackground(steps.BatchStep): | |
""" | |
Executes background removal on a batch of images using the provided model. | |
Args: | |
batch_size (int): The batch size (number of images) given to the background removal model at a time | |
model (resource): The model to use for background removal. Use the resource `BackgroundRemoval/RMBGModel` to load the model | |
output_dir (str): The directory to save the output images to | |
""" | |
RequiresInput = True | |
Outputs = ["out"] | |
Parameters = { | |
"batch_size": { | |
"type": "number", | |
"description": "The batch size for background removal.", | |
"default": 8, | |
}, | |
"model": { | |
"type": "resource", | |
"description": "The model to use for background removal.", | |
}, | |
"output_dir": { | |
"type": "string", | |
"description": "The directory to save the output images.", | |
"default": "/tmp/output", | |
}, | |
} | |
Category = "BackgroundRemoval" | |
def __init__(self, batch_size: int, model: AutoModelForImageSegmentation, output_dir: str): | |
super().__init__(batch_size, "image") | |
self.model = model | |
self.output_dir = output_dir | |
if self.output_dir: | |
os.makedirs(self.output_dir, exist_ok=True) | |
def load_fn(self, item: dict) -> torch.Tensor: | |
return load_image_as_tensor(item) | |
def dump_fn(self, t: torch.Tensor, output_path: str): | |
save_image_to_disk(t, output_path) | |
def on_item_batch(self, tensors, items, notes): | |
def get_output_path(note): | |
label = "" | |
if note["labels"] == 0: | |
label = "cat" | |
else: | |
label = "dog" | |
return osp.join(self.output_dir, label, f"{note['image']['shm_id']}.jpg") | |
og_sizes = [t.shape[1:] for t in tensors] | |
images = [ | |
F.normalize( | |
torch.nn.functional.interpolate( | |
torch.unsqueeze(image, 0), size=[1024, 1024], mode="bilinear" | |
), | |
[0.5, 0.5, 0.5], | |
[1.0, 1.0, 1.0], | |
) | |
for image in tensors | |
] | |
images = torch.stack(images).to("cuda") | |
images = torch.squeeze(images, 1) | |
tup = self.model(images) | |
result = tup[0][0] | |
ma = torch.max(result) | |
mi = torch.min(result) | |
result = (result - mi) / (ma - mi) | |
resized = [ | |
torch.squeeze( | |
torch.nn.functional.interpolate( | |
torch.unsqueeze(image, 0), size=og_size, mode="bilinear" | |
), | |
0, | |
).cpu() | |
for image, og_size in zip(result, og_sizes) | |
] | |
paths = [ | |
get_output_path(note) for note in notes | |
] | |
removed_bg = list(zip(resized, paths)) | |
for path, note in zip(paths, notes): | |
masks = note["masks"] | |
if masks is None: | |
masks = [] | |
masks.append({"value": path, "type": "image"}) | |
note["masks"] = masks | |
return removed_bg | |
def load_image_as_tensor(item: dict) -> torch.Tensor: | |
im = item["value"] | |
image = F.to_tensor(im) | |
if image.shape[0] == 1: | |
image = image.repeat(3, 1, 1) | |
elif image.shape[0] == 4: | |
image = image[:3] | |
return image | |
def save_image_to_disk(t: torch.Tensor, output_path: str): | |
dir = osp.dirname(output_path) | |
os.makedirs(dir, exist_ok=True) | |
img = F.to_pil_image(t) | |
img.save(output_path) | |