rsamf's picture
Adding documentation
3ac5852
from graphbook import steps
from transformers import AutoModelForImageSegmentation
import torchvision.transforms.functional as F
import torch.nn.functional
import torch
import os
import os.path as osp
class RemoveBackground(steps.BatchStep):
"""
Executes background removal on a batch of images using the provided model.
Args:
batch_size (int): The batch size (number of images) given to the background removal model at a time
model (resource): The model to use for background removal. Use the resource `BackgroundRemoval/RMBGModel` to load the model
output_dir (str): The directory to save the output images to
"""
RequiresInput = True
Outputs = ["out"]
Parameters = {
"batch_size": {
"type": "number",
"description": "The batch size for background removal.",
"default": 8,
},
"model": {
"type": "resource",
"description": "The model to use for background removal.",
},
"output_dir": {
"type": "string",
"description": "The directory to save the output images.",
"default": "/tmp/output",
},
}
Category = "BackgroundRemoval"
def __init__(self, batch_size: int, model: AutoModelForImageSegmentation, output_dir: str):
super().__init__(batch_size, "image")
self.model = model
self.output_dir = output_dir
if self.output_dir:
os.makedirs(self.output_dir, exist_ok=True)
def load_fn(self, item: dict) -> torch.Tensor:
return load_image_as_tensor(item)
def dump_fn(self, t: torch.Tensor, output_path: str):
save_image_to_disk(t, output_path)
@torch.no_grad()
def on_item_batch(self, tensors, items, notes):
def get_output_path(note):
label = ""
if note["labels"] == 0:
label = "cat"
else:
label = "dog"
return osp.join(self.output_dir, label, f"{note['image']['shm_id']}.jpg")
og_sizes = [t.shape[1:] for t in tensors]
images = [
F.normalize(
torch.nn.functional.interpolate(
torch.unsqueeze(image, 0), size=[1024, 1024], mode="bilinear"
),
[0.5, 0.5, 0.5],
[1.0, 1.0, 1.0],
)
for image in tensors
]
images = torch.stack(images).to("cuda")
images = torch.squeeze(images, 1)
tup = self.model(images)
result = tup[0][0]
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
resized = [
torch.squeeze(
torch.nn.functional.interpolate(
torch.unsqueeze(image, 0), size=og_size, mode="bilinear"
),
0,
).cpu()
for image, og_size in zip(result, og_sizes)
]
paths = [
get_output_path(note) for note in notes
]
removed_bg = list(zip(resized, paths))
for path, note in zip(paths, notes):
masks = note["masks"]
if masks is None:
masks = []
masks.append({"value": path, "type": "image"})
note["masks"] = masks
return removed_bg
def load_image_as_tensor(item: dict) -> torch.Tensor:
im = item["value"]
image = F.to_tensor(im)
if image.shape[0] == 1:
image = image.repeat(3, 1, 1)
elif image.shape[0] == 4:
image = image[:3]
return image
def save_image_to_disk(t: torch.Tensor, output_path: str):
dir = osp.dirname(output_path)
os.makedirs(dir, exist_ok=True)
img = F.to_pil_image(t)
img.save(output_path)