lorocksUMD's picture
Upload 32 files
099ac14 verified
import json
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torchmetrics.functional.classification import binary_average_precision
from tqdm import tqdm
from constants import *
from DenseAV.denseav.shared import unnorm, remove_axes
def prep_heatmap(sims, masks, h, w):
masks = masks.to(torch.float32)
hm = torch.einsum("bhwt,bt->bhw", sims, masks) / masks.sum(-1).reshape(-1, 1, 1)
hm -= hm.min()
hm /= hm.max()
return F.interpolate(hm.unsqueeze(1), (h, w), mode="bilinear").squeeze(1)
def iou(prediction, target):
prediction = prediction > 0.0
target = target > 0.5
intersection = torch.logical_and(prediction, target).sum().float()
union = torch.logical_or(prediction, target).sum().float()
if union == 0:
return 1.0
return (intersection / union).item() # Convert to Python scalar
def multi_iou(prediction, target, k=20):
prediction = torch.tensor(prediction)
target = torch.tensor(target)
target = target > 0.5
thresholds = torch.linspace(prediction.min(), prediction.max(), k)
hard_pred = prediction.unsqueeze(0) > thresholds.reshape(k, 1, 1, 1, 1)
target = torch.broadcast_to(target.unsqueeze(0), hard_pred.shape)
# Calculate IoU for each threshold
intersection = torch.logical_and(hard_pred, target).sum(dim=(1, 2, 3, 4)).float()
union = torch.logical_or(hard_pred, target).sum(dim=(1, 2, 3, 4)).float()
union = torch.where(union == 0, torch.tensor(1.0), union) # Avoid division by zero
iou_scores = intersection / union
# Find the best IoU and corresponding threshold
best_iou, best_idx = torch.max(iou_scores, dim=0)
# best_threshold = thresholds[best_idx]
# print(best_threshold)
return best_iou # , best_threshold.item()
def get_paired_heatmaps(
model,
results,
class_ids,
timing,
class_names=None):
sims = model.sim_agg.get_pairwise_sims(
results,
raw=False,
agg_sim=False,
agg_heads=True
).squeeze(1).mean(-2)
prompt_classes = torch.tensor(list(class_ids))
gt = results["semseg"] == prompt_classes.reshape(-1, 1, 1)
basic_masks = results[AUDIO_MASK] # BxT
_, fullh, fullw = gt.shape
basic_heatmaps = prep_heatmap(sims, basic_masks, fullh, fullw)
if timing is not None:
prompt_timing = np.array(list(timing))
raw_timing = torch.tensor([json.loads(t) for t in prompt_timing])
timing = torch.clone(raw_timing)
timing[:, 0] -= .2
timing[:, 1] += .2
total_length = (results['total_length'] / 16000)[0]
fracs = timing / total_length
bounds = basic_masks.shape[1] * fracs
bounds[:, 0] = bounds[:, 0].floor()
bounds[:, 1] = bounds[:, 1].ceil()
bounds = bounds.to(torch.int64)
advanced_masks = (F.one_hot(bounds, basic_masks.shape[1]).cumsum(-1).sum(-2) == 1).to(basic_masks)
advanced_heatmaps = prep_heatmap(sims, advanced_masks, fullh, fullw)
metrics = defaultdict(list)
unique_classes = torch.unique(prompt_classes)
should_plot = class_names is not None
if should_plot:
prompt_names = np.array(list(class_names))
for prompt_class in tqdm(unique_classes):
subset = torch.where(prompt_classes == prompt_class)[0]
gt_subset = gt[subset]
basic_subset = basic_heatmaps[subset]
metrics["basic_ap"].append(binary_average_precision(basic_subset.flatten(), gt_subset.flatten()))
metrics["basic_iou"].append(multi_iou(basic_subset.flatten(), gt_subset.flatten()))
if timing is not None:
advanced_subset = advanced_heatmaps[subset]
metrics["advanced_ap"].append(binary_average_precision(advanced_subset.flatten(), gt_subset.flatten()))
metrics["advanced_iou"].append(multi_iou(advanced_subset.flatten(), gt_subset.flatten()))
if should_plot:
prompt_class_subset = prompt_classes[subset]
name_subset = prompt_names[subset]
print(prompt_class, name_subset, prompt_class_subset)
n_imgs = min(len(subset), 5)
if n_imgs > 1:
fig, axes = plt.subplots(n_imgs, 5, figsize=(4 * 5, n_imgs * 3))
frame_subset = unnorm(results[IMAGE_INPUT][subset].squeeze(1)).permute(0, 2, 3, 1)
semseg_subset = results["semseg"][subset]
for img_num in range(n_imgs):
axes[img_num, 0].imshow(frame_subset[img_num])
axes[img_num, 1].imshow(basic_subset[img_num])
axes[img_num, 2].imshow(advanced_subset[img_num])
axes[img_num, 3].imshow(gt_subset[img_num])
axes[img_num, 4].imshow(semseg_subset[img_num], cmap="tab20", interpolation='none')
axes[0, 0].set_title("Image")
class_name = name_subset[0].split(",")[0]
axes[0, 1].set_title(f"{class_name} Basic Heatmap")
axes[0, 2].set_title(f"{class_name} Advanced Heatmap")
axes[0, 3].set_title("True Mask")
axes[0, 4].set_title("True Seg")
remove_axes(axes)
plt.tight_layout()
plt.show()
return metrics, unique_classes