File size: 3,793 Bytes
e7f01f9
 
 
 
 
 
 
 
 
 
3ac5852
 
 
 
 
 
 
 
e7f01f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from graphbook import steps
from transformers import AutoModelForImageSegmentation
import torchvision.transforms.functional as F
import torch.nn.functional
import torch
import os
import os.path as osp


class RemoveBackground(steps.BatchStep):
    """
    Executes background removal on a batch of images using the provided model.
    
    Args:
        batch_size (int): The batch size (number of images) given to the background removal model at a time
        model (resource): The model to use for background removal. Use the resource `BackgroundRemoval/RMBGModel` to load the model
        output_dir (str): The directory to save the output images to
    """
    RequiresInput = True
    Outputs = ["out"]
    Parameters = {
        "batch_size": {
            "type": "number",
            "description": "The batch size for background removal.",
            "default": 8,
        },
        "model": {
            "type": "resource",
            "description": "The model to use for background removal.",
        },
        "output_dir": {
            "type": "string",
            "description": "The directory to save the output images.",
            "default": "/tmp/output",
        },
    }
    Category = "BackgroundRemoval"

    def __init__(self, batch_size: int, model: AutoModelForImageSegmentation, output_dir: str):
        super().__init__(batch_size, "image")
        self.model = model
        self.output_dir = output_dir
        if self.output_dir:
            os.makedirs(self.output_dir, exist_ok=True)

    def load_fn(self, item: dict) -> torch.Tensor:
        return load_image_as_tensor(item)

    def dump_fn(self, t: torch.Tensor, output_path: str):
        save_image_to_disk(t, output_path)

    @torch.no_grad()
    def on_item_batch(self, tensors, items, notes):
        def get_output_path(note):
            label = ""
            if note["labels"] == 0:
                label = "cat"
            else:
                label = "dog"
            return osp.join(self.output_dir, label, f"{note['image']['shm_id']}.jpg")

        og_sizes = [t.shape[1:] for t in tensors]

        images = [
            F.normalize(
                torch.nn.functional.interpolate(
                    torch.unsqueeze(image, 0), size=[1024, 1024], mode="bilinear"
                ),
                [0.5, 0.5, 0.5],
                [1.0, 1.0, 1.0],
            )
            for image in tensors
        ]
        images = torch.stack(images).to("cuda")
        images = torch.squeeze(images, 1)
        tup = self.model(images)
        result = tup[0][0]
        ma = torch.max(result)
        mi = torch.min(result)
        result = (result - mi) / (ma - mi)
        resized = [
            torch.squeeze(
                torch.nn.functional.interpolate(
                    torch.unsqueeze(image, 0), size=og_size, mode="bilinear"
                ),
                0,
            ).cpu()
            for image, og_size in zip(result, og_sizes)
        ]
        paths = [
            get_output_path(note) for note in notes
        ]
        removed_bg = list(zip(resized, paths))
        for path, note in zip(paths, notes):
            masks = note["masks"]
            if masks is None:
                masks = []
            masks.append({"value": path, "type": "image"})
            note["masks"] = masks

        return removed_bg


def load_image_as_tensor(item: dict) -> torch.Tensor:
    im = item["value"]
    image = F.to_tensor(im)
    if image.shape[0] == 1:
        image = image.repeat(3, 1, 1)
    elif image.shape[0] == 4:
        image = image[:3]

    return image


def save_image_to_disk(t: torch.Tensor, output_path: str):
    dir = osp.dirname(output_path)
    os.makedirs(dir, exist_ok=True)
    img = F.to_pil_image(t)
    img.save(output_path)