File size: 13,972 Bytes
cdd6c01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
"""Core part of LaDeco v2



Example usage:

>>> from core import Ladeco

>>> from PIL import Image

>>> from pathlib import Path

>>>

>>> # predict

>>> ldc = Ladeco()

>>> imgs = (thing for thing in Path("example").glob("*.jpg"))

>>> out = ldc.predict(imgs)

>>>

>>> # output - visualization

>>> segs = out.visualize(level=2)

>>> segs[0].image.show()

>>>

>>> # output - element area

>>> area = out.area()

>>> area[0]

{"fid": "example/.jpg", "l1_nature": 0.673, "l1_man_made": 0.241, ...}

"""
from matplotlib.patches import Rectangle
from pathlib import Path
from PIL import Image
from transformers import AutoModelForUniversalSegmentation, AutoProcessor
import math
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch
from functools import lru_cache
from matplotlib.figure import Figure
import numpy.typing as npt
from typing import Iterable, NamedTuple, Generator
from tqdm import tqdm


class LadecoVisualization(NamedTuple):
    filename: str
    image: Figure


class Ladeco:

    def __init__(self,

        model_name: str = "shi-labs/oneformer_ade20k_swin_large",

        area_threshold: float = 0.01,

        device: str | None = None,

    ):
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device

        self.processor = AutoProcessor.from_pretrained(model_name)
        self.model = AutoModelForUniversalSegmentation.from_pretrained(model_name).to(self.device)

        self.area_threshold = area_threshold

        self.ade20k_labels = {
            name.strip(): int(idx)
            for name, idx in self.model.config.label2id.items()
        }
        self.ladeco2ade20k: dict[str, tuple[int]] = _get_ladeco_labels(self.ade20k_labels)

    def predict(

            self, image_paths: str | Path | Iterable[str | Path], show_progress: bool = False

        ) -> "LadecoOutput":
        if isinstance(image_paths, (str, Path)):
            imgpaths = [image_paths]
        else:
            imgpaths = list(image_paths)

        images = (
            Image.open(img_path).convert("RGB")
            for img_path in imgpaths
        )

        # batch inference functionality of OneFormer is broken
        masks: list[torch.Tensor] = []
        for img in tqdm(images, total=len(imgpaths), desc="Segmenting", disable=not show_progress):
            samples = self.processor(
                images=img, task_inputs=["semantic"], return_tensors="pt"
            ).to(self.device)

            with torch.no_grad():
                outputs = self.model(**samples)

            masks.append(
                self.processor.post_process_semantic_segmentation(outputs)[0]
            )

        return LadecoOutput(imgpaths, masks, self.ladeco2ade20k, self.area_threshold)


class LadecoOutput:

    def __init__(

        self,

        filenames: list[str | Path],

        masks: torch.Tensor,

        ladeco2ade: dict[str, tuple[int]],

        threshold: float,

    ):
        self.filenames = filenames
        self.masks = masks
        self.ladeco2ade: dict[str, tuple[int]] = ladeco2ade
        self.ade2ladeco: dict[int, str] = {
            idx: label
            for label, indices in self.ladeco2ade.items()
            for idx in indices
        }
        self.threshold = threshold

    def visualize(self, level: int) -> list[LadecoVisualization]:
        return list(self.ivisualize(level))

    def ivisualize(self, level: int) -> Generator[LadecoVisualization, None, None]:
        colormaps = self.color_map(level)
        labelnames = [name for name in self.ladeco2ade if name.startswith(f"l{level}")]

        for fname, mask in zip(self.filenames, self.masks):
            size = mask.shape + (3,)  # (H, W, RGB)
            vis = torch.zeros(size, dtype=torch.uint8)
            for name in labelnames:
                for idx in self.ladeco2ade[name]:
                    color = torch.tensor(colormaps[name] * 255, dtype=torch.uint8)
                    vis[mask == idx] = color

            with Image.open(fname) as img:
                target_size = img.size
            vis = Image.fromarray(vis.numpy(), mode="RGB").resize(target_size)

            fig, ax = plt.subplots()
            ax.imshow(vis)
            ax.axis('off')

            yield LadecoVisualization(filename=str(fname), image=fig)

    def area(self) -> list[dict[str, float | str]]:
        return list(self.iarea())

    def iarea(self) -> Generator[dict[str, float | str], None, None]:
        n_label_ADE20k = 150
        for filename, mask in zip(self.filenames, self.masks):
            ade_ratios = torch.tensor([(mask == i).count_nonzero() / mask.numel() for i in range(n_label_ADE20k)])
            #breakpoint()
            ldc_ratios: dict[str, float] = {
                label: round(ade_ratios[list(ade_indices)].sum().item(), 4)
                for label, ade_indices in self.ladeco2ade.items()
            }
            ldc_ratios: dict[str, float] = {
                label: 0 if ratio < self.threshold else ratio
                for label, ratio in ldc_ratios.items()
            }
            others = round(1 - ldc_ratios["l1_nature"] - ldc_ratios["l1_man_made"], 4)
            nfi = round(ldc_ratios["l1_nature"]/ (ldc_ratios["l1_nature"] + ldc_ratios.get("l1_man_made", 0) + 1e-6), 4)

            yield {
                "fid": str(filename), **ldc_ratios, "others": others, "LC_NFI": nfi,
            }

    def color_map(self, level: int) -> dict[str, npt.NDArray[np.float64]]:
        "returns {'label_name': (R, G, B), ...}, where (R, G, B) in range [0, 1]"
        labels = [
            name for name in self.ladeco2ade.keys() if name.startswith(f"l{level}")
        ]
        if len(labels) == 0:
            raise RuntimeError(
                f"LaDeco only has 4 levels in 1, 2, 3, 4. You assigned {level}."
            )
        colormap = mpl.colormaps["viridis"].resampled(len(labels)).colors[:, :-1]
        # [:, :-1]: discard alpha channel
        return {name: color for name, color in zip(labels, colormap)}

    def color_legend(self, level: int) -> Figure:
        colors = self.color_map(level)

        match level:
            case 1:
                ncols = 1
            case 2:
                ncols = 1
            case 3:
                ncols = 2
            case 4:
                ncols = 5

        cell_width = 212
        cell_height = 22
        swatch_width = 48
        margin = 12

        nrows = math.ceil(len(colors) / ncols)

        width = cell_width * ncols + 2 * margin
        height = cell_height * nrows + 2 * margin
        dpi = 72

        fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
        fig.subplots_adjust(margin/width, margin/height,
                            (width-margin)/width, (height-margin*2)/height)
        ax.set_xlim(0, cell_width * ncols)
        ax.set_ylim(cell_height * (nrows-0.5), -cell_height/2.)
        ax.yaxis.set_visible(False)
        ax.xaxis.set_visible(False)
        ax.set_axis_off()

        for i, name in enumerate(colors):
            row = i % nrows
            col = i // nrows
            y = row * cell_height

            swatch_start_x = cell_width * col
            text_pos_x = cell_width * col + swatch_width + 7

            ax.text(text_pos_x, y, name, fontsize=14,
                    horizontalalignment='left',
                    verticalalignment='center')

            ax.add_patch(
                Rectangle(xy=(swatch_start_x, y-9), width=swatch_width,
                            height=18, facecolor=colors[name], edgecolor='0.7')
            )

            ax.set_title(f"LaDeco Color Legend - Level {level}")

        return fig
    

def _get_ladeco_labels(ade20k: dict[str, int]) -> dict[str, tuple[int]]:
    labels =  {
        # level 4 labels
        # under l3_architecture
        "l4_hovel": (ade20k["hovel, hut, hutch, shack, shanty"],),
        "l4_building": (ade20k["building"], ade20k["house"]),
        "l4_skyscraper": (ade20k["skyscraper"],),
        "l4_tower": (ade20k["tower"],),
        # under l3_archi_parts
        "l4_step": (ade20k["step, stair"],),
        "l4_canopy": (ade20k["awning, sunshade, sunblind"], ade20k["canopy"]),
        "l4_arcade": (ade20k["arcade machine"],),
        "l4_door": (ade20k["door"],),
        "l4_window": (ade20k["window"],),
        "l4_wall": (ade20k["wall"],),
        # under l3_roadway
        "l4_stairway": (ade20k["stairway, staircase"],),
        "l4_sidewalk": (ade20k["sidewalk, pavement"],),
        "l4_road": (ade20k["road, route"],),
        # under l3_furniture
        "l4_sculpture": (ade20k["sculpture"],),
        "l4_flag": (ade20k["flag"],),
        "l4_can": (ade20k["trash can"],),
        "l4_chair": (ade20k["chair"],),
        "l4_pot": (ade20k["pot"],),
        "l4_booth": (ade20k["booth"],),
        "l4_streetlight": (ade20k["street lamp"],),
        "l4_bench": (ade20k["bench"],),
        "l4_fence": (ade20k["fence"],),
        "l4_table": (ade20k["table"],),
        # under l3_vehicle
        "l4_bike": (ade20k["bicycle"],),
        "l4_motorbike": (ade20k["minibike, motorbike"],),
        "l4_van": (ade20k["van"],),
        "l4_truck": (ade20k["truck"],),
        "l4_bus": (ade20k["bus"],),
        "l4_car": (ade20k["car"],),
        # under l3_sign
        "l4_traffic_sign": (ade20k["traffic light"],),
        "l4_poster": (ade20k["poster, posting, placard, notice, bill, card"],),
        "l4_signboard": (ade20k["signboard, sign"],),
        # under l3_vert_land
        "l4_rock": (ade20k["rock, stone"],),
        "l4_hill": (ade20k["hill"],),
        "l4_mountain": (ade20k["mountain, mount"],),
        # under l3_hori_land
        "l4_ground": (ade20k["earth, ground"], ade20k["land, ground, soil"]),
        "l4_field": (ade20k["field"],),
        "l4_sand": (ade20k["sand"],),
        "l4_dirt": (ade20k["dirt track"],),
        "l4_path": (ade20k["path"],),
        # under l3_flower
        "l4_flower": (ade20k["flower"],),
        # under l3_grass
        "l4_grass": (ade20k["grass"],),
        # under l3_shrub
        "l4_flora": (ade20k["plant"],),
        # under l3_arbor
        "l4_tree": (ade20k["tree"],),
        "l4_palm": (ade20k["palm, palm tree"],),
        # under l3_hori_water
        "l4_lake": (ade20k["lake"],),
        "l4_pool": (ade20k["pool"],),
        "l4_river": (ade20k["river"],),
        "l4_sea": (ade20k["sea"],),
        "l4_water": (ade20k["water"],),
        # under l3_vert_water
        "l4_fountain": (ade20k["fountain"],),
        "l4_waterfall": (ade20k["falls"],),
        # under l3_human
        "l4_person": (ade20k["person"],),
        # under l3_animal
        "l4_animal": (ade20k["animal"],),
        # under l3_sky
        "l4_sky": (ade20k["sky"],),
    }
    labels = labels | {
        # level 3 labels
        # under l2_landform
        "l3_hori_land": labels["l4_ground"] + labels["l4_field"] + labels["l4_sand"] + labels["l4_dirt"] + labels["l4_path"],
        "l3_vert_land": labels["l4_mountain"] + labels["l4_hill"] + labels["l4_rock"],
        # under l2_vegetation
        "l3_woody_plant": labels["l4_tree"] + labels["l4_palm"] + labels["l4_flora"],
        "l3_herb_plant": labels["l4_grass"],
        "l3_flower": labels["l4_flower"],
        # under l2_water
        "l3_hori_water": labels["l4_water"] + labels["l4_sea"] + labels["l4_river"] + labels["l4_pool"] + labels["l4_lake"],
        "l3_vert_water": labels["l4_fountain"] + labels["l4_waterfall"],
        # under l2_bio
        "l3_human": labels["l4_person"],
        "l3_animal": labels["l4_animal"],
        # under l2_sky
        "l3_sky": labels["l4_sky"],
        # under l2_archi
        "l3_architecture": labels["l4_building"] + labels["l4_hovel"] + labels["l4_tower"] + labels["l4_skyscraper"],
        "l3_archi_parts": labels["l4_wall"] + labels["l4_window"] + labels["l4_door"] + labels["l4_arcade"] + labels["l4_canopy"] + labels["l4_step"],
        # under l2_street
        "l3_roadway": labels["l4_road"] + labels["l4_sidewalk"] + labels["l4_stairway"],
        "l3_furniture": labels["l4_table"] + labels["l4_chair"] + labels["l4_fence"] + labels["l4_bench"] + labels["l4_streetlight"] + labels["l4_booth"] + labels["l4_pot"] + labels["l4_can"] + labels["l4_flag"] + labels["l4_sculpture"],
        "l3_vehicle": labels["l4_car"] + labels["l4_bus"] + labels["l4_truck"] + labels["l4_van"] + labels["l4_motorbike"] + labels["l4_bike"],
        "l3_sign": labels["l4_signboard"] + labels["l4_poster"] + labels["l4_traffic_sign"],
    }
    labels = labels | {
        # level 2 labels
        # under l1_nature
        "l2_landform": labels["l3_hori_land"] + labels["l3_vert_land"],
        "l2_vegetation": labels["l3_woody_plant"] + labels["l3_herb_plant"] + labels["l3_flower"],
        "l2_water": labels["l3_hori_water"] + labels["l3_vert_water"],
        "l2_bio": labels["l3_human"] + labels["l3_animal"],
        "l2_sky": labels["l3_sky"],
        # under l1_man_made
        "l2_archi": labels["l3_architecture"] + labels["l3_archi_parts"],
        "l2_street": labels["l3_roadway"] + labels["l3_furniture"] + labels["l3_vehicle"] + labels["l3_sign"],
    }
    labels = labels | {
        # level 1 labels
        "l1_nature": labels["l2_landform"] + labels["l2_vegetation"] + labels["l2_water"] + labels["l2_bio"] + labels["l2_sky"],
        "l1_man_made": labels["l2_archi"] + labels["l2_street"],
    }
    return labels


if __name__ == "__main__":
    ldc = Ladeco()
    image = Path("images") / "canyon_3011_00002354.jpg"
    out = ldc.predict(image)