File size: 3,374 Bytes
e7f01f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from graphbook import steps
from transformers import AutoModelForImageSegmentation
import torchvision.transforms.functional as F
import torch.nn.functional
import torch
import os
import os.path as osp


class RemoveBackground(steps.BatchStep):
    RequiresInput = True
    Outputs = ["out"]
    Parameters = {
        "batch_size": {
            "type": "number",
            "description": "The batch size for background removal.",
            "default": 8,
        },
        "model": {
            "type": "resource",
            "description": "The model to use for background removal.",
        },
        "output_dir": {
            "type": "string",
            "description": "The directory to save the output images.",
            "default": "/tmp/output",
        },
    }
    Category = "BackgroundRemoval"

    def __init__(self, batch_size: int, model: AutoModelForImageSegmentation, output_dir: str):
        super().__init__(batch_size, "image")
        self.model = model
        self.output_dir = output_dir
        if self.output_dir:
            os.makedirs(self.output_dir, exist_ok=True)

    def load_fn(self, item: dict) -> torch.Tensor:
        return load_image_as_tensor(item)

    def dump_fn(self, t: torch.Tensor, output_path: str):
        save_image_to_disk(t, output_path)

    @torch.no_grad()
    def on_item_batch(self, tensors, items, notes):
        def get_output_path(note):
            label = ""
            if note["labels"] == 0:
                label = "cat"
            else:
                label = "dog"
            return osp.join(self.output_dir, label, f"{note['image']['shm_id']}.jpg")

        og_sizes = [t.shape[1:] for t in tensors]

        images = [
            F.normalize(
                torch.nn.functional.interpolate(
                    torch.unsqueeze(image, 0), size=[1024, 1024], mode="bilinear"
                ),
                [0.5, 0.5, 0.5],
                [1.0, 1.0, 1.0],
            )
            for image in tensors
        ]
        images = torch.stack(images).to("cuda")
        images = torch.squeeze(images, 1)
        tup = self.model(images)
        result = tup[0][0]
        ma = torch.max(result)
        mi = torch.min(result)
        result = (result - mi) / (ma - mi)
        resized = [
            torch.squeeze(
                torch.nn.functional.interpolate(
                    torch.unsqueeze(image, 0), size=og_size, mode="bilinear"
                ),
                0,
            ).cpu()
            for image, og_size in zip(result, og_sizes)
        ]
        paths = [
            get_output_path(note) for note in notes
        ]
        removed_bg = list(zip(resized, paths))
        for path, note in zip(paths, notes):
            masks = note["masks"]
            if masks is None:
                masks = []
            masks.append({"value": path, "type": "image"})
            note["masks"] = masks

        return removed_bg


def load_image_as_tensor(item: dict) -> torch.Tensor:
    im = item["value"]
    image = F.to_tensor(im)
    if image.shape[0] == 1:
        image = image.repeat(3, 1, 1)
    elif image.shape[0] == 4:
        image = image[:3]

    return image


def save_image_to_disk(t: torch.Tensor, output_path: str):
    dir = osp.dirname(output_path)
    os.makedirs(dir, exist_ok=True)
    img = F.to_pil_image(t)
    img.save(output_path)