import random from collections import defaultdict, deque from typing import Any import math import matplotlib.pyplot as plt import numpy as np import torch import torch.distributed as dist import torch.nn.functional as F import torchaudio import torchvision.transforms as T from PIL import Image from torch.utils.data import Dataset from torchaudio.functional import resample class UnNormalize(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, image): image2 = torch.clone(image) for t, m, s in zip(image2, self.mean, self.std): t.mul_(s).add_(m) return image2 class SliceDataset(Dataset): def __init__(self, ds, start, end): self.ds = ds self.start = start self.end = end def __len__(self): return self.end - self.start def __getitem__(self, item): return self.ds[item + self.start] class SubsetDataset(Dataset): def __init__(self, ds, subset): self.ds = ds self.subset = subset def __len__(self): return len(self.subset) def __getitem__(self, item): return self.ds[self.subset[item]] norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) def crop_to_divisor(x, patch_size): if len(x.shape) == 3: C, H, W = x.shape return x[:, :(patch_size * (H // patch_size)), :(patch_size * (W // patch_size))] elif len(x.shape) == 4: B, C, H, W = x.shape return x[:, :, :(patch_size * (H // patch_size)), :(patch_size * (W // patch_size))] else: raise ValueError("x should have 3 or 4 dimensions") def _remove_axes(ax): ax.xaxis.set_major_formatter(plt.NullFormatter()) ax.yaxis.set_major_formatter(plt.NullFormatter()) ax.set_xticks([]) ax.set_yticks([]) def remove_axes(axes): if len(axes.shape) == 2: for ax1 in axes: for ax in ax1: _remove_axes(ax) else: for ax in axes: _remove_axes(ax) def get_image_featurizer(name, token_type="key", **kwargs): name = name.lower() if name == "vit": from DenseAV.denseav.featurizers.DINO import DINOFeaturizer patch_size = 16 model = DINOFeaturizer("vit_small_patch16_224", patch_size, token_type) dim = 384 elif name == "dino16": from DenseAV.denseav.featurizers.DINO import DINOFeaturizer patch_size = 16 model = DINOFeaturizer("dino_vits16", patch_size, token_type) dim = 384 elif name == "dino8": from DenseAV.denseav.featurizers.DINO import DINOFeaturizer patch_size = 8 model = DINOFeaturizer("dino_vits8", patch_size, token_type) dim = 384 elif name == "clip": from DenseAV.denseav.featurizers.CLIP import CLIPFeaturizer patch_size = 16 model = CLIPFeaturizer() dim = 512 elif name == "cavmae": from DenseAV.denseav.featurizers.CAVMAE import CAVMAEImageFeaturizer model = CAVMAEImageFeaturizer(kwargs["output_root"], model=kwargs.get("model")) dim = 768 patch_size = 16 elif name == "fnac": from DenseAV.denseav.featurizers.FNACAVL import FNACImageFeaturizer model = FNACImageFeaturizer(kwargs["output_root"], model=kwargs.get("model")) dim = 512 patch_size = 16 elif name == "imagebind": from DenseAV.denseav.featurizers.ImageBind import ImageBindImageFeaturizer model = ImageBindImageFeaturizer(kwargs["output_root"], model=kwargs.get("model")) dim = 1024 patch_size = 16 elif name == "resnet50": from torchvision import models model = models.resnet50(pretrained=True) model = torch.nn.Sequential(*list(model.children())[:-2]) patch_size = 1 dim = 2048 elif name == "davenet": from fDenseAV.denseav.eaturizers.DAVENet import DavenetImageFeaturizer model = DavenetImageFeaturizer() patch_size = 1 dim = 1024 elif name == "dinov2": from DenseAV.denseav.featurizers.DINOv2 import DINOv2Featurizer model = DINOv2Featurizer() patch_size = 14 dim = 768 else: raise ValueError("unknown model: {}".format(name)) return model, patch_size, dim def get_audio_featurizer(name, **kwargs): if name == "davenet": from DenseAV.denseav.featurizers.DAVENet import DavenetAudioFeaturizer model = DavenetAudioFeaturizer() dim = 1024 elif name == "dino8": model, _, dim = get_image_featurizer("dino8") elif name == "hubert": from DenseAV.denseav.featurizers.Hubert import Hubert model = Hubert() dim = 1024 elif name == "cavmae": from DenseAV.denseav.featurizers.CAVMAE import CAVMAEAudioFeaturizer model = CAVMAEAudioFeaturizer(kwargs["output_root"], model=kwargs.get("model")) dim = 768 elif name == "imagebind": from DenseAV.denseav.featurizers.ImageBind import ImageBindAudioFeaturizer model = ImageBindAudioFeaturizer(kwargs["output_root"], model=kwargs.get("model")) dim = 1024 elif name == "audiomae": from DenseAV.denseav.featurizers.AudioMAE import AudioMAE model = AudioMAE(kwargs["output_root"], False) dim = 768 elif name == "audiomae-finetuned": from DenseAV.denseav.featurizers.AudioMAE import AudioMAE model = AudioMAE(kwargs["output_root"], True) dim = 768 else: raise ValueError("Unknown audio model type") return model, dim def load_img(image_path, transform): return transform(Image.open(image_path)).unsqueeze(0) def pytorch_to_pil(tensor): return Image.fromarray((unnorm(tensor).permute(0, 2, 3, 1).cpu() * 255) .clamp(0, 255).to(torch.uint8).detach().numpy()[0]) def _get_random_window(waveform, mask, min_size, max_size): effective_size = mask.sum().to(torch.int64) if effective_size <= min_size: return waveform, mask else: window_size = min(torch.randint(low=min_size, high=min(effective_size, max_size), size=()), waveform.shape[0]) if window_size == waveform.shape[0]: window_start = 0 else: window_start = torch.randint(low=0, high=effective_size - window_size, size=()) new_waveform = torch.zeros_like(waveform) new_mask = torch.zeros_like(mask) new_waveform[window_start:window_start + window_size] = waveform[window_start:window_start + window_size] new_mask[window_start:window_start + window_size] = mask[window_start:window_start + window_size] return new_waveform, new_mask def _splice_clips(clip1, clip2, loc, easing_size): assert loc >= 0 and loc < len(clip1), "Invalid location" assert easing_size > 0 and easing_size <= len(clip2), "Invalid easing size" try: assert loc + clip2.shape[0] < clip1.shape[0] except Exception as e: print(loc, clip2.shape[0], clip1.shape[0]) raise e # Split clip1 into three parts: before splice, easing region, after splice before_splice = clip1[:loc] after_splice = clip1[loc + clip2.shape[0]:] # Compute the fading weights for the easing region # fade_in_weights = torch.cos(torch.linspace(1, 0, easing_size, device=clip1.device)) fade_in_weights = 0.5 * (1 + torch.cos(math.pi * torch.linspace(0, 1, easing_size))) fade_out_weights = 1 - fade_in_weights clip1_ease = torch.cat([ fade_in_weights, torch.zeros(clip2.shape[0] - easing_size * 2), fade_out_weights, ]) mask = torch.cat([torch.ones(loc), clip1_ease, torch.ones(clip1.shape[0] - (loc + clip2.shape[0]))]) # Apply fading weights to clip1 and clip2 within the easing region splice = clip1_ease * clip1[loc:loc + clip2.shape[0]] + (1 - clip1_ease) * clip2 # Concatenate all parts back together spliced_clip = torch.cat((before_splice, splice, after_splice)) return spliced_clip, mask def _generate_random_subset(waveform, low, high): length = len(waveform) # If waveform is smaller than low or has zero length, return unmodified if length < low or length == 0: return waveform # Generate random start index within valid range start = random.randint(0, length - low) # Generate random subset size within valid range subset_size = random.randint(low, min(high, length - start)) # Extract the random subset from the waveform subset = waveform[start: start + subset_size] return subset def level_audio(waveform): waveform -= waveform.mean() waveform /= waveform.abs.max().valus.clamp_min(.0001) return waveform def prep_waveform(waveform, obs_sr, target_length, spec_mel_bins, spec_mean, spec_std, sample_rate, return_spec, random_clip, extra_audio_masking, neg_waveform, neg_obs_sr, audio_level, audio_aug, ): if obs_sr != sample_rate: waveform = resample(waveform, obs_sr, sample_rate) if audio_level: waveform = level_audio(waveform) if neg_obs_sr is not None and neg_obs_sr != sample_rate: neg_waveform = resample(neg_waveform, neg_obs_sr, sample_rate) if audio_level: neg_waveform = level_audio(neg_waveform) if neg_obs_sr is not None: # and random.random() > .5: neg_waveform_clip = _generate_random_subset(neg_waveform, sample_rate, sample_rate * 4) if waveform.shape[0] - neg_waveform_clip.shape[0] > 0: start = random.randint(0, waveform.shape[0] - neg_waveform_clip.shape[0] - 1) easing = max(int(neg_waveform_clip.shape[0] * 1 / 4), sample_rate // 2) easing = min(int(neg_waveform_clip.shape[0] * 1 / 2), easing) waveform, pos_mask = _splice_clips(waveform, neg_waveform_clip, start, easing_size=easing) else: waveform, pos_mask = waveform, torch.ones_like(waveform) else: waveform, pos_mask = waveform, torch.ones_like(waveform) mask = torch.ones_like(waveform) original_length = waveform.shape[0] if target_length == 10: target_samples = 164200 # Result is 1024 after spec else: target_samples = int(target_length * sample_rate) padding = target_samples - original_length if padding > 0: p = torch.nn.ZeroPad2d((0, padding)) waveform = p(waveform) mask = p(mask) pos_mask = p(pos_mask) else: if random_clip: start = torch.randint(0, waveform.shape[0] - target_samples, size=()) else: start = 0 end = start + target_samples waveform = waveform[start:end] mask = mask[start:end] pos_mask = pos_mask[start:end] audio_length = min(original_length, target_samples) total_length = target_samples if extra_audio_masking: min_size = sample_rate // 2 max_size = total_length if original_length > min_size and random.random() > .5: waveform, mask = _get_random_window(waveform, mask, min_size, max_size) if audio_aug: import torchaudio_augmentations as AA from torchvision.transforms import RandomApply, Compose transform = Compose([ RandomApply([AA.PolarityInversion()], p=0.5), RandomApply([AA.Noise(min_snr=0.001, max_snr=0.005)], p=0.2), RandomApply([AA.Gain()], p=0.2), RandomApply([AA.HighLowPass(sample_rate=sample_rate)], p=0.2), RandomApply([AA.PitchShift(n_samples=waveform.shape[-1], sample_rate=sample_rate)], p=0.2), RandomApply([AA.Reverb(sample_rate=sample_rate)], p=0.2) ]) waveform = transform(waveform.unsqueeze(0)).squeeze(0) if return_spec: spectrogram = torchaudio.compliance.kaldi.fbank( waveform.unsqueeze(0) - waveform.mean(), htk_compat=True, sample_frequency=sample_rate, use_energy=False, window_type='hanning', num_mel_bins=spec_mel_bins, dither=0.0, frame_shift=10) spectrogram = ((spectrogram - spec_mean) / spec_std).unsqueeze(0) else: spectrogram = None if mask.mean() < .04: print(f"Bad entry: {mask.mean()}") return waveform, spectrogram, audio_length, total_length, original_length, mask, pos_mask class ToTargetTensor(object): def __call__(self, target): return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0) def show_heatmap(ax, image, heatmap, cmap="bwr", color=False, center=False, show_negative=False, cax=None, vmax=None, vmin=None): frame = [] if color: frame.append(ax.imshow(image)) else: bw = np.dot(np.array(image)[..., :3] / 255, [0.2989, 0.5870, 0.1140]) bw = np.ones_like(image) * np.expand_dims(bw, -1) frame.append(ax.imshow(bw)) if center: heatmap -= heatmap.mean() if not show_negative: heatmap = heatmap.clamp_min(0) heatmap = F.interpolate(heatmap.unsqueeze(0).unsqueeze(0), (image.shape[0], image.shape[1])) \ .squeeze(0).squeeze(0) if vmax is None: vmax = np.abs(heatmap).max() if vmin is None: vmin = -vmax hm = ax.imshow(heatmap, alpha=.5, cmap=cmap, vmax=vmax, vmin=vmin) if cax is not None: plt.colorbar(hm, cax=cax, orientation='vertical') frame.extend([hm]) return frame class TorchPCA(object): def __init__(self, n_components): self.n_components = n_components def fit(self, X): self.mean_ = X.mean(dim=0) unbiased = X - self.mean_.unsqueeze(0) U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4) self.components_ = V.T self.singular_values_ = S return self def transform(self, X): t0 = X - self.mean_.unsqueeze(0) projected = t0 @ self.components_.T return projected def pca(image_feats_list, dim=3, fit_pca=None): device = image_feats_list[0].device def flatten(tensor, target_size=None): if target_size is not None and fit_pca is None: F.interpolate(tensor, (target_size, target_size), mode="bilinear") B, C, H, W = tensor.shape return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu() if len(image_feats_list) > 1 and fit_pca is None: target_size = image_feats_list[0].shape[2] else: target_size = None flattened_feats = [] for feats in image_feats_list: flattened_feats.append(flatten(feats, target_size)) x = torch.cat(flattened_feats, dim=0) if fit_pca is None: # fit_pca = PCA(n_components=dim, svd_solver='arpack').fit(np.nan_to_num(x.detach().numpy())) fit_pca = TorchPCA(n_components=dim).fit(x) reduced_feats = [] for feats in image_feats_list: # x_red = torch.from_numpy(fit_pca.transform(flatten(feats))) x_red = fit_pca.transform(flatten(feats)) x_red -= x_red.min(dim=0, keepdim=True).values x_red /= x_red.max(dim=0, keepdim=True).values B, C, H, W = feats.shape reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device)) return reduced_feats, fit_pca def merge_col(fig, axes, col): gs = axes[0, col].get_gridspec() for ax in axes[:, col]: ax.remove() return fig.add_subplot(gs[:, col]) def visualize_av_features( audio, video, feat_a, feat_v, att_a, n_frames, norm_before_pca=True, axes=None, fig=None, modify_fig=True, video_time=0, fit_pca=None ): assert (len(audio.shape) == 3) # C, F, T assert (len(video.shape) == 4) # T, C, H, W assert (len(feat_a.shape) == 2) # C, T assert (len(feat_v.shape) == 4) # T, C, H, W assert (len(att_a.shape) == 2) # F, T ac, af, at = audio.shape fac, fat = feat_a.shape if modify_fig: if axes is None: fig, axes = plt.subplots(3, 3, figsize=(5 * 3, 5)) fig.tight_layout() bigax1 = merge_col(fig, axes, 0) bigax2 = merge_col(fig, axes, 1) _remove_axes(bigax1) _remove_axes(bigax2) remove_axes(axes[:, 2]) else: bigax1 = fig.axes[-2] bigax2 = fig.axes[-1] frame_v = unnorm(video).permute(0, 2, 3, 1).detach().cpu() frame_v -= frame_v.min() frame_v /= frame_v.max() frame_a = audio.detach().cpu() frame_a -= frame_a.min() frame_a /= frame_a.max() if norm_before_pca: [red_feat_v], fit_pca = pca([F.normalize(feat_v, dim=1)], fit_pca=fit_pca) [red_feat_a], _ = pca([F.normalize(feat_a.unsqueeze(0).unsqueeze(-1), dim=1)], fit_pca=fit_pca) else: [red_feat_v], fit_pca = pca([feat_v], fit_pca=fit_pca) [red_feat_a], _ = pca([feat_a.unsqueeze(0).unsqueeze(-1)], fit_pca=fit_pca) red_feat_v = red_feat_v.permute(0, 2, 3, 1).detach().cpu() red_feat_a = red_feat_a.permute(0, 2, 3, 1)[0].detach().cpu() if red_feat_a.shape[0] == 1: new_height = int((frame_a.shape[0] / frame_a.shape[1]) * red_feat_a.shape[1]) red_feat_a = torch.broadcast_to( red_feat_a, (new_height, red_feat_a.shape[1], red_feat_a.shape[2])) plt_att_a = torch.broadcast_to(att_a, (new_height, att_a.shape[1])) else: plt_att_a = att_a frac_signal = n_frames / fat n_at = int(at * frac_signal) return [bigax1.imshow(frame_v[video_time]), bigax2.imshow(red_feat_v[video_time]), axes[0, 2].imshow(frame_a[:, :n_at]), axes[0, 2].set_title("Spectrogram"), axes[1, 2].imshow(red_feat_a[:, :n_frames]), axes[1, 2].set_title("Audio Features"), axes[2, 2].imshow(plt_att_a[:, :n_frames], vmin=0), axes[2, 2].set_title("Audio Attention")], fig, fit_pca def create_label_tensor(labels, starts, ends, max_time, n_steps): assert isinstance(starts, torch.Tensor) assert isinstance(ends, torch.Tensor) ends[ends < 0] = max_time fps = n_steps / max_time times = (torch.arange(0, n_steps, device=labels.device, dtype=torch.float32) + .5) / fps after_start = starts.unsqueeze(1) <= times.unsqueeze(0) before_end = ends.unsqueeze(1) >= times.unsqueeze(0) # Find when you are inside of a word in_word = (after_start * before_end) # Find which word you are inside of word_to_use = in_word.to(torch.float32).argmax(0) # Get the label for that word, or mask out the label if in no word final_labels = labels[word_to_use] * in_word.any(0).reshape(-1, 1, 1) return final_labels def generate_subset(n, batch, seed=0): np.random.seed(seed) return np.random.permutation(n)[:batch] def channel_blur(t, window=5, std_dev=1): tb, tc, th, tw = t.shape x = torch.linspace(-2, 2, window, device=t.device, dtype=torch.float32) k = torch.exp((-x ** 2 / (2 * std_dev ** 2))) k = k / k.sum() pad = window // 2 t_pad = F.pad(t, [0, 0, 0, 0, pad, pad], mode="replicate") tpb, tpc, tph, tpw = t_pad.shape flattened_t = t_pad.permute(0, 2, 3, 1).reshape(tpb * tph * tpw, 1, -1) return F.conv1d(flattened_t, k.reshape(1, 1, window)).reshape(tpb, tph, tpw, tc).permute(0, 3, 1, 2) def time_blur(t, window=5, std_dev=1): tb, tc, tt = t.shape with torch.no_grad(): x = torch.linspace(-2, 2, window, device=t.device, dtype=torch.float32) k = torch.exp((-x ** 2 / (2 * std_dev ** 2))) k = k / k.sum() k = k.reshape(1, 1, window).detach() pad = window // 2 t_pad = F.pad(t, [pad, pad], mode="replicate") return F.conv1d(t_pad.reshape(tb * tc, 1, -1), k).reshape(tb, tc, tt) def create_model_from_cfg(clazz, cfg, extra_args): import inspect expected_args = inspect.getfullargspec(clazz.__init__).args[1:] new_args = {k: v for k, v in {**cfg, **extra_args}.items() if k in expected_args} return clazz(**new_args) def load_trained_model(chkpt_dir, extra_args, strict=True): from train_av_alignment import LitAVAligner model = LitAVAligner.load_from_checkpoint(chkpt_dir, **extra_args, strict=strict).cuda() return model def flatten(l): return [item for sublist in l for item in sublist] def flatten_preds(preds): results = {} for k in preds[0].keys(): if k == "caption_labels": continue if isinstance(preds[0][k], torch.Tensor): results[k] = torch.cat([p[k] for p in preds], dim=0) if "caption" in preds[0]: results["caption"] = flatten([p["caption"] for p in preds]) if "metadata" in preds[0]: results["frame_files"] = flatten([list(p["metadata"]["frame_files"][0]) for p in preds]) results["audio_file"] = flatten([list(p["metadata"]["audio_file"]) for p in preds]) results["id"] = flatten([list(p["metadata"]["id"]) for p in preds]) results["index"] = torch.tensor(flatten([list(p["metadata"]["index"]) for p in preds])) return results def batch(iterable, n=1): l = len(iterable) for ndx in range(0, l, n): yield iterable[ndx:min(ndx + n, l)] class GatherLayer(torch.autograd.Function): """Gather tensors from all process, supporting backward propagation.""" @staticmethod def jvp(ctx: Any, *grad_inputs: Any) -> Any: pass @staticmethod def forward(ctx, inputs): ctx.save_for_backward(inputs) output = [torch.zeros_like(inputs) for _ in range(dist.get_world_size())] dist.all_gather(output, inputs) return tuple(output) @staticmethod def backward(ctx, *grads): (inputs,) = ctx.saved_tensors grad_out = torch.zeros_like(inputs) grad_out[:] = grads[dist.get_rank()] return grad_out class RollingAvg: def __init__(self, length, nonzero=False): self.length = length self.nonzero = nonzero self.metrics = defaultdict(lambda: deque(maxlen=self.length)) def add(self, name, metric): if self.nonzero and metric == 0: return if isinstance(metric, torch.Tensor): metric = metric.detach() self.metrics[name].append(metric) def get(self, name): with torch.no_grad(): return torch.tensor(list(self.metrics[name])).mean() def get_all(self): return {k: self.get(k) for k in self.metrics.keys()} def add_all(self, values): for k, v in values.items(): self.add(k, v) def logall(self, log_func): for k in self.metrics.keys(): log_func(k, self.get(k)) def gaussian_kernel(k, sigma): kernel = torch.tensor([math.exp(-0.5 * (x - (k // 2)) ** 2 / sigma ** 2) for x in range(k)], dtype=torch.float32) kernel /= kernel.sum() # Normalize the kernel return kernel def blur_dim(t, window=5, std_dev=1, dim=-1): shape = t.shape n_dims = len(shape) # Create the Gaussian kernel with torch.no_grad(): x = torch.linspace(-2, 2, window, device=t.device, dtype=torch.float32) k = torch.exp(-x ** 2 / (2 * std_dev ** 2)) k = k / k.sum() k = k.view(1, 1, window).detach() # Calculate padding pad = window // 2 # Move the target dimension to the end permute_order = list(range(n_dims)) permute_order.append(permute_order.pop(dim)) t_permuted = t.permute(permute_order) # Flatten all dimensions except the last one new_shape = (-1, t_permuted.size(-1)) t_flattened = t_permuted.reshape(new_shape) # Pad the tensor t_padded = F.pad(t_flattened.unsqueeze(1), (pad, pad), mode="replicate") # Apply convolution blurred = F.conv1d(t_padded, k) # Reshape back to original blurred = blurred.squeeze(1).reshape(*t_permuted.shape) blurred = blurred.permute([permute_order.index(i) for i in range(n_dims)]) return blurred