from graphbook import steps from transformers import AutoModelForImageSegmentation import torchvision.transforms.functional as F import torch.nn.functional import torch import os import os.path as osp class RemoveBackground(steps.BatchStep): """ Executes background removal on a batch of images using the provided model. Args: batch_size (int): The batch size (number of images) given to the background removal model at a time model (resource): The model to use for background removal. Use the resource `BackgroundRemoval/RMBGModel` to load the model output_dir (str): The directory to save the output images to """ RequiresInput = True Outputs = ["out"] Parameters = { "batch_size": { "type": "number", "description": "The batch size for background removal.", "default": 8, }, "model": { "type": "resource", "description": "The model to use for background removal.", }, "output_dir": { "type": "string", "description": "The directory to save the output images.", "default": "/tmp/output", }, } Category = "BackgroundRemoval" def __init__(self, batch_size: int, model: AutoModelForImageSegmentation, output_dir: str): super().__init__(batch_size, "image") self.model = model self.output_dir = output_dir if self.output_dir: os.makedirs(self.output_dir, exist_ok=True) def load_fn(self, item: dict) -> torch.Tensor: return load_image_as_tensor(item) def dump_fn(self, t: torch.Tensor, output_path: str): save_image_to_disk(t, output_path) @torch.no_grad() def on_item_batch(self, tensors, items, notes): def get_output_path(note): label = "" if note["labels"] == 0: label = "cat" else: label = "dog" return osp.join(self.output_dir, label, f"{note['image']['shm_id']}.jpg") og_sizes = [t.shape[1:] for t in tensors] images = [ F.normalize( torch.nn.functional.interpolate( torch.unsqueeze(image, 0), size=[1024, 1024], mode="bilinear" ), [0.5, 0.5, 0.5], [1.0, 1.0, 1.0], ) for image in tensors ] images = torch.stack(images).to("cuda") images = torch.squeeze(images, 1) tup = self.model(images) result = tup[0][0] ma = torch.max(result) mi = torch.min(result) result = (result - mi) / (ma - mi) resized = [ torch.squeeze( torch.nn.functional.interpolate( torch.unsqueeze(image, 0), size=og_size, mode="bilinear" ), 0, ).cpu() for image, og_size in zip(result, og_sizes) ] paths = [ get_output_path(note) for note in notes ] removed_bg = list(zip(resized, paths)) for path, note in zip(paths, notes): masks = note["masks"] if masks is None: masks = [] masks.append({"value": path, "type": "image"}) note["masks"] = masks return removed_bg def load_image_as_tensor(item: dict) -> torch.Tensor: im = item["value"] image = F.to_tensor(im) if image.shape[0] == 1: image = image.repeat(3, 1, 1) elif image.shape[0] == 4: image = image[:3] return image def save_image_to_disk(t: torch.Tensor, output_path: str): dir = osp.dirname(output_path) os.makedirs(dir, exist_ok=True) img = F.to_pil_image(t) img.save(output_path)