import json from collections import defaultdict import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F from torchmetrics.functional.classification import binary_average_precision from tqdm import tqdm from constants import * from DenseAV.denseav.shared import unnorm, remove_axes def prep_heatmap(sims, masks, h, w): masks = masks.to(torch.float32) hm = torch.einsum("bhwt,bt->bhw", sims, masks) / masks.sum(-1).reshape(-1, 1, 1) hm -= hm.min() hm /= hm.max() return F.interpolate(hm.unsqueeze(1), (h, w), mode="bilinear").squeeze(1) def iou(prediction, target): prediction = prediction > 0.0 target = target > 0.5 intersection = torch.logical_and(prediction, target).sum().float() union = torch.logical_or(prediction, target).sum().float() if union == 0: return 1.0 return (intersection / union).item() # Convert to Python scalar def multi_iou(prediction, target, k=20): prediction = torch.tensor(prediction) target = torch.tensor(target) target = target > 0.5 thresholds = torch.linspace(prediction.min(), prediction.max(), k) hard_pred = prediction.unsqueeze(0) > thresholds.reshape(k, 1, 1, 1, 1) target = torch.broadcast_to(target.unsqueeze(0), hard_pred.shape) # Calculate IoU for each threshold intersection = torch.logical_and(hard_pred, target).sum(dim=(1, 2, 3, 4)).float() union = torch.logical_or(hard_pred, target).sum(dim=(1, 2, 3, 4)).float() union = torch.where(union == 0, torch.tensor(1.0), union) # Avoid division by zero iou_scores = intersection / union # Find the best IoU and corresponding threshold best_iou, best_idx = torch.max(iou_scores, dim=0) # best_threshold = thresholds[best_idx] # print(best_threshold) return best_iou # , best_threshold.item() def get_paired_heatmaps( model, results, class_ids, timing, class_names=None): sims = model.sim_agg.get_pairwise_sims( results, raw=False, agg_sim=False, agg_heads=True ).squeeze(1).mean(-2) prompt_classes = torch.tensor(list(class_ids)) gt = results["semseg"] == prompt_classes.reshape(-1, 1, 1) basic_masks = results[AUDIO_MASK] # BxT _, fullh, fullw = gt.shape basic_heatmaps = prep_heatmap(sims, basic_masks, fullh, fullw) if timing is not None: prompt_timing = np.array(list(timing)) raw_timing = torch.tensor([json.loads(t) for t in prompt_timing]) timing = torch.clone(raw_timing) timing[:, 0] -= .2 timing[:, 1] += .2 total_length = (results['total_length'] / 16000)[0] fracs = timing / total_length bounds = basic_masks.shape[1] * fracs bounds[:, 0] = bounds[:, 0].floor() bounds[:, 1] = bounds[:, 1].ceil() bounds = bounds.to(torch.int64) advanced_masks = (F.one_hot(bounds, basic_masks.shape[1]).cumsum(-1).sum(-2) == 1).to(basic_masks) advanced_heatmaps = prep_heatmap(sims, advanced_masks, fullh, fullw) metrics = defaultdict(list) unique_classes = torch.unique(prompt_classes) should_plot = class_names is not None if should_plot: prompt_names = np.array(list(class_names)) for prompt_class in tqdm(unique_classes): subset = torch.where(prompt_classes == prompt_class)[0] gt_subset = gt[subset] basic_subset = basic_heatmaps[subset] metrics["basic_ap"].append(binary_average_precision(basic_subset.flatten(), gt_subset.flatten())) metrics["basic_iou"].append(multi_iou(basic_subset.flatten(), gt_subset.flatten())) if timing is not None: advanced_subset = advanced_heatmaps[subset] metrics["advanced_ap"].append(binary_average_precision(advanced_subset.flatten(), gt_subset.flatten())) metrics["advanced_iou"].append(multi_iou(advanced_subset.flatten(), gt_subset.flatten())) if should_plot: prompt_class_subset = prompt_classes[subset] name_subset = prompt_names[subset] print(prompt_class, name_subset, prompt_class_subset) n_imgs = min(len(subset), 5) if n_imgs > 1: fig, axes = plt.subplots(n_imgs, 5, figsize=(4 * 5, n_imgs * 3)) frame_subset = unnorm(results[IMAGE_INPUT][subset].squeeze(1)).permute(0, 2, 3, 1) semseg_subset = results["semseg"][subset] for img_num in range(n_imgs): axes[img_num, 0].imshow(frame_subset[img_num]) axes[img_num, 1].imshow(basic_subset[img_num]) axes[img_num, 2].imshow(advanced_subset[img_num]) axes[img_num, 3].imshow(gt_subset[img_num]) axes[img_num, 4].imshow(semseg_subset[img_num], cmap="tab20", interpolation='none') axes[0, 0].set_title("Image") class_name = name_subset[0].split(",")[0] axes[0, 1].set_title(f"{class_name} Basic Heatmap") axes[0, 2].set_title(f"{class_name} Advanced Heatmap") axes[0, 3].set_title("True Mask") axes[0, 4].set_title("True Seg") remove_axes(axes) plt.tight_layout() plt.show() return metrics, unique_classes