File size: 5,499 Bytes
e6d4b46
 
 
 
 
 
 
 
 
 
 
099ac14
e6d4b46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import json
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torchmetrics.functional.classification import binary_average_precision
from tqdm import tqdm

from constants import *
from DenseAV.denseav.shared import unnorm, remove_axes


def prep_heatmap(sims, masks, h, w):
    masks = masks.to(torch.float32)
    hm = torch.einsum("bhwt,bt->bhw", sims, masks) / masks.sum(-1).reshape(-1, 1, 1)
    hm -= hm.min()
    hm /= hm.max()
    return F.interpolate(hm.unsqueeze(1), (h, w), mode="bilinear").squeeze(1)


def iou(prediction, target):
    prediction = prediction > 0.0
    target = target > 0.5
    intersection = torch.logical_and(prediction, target).sum().float()
    union = torch.logical_or(prediction, target).sum().float()
    if union == 0:
        return 1.0
    return (intersection / union).item()  # Convert to Python scalar


def multi_iou(prediction, target, k=20):
    prediction = torch.tensor(prediction)
    target = torch.tensor(target)
    target = target > 0.5

    thresholds = torch.linspace(prediction.min(), prediction.max(), k)
    hard_pred = prediction.unsqueeze(0) > thresholds.reshape(k, 1, 1, 1, 1)
    target = torch.broadcast_to(target.unsqueeze(0), hard_pred.shape)

    # Calculate IoU for each threshold
    intersection = torch.logical_and(hard_pred, target).sum(dim=(1, 2, 3, 4)).float()
    union = torch.logical_or(hard_pred, target).sum(dim=(1, 2, 3, 4)).float()
    union = torch.where(union == 0, torch.tensor(1.0), union)  # Avoid division by zero
    iou_scores = intersection / union

    # Find the best IoU and corresponding threshold
    best_iou, best_idx = torch.max(iou_scores, dim=0)
    # best_threshold = thresholds[best_idx]
    # print(best_threshold)
    return best_iou  # , best_threshold.item()


def get_paired_heatmaps(

        model,

        results,

        class_ids,

        timing,

        class_names=None):
    sims = model.sim_agg.get_pairwise_sims(
        results,
        raw=False,
        agg_sim=False,
        agg_heads=True
    ).squeeze(1).mean(-2)

    prompt_classes = torch.tensor(list(class_ids))
    gt = results["semseg"] == prompt_classes.reshape(-1, 1, 1)
    basic_masks = results[AUDIO_MASK]  # BxT
    _, fullh, fullw = gt.shape
    basic_heatmaps = prep_heatmap(sims, basic_masks, fullh, fullw)

    if timing is not None:
        prompt_timing = np.array(list(timing))
        raw_timing = torch.tensor([json.loads(t) for t in prompt_timing])
        timing = torch.clone(raw_timing)
        timing[:, 0] -= .2
        timing[:, 1] += .2
        total_length = (results['total_length'] / 16000)[0]
        fracs = timing / total_length
        bounds = basic_masks.shape[1] * fracs
        bounds[:, 0] = bounds[:, 0].floor()
        bounds[:, 1] = bounds[:, 1].ceil()
        bounds = bounds.to(torch.int64)
        advanced_masks = (F.one_hot(bounds, basic_masks.shape[1]).cumsum(-1).sum(-2) == 1).to(basic_masks)
        advanced_heatmaps = prep_heatmap(sims, advanced_masks, fullh, fullw)

    metrics = defaultdict(list)
    unique_classes = torch.unique(prompt_classes)

    should_plot = class_names is not None

    if should_plot:
        prompt_names = np.array(list(class_names))

    for prompt_class in tqdm(unique_classes):
        subset = torch.where(prompt_classes == prompt_class)[0]
        gt_subset = gt[subset]
        basic_subset = basic_heatmaps[subset]
        metrics["basic_ap"].append(binary_average_precision(basic_subset.flatten(), gt_subset.flatten()))
        metrics["basic_iou"].append(multi_iou(basic_subset.flatten(), gt_subset.flatten()))

        if timing is not None:
            advanced_subset = advanced_heatmaps[subset]
            metrics["advanced_ap"].append(binary_average_precision(advanced_subset.flatten(), gt_subset.flatten()))
            metrics["advanced_iou"].append(multi_iou(advanced_subset.flatten(), gt_subset.flatten()))

        if should_plot:
            prompt_class_subset = prompt_classes[subset]
            name_subset = prompt_names[subset]
            print(prompt_class, name_subset, prompt_class_subset)
            n_imgs = min(len(subset), 5)
            if n_imgs > 1:
                fig, axes = plt.subplots(n_imgs, 5, figsize=(4 * 5, n_imgs * 3))
                frame_subset = unnorm(results[IMAGE_INPUT][subset].squeeze(1)).permute(0, 2, 3, 1)
                semseg_subset = results["semseg"][subset]
                for img_num in range(n_imgs):
                    axes[img_num, 0].imshow(frame_subset[img_num])
                    axes[img_num, 1].imshow(basic_subset[img_num])
                    axes[img_num, 2].imshow(advanced_subset[img_num])
                    axes[img_num, 3].imshow(gt_subset[img_num])
                    axes[img_num, 4].imshow(semseg_subset[img_num], cmap="tab20", interpolation='none')

                axes[0, 0].set_title("Image")
                class_name = name_subset[0].split(",")[0]
                axes[0, 1].set_title(f"{class_name} Basic Heatmap")
                axes[0, 2].set_title(f"{class_name} Advanced Heatmap")
                axes[0, 3].set_title("True Mask")
                axes[0, 4].set_title("True Seg")
                remove_axes(axes)
                plt.tight_layout()
                plt.show()

    return metrics, unique_classes