Spaces:
Running
Running
File size: 5,499 Bytes
e6d4b46 099ac14 e6d4b46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import json
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torchmetrics.functional.classification import binary_average_precision
from tqdm import tqdm
from constants import *
from DenseAV.denseav.shared import unnorm, remove_axes
def prep_heatmap(sims, masks, h, w):
masks = masks.to(torch.float32)
hm = torch.einsum("bhwt,bt->bhw", sims, masks) / masks.sum(-1).reshape(-1, 1, 1)
hm -= hm.min()
hm /= hm.max()
return F.interpolate(hm.unsqueeze(1), (h, w), mode="bilinear").squeeze(1)
def iou(prediction, target):
prediction = prediction > 0.0
target = target > 0.5
intersection = torch.logical_and(prediction, target).sum().float()
union = torch.logical_or(prediction, target).sum().float()
if union == 0:
return 1.0
return (intersection / union).item() # Convert to Python scalar
def multi_iou(prediction, target, k=20):
prediction = torch.tensor(prediction)
target = torch.tensor(target)
target = target > 0.5
thresholds = torch.linspace(prediction.min(), prediction.max(), k)
hard_pred = prediction.unsqueeze(0) > thresholds.reshape(k, 1, 1, 1, 1)
target = torch.broadcast_to(target.unsqueeze(0), hard_pred.shape)
# Calculate IoU for each threshold
intersection = torch.logical_and(hard_pred, target).sum(dim=(1, 2, 3, 4)).float()
union = torch.logical_or(hard_pred, target).sum(dim=(1, 2, 3, 4)).float()
union = torch.where(union == 0, torch.tensor(1.0), union) # Avoid division by zero
iou_scores = intersection / union
# Find the best IoU and corresponding threshold
best_iou, best_idx = torch.max(iou_scores, dim=0)
# best_threshold = thresholds[best_idx]
# print(best_threshold)
return best_iou # , best_threshold.item()
def get_paired_heatmaps(
model,
results,
class_ids,
timing,
class_names=None):
sims = model.sim_agg.get_pairwise_sims(
results,
raw=False,
agg_sim=False,
agg_heads=True
).squeeze(1).mean(-2)
prompt_classes = torch.tensor(list(class_ids))
gt = results["semseg"] == prompt_classes.reshape(-1, 1, 1)
basic_masks = results[AUDIO_MASK] # BxT
_, fullh, fullw = gt.shape
basic_heatmaps = prep_heatmap(sims, basic_masks, fullh, fullw)
if timing is not None:
prompt_timing = np.array(list(timing))
raw_timing = torch.tensor([json.loads(t) for t in prompt_timing])
timing = torch.clone(raw_timing)
timing[:, 0] -= .2
timing[:, 1] += .2
total_length = (results['total_length'] / 16000)[0]
fracs = timing / total_length
bounds = basic_masks.shape[1] * fracs
bounds[:, 0] = bounds[:, 0].floor()
bounds[:, 1] = bounds[:, 1].ceil()
bounds = bounds.to(torch.int64)
advanced_masks = (F.one_hot(bounds, basic_masks.shape[1]).cumsum(-1).sum(-2) == 1).to(basic_masks)
advanced_heatmaps = prep_heatmap(sims, advanced_masks, fullh, fullw)
metrics = defaultdict(list)
unique_classes = torch.unique(prompt_classes)
should_plot = class_names is not None
if should_plot:
prompt_names = np.array(list(class_names))
for prompt_class in tqdm(unique_classes):
subset = torch.where(prompt_classes == prompt_class)[0]
gt_subset = gt[subset]
basic_subset = basic_heatmaps[subset]
metrics["basic_ap"].append(binary_average_precision(basic_subset.flatten(), gt_subset.flatten()))
metrics["basic_iou"].append(multi_iou(basic_subset.flatten(), gt_subset.flatten()))
if timing is not None:
advanced_subset = advanced_heatmaps[subset]
metrics["advanced_ap"].append(binary_average_precision(advanced_subset.flatten(), gt_subset.flatten()))
metrics["advanced_iou"].append(multi_iou(advanced_subset.flatten(), gt_subset.flatten()))
if should_plot:
prompt_class_subset = prompt_classes[subset]
name_subset = prompt_names[subset]
print(prompt_class, name_subset, prompt_class_subset)
n_imgs = min(len(subset), 5)
if n_imgs > 1:
fig, axes = plt.subplots(n_imgs, 5, figsize=(4 * 5, n_imgs * 3))
frame_subset = unnorm(results[IMAGE_INPUT][subset].squeeze(1)).permute(0, 2, 3, 1)
semseg_subset = results["semseg"][subset]
for img_num in range(n_imgs):
axes[img_num, 0].imshow(frame_subset[img_num])
axes[img_num, 1].imshow(basic_subset[img_num])
axes[img_num, 2].imshow(advanced_subset[img_num])
axes[img_num, 3].imshow(gt_subset[img_num])
axes[img_num, 4].imshow(semseg_subset[img_num], cmap="tab20", interpolation='none')
axes[0, 0].set_title("Image")
class_name = name_subset[0].split(",")[0]
axes[0, 1].set_title(f"{class_name} Basic Heatmap")
axes[0, 2].set_title(f"{class_name} Advanced Heatmap")
axes[0, 3].set_title("True Mask")
axes[0, 4].set_title("True Seg")
remove_axes(axes)
plt.tight_layout()
plt.show()
return metrics, unique_classes
|