Spaces:
Running
Running
import random | |
from collections import defaultdict, deque | |
from typing import Any | |
import math | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
import torchaudio | |
import torchvision.transforms as T | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from torchaudio.functional import resample | |
class UnNormalize(object): | |
def __init__(self, mean, std): | |
self.mean = mean | |
self.std = std | |
def __call__(self, image): | |
image2 = torch.clone(image) | |
for t, m, s in zip(image2, self.mean, self.std): | |
t.mul_(s).add_(m) | |
return image2 | |
class SliceDataset(Dataset): | |
def __init__(self, ds, start, end): | |
self.ds = ds | |
self.start = start | |
self.end = end | |
def __len__(self): | |
return self.end - self.start | |
def __getitem__(self, item): | |
return self.ds[item + self.start] | |
class SubsetDataset(Dataset): | |
def __init__(self, ds, subset): | |
self.ds = ds | |
self.subset = subset | |
def __len__(self): | |
return len(self.subset) | |
def __getitem__(self, item): | |
return self.ds[self.subset[item]] | |
norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
def crop_to_divisor(x, patch_size): | |
if len(x.shape) == 3: | |
C, H, W = x.shape | |
return x[:, :(patch_size * (H // patch_size)), :(patch_size * (W // patch_size))] | |
elif len(x.shape) == 4: | |
B, C, H, W = x.shape | |
return x[:, :, :(patch_size * (H // patch_size)), :(patch_size * (W // patch_size))] | |
else: | |
raise ValueError("x should have 3 or 4 dimensions") | |
def _remove_axes(ax): | |
ax.xaxis.set_major_formatter(plt.NullFormatter()) | |
ax.yaxis.set_major_formatter(plt.NullFormatter()) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
def remove_axes(axes): | |
if len(axes.shape) == 2: | |
for ax1 in axes: | |
for ax in ax1: | |
_remove_axes(ax) | |
else: | |
for ax in axes: | |
_remove_axes(ax) | |
def get_image_featurizer(name, token_type="key", **kwargs): | |
name = name.lower() | |
if name == "vit": | |
from DenseAV.denseav.featurizers.DINO import DINOFeaturizer | |
patch_size = 16 | |
model = DINOFeaturizer("vit_small_patch16_224", patch_size, token_type) | |
dim = 384 | |
elif name == "dino16": | |
from DenseAV.denseav.featurizers.DINO import DINOFeaturizer | |
patch_size = 16 | |
model = DINOFeaturizer("dino_vits16", patch_size, token_type) | |
dim = 384 | |
elif name == "dino8": | |
from DenseAV.denseav.featurizers.DINO import DINOFeaturizer | |
patch_size = 8 | |
model = DINOFeaturizer("dino_vits8", patch_size, token_type) | |
dim = 384 | |
elif name == "clip": | |
from DenseAV.denseav.featurizers.CLIP import CLIPFeaturizer | |
patch_size = 16 | |
model = CLIPFeaturizer() | |
dim = 512 | |
elif name == "cavmae": | |
from DenseAV.denseav.featurizers.CAVMAE import CAVMAEImageFeaturizer | |
model = CAVMAEImageFeaturizer(kwargs["output_root"], model=kwargs.get("model")) | |
dim = 768 | |
patch_size = 16 | |
elif name == "fnac": | |
from DenseAV.denseav.featurizers.FNACAVL import FNACImageFeaturizer | |
model = FNACImageFeaturizer(kwargs["output_root"], model=kwargs.get("model")) | |
dim = 512 | |
patch_size = 16 | |
elif name == "imagebind": | |
from DenseAV.denseav.featurizers.ImageBind import ImageBindImageFeaturizer | |
model = ImageBindImageFeaturizer(kwargs["output_root"], model=kwargs.get("model")) | |
dim = 1024 | |
patch_size = 16 | |
elif name == "resnet50": | |
from torchvision import models | |
model = models.resnet50(pretrained=True) | |
model = torch.nn.Sequential(*list(model.children())[:-2]) | |
patch_size = 1 | |
dim = 2048 | |
elif name == "davenet": | |
from fDenseAV.denseav.eaturizers.DAVENet import DavenetImageFeaturizer | |
model = DavenetImageFeaturizer() | |
patch_size = 1 | |
dim = 1024 | |
elif name == "dinov2": | |
from DenseAV.denseav.featurizers.DINOv2 import DINOv2Featurizer | |
model = DINOv2Featurizer() | |
patch_size = 14 | |
dim = 768 | |
else: | |
raise ValueError("unknown model: {}".format(name)) | |
return model, patch_size, dim | |
def get_audio_featurizer(name, **kwargs): | |
if name == "davenet": | |
from DenseAV.denseav.featurizers.DAVENet import DavenetAudioFeaturizer | |
model = DavenetAudioFeaturizer() | |
dim = 1024 | |
elif name == "dino8": | |
model, _, dim = get_image_featurizer("dino8") | |
elif name == "hubert": | |
from DenseAV.denseav.featurizers.Hubert import Hubert | |
model = Hubert() | |
dim = 1024 | |
elif name == "cavmae": | |
from DenseAV.denseav.featurizers.CAVMAE import CAVMAEAudioFeaturizer | |
model = CAVMAEAudioFeaturizer(kwargs["output_root"], model=kwargs.get("model")) | |
dim = 768 | |
elif name == "imagebind": | |
from DenseAV.denseav.featurizers.ImageBind import ImageBindAudioFeaturizer | |
model = ImageBindAudioFeaturizer(kwargs["output_root"], model=kwargs.get("model")) | |
dim = 1024 | |
elif name == "audiomae": | |
from DenseAV.denseav.featurizers.AudioMAE import AudioMAE | |
model = AudioMAE(kwargs["output_root"], False) | |
dim = 768 | |
elif name == "audiomae-finetuned": | |
from DenseAV.denseav.featurizers.AudioMAE import AudioMAE | |
model = AudioMAE(kwargs["output_root"], True) | |
dim = 768 | |
else: | |
raise ValueError("Unknown audio model type") | |
return model, dim | |
def load_img(image_path, transform): | |
return transform(Image.open(image_path)).unsqueeze(0) | |
def pytorch_to_pil(tensor): | |
return Image.fromarray((unnorm(tensor).permute(0, 2, 3, 1).cpu() * 255) | |
.clamp(0, 255).to(torch.uint8).detach().numpy()[0]) | |
def _get_random_window(waveform, mask, min_size, max_size): | |
effective_size = mask.sum().to(torch.int64) | |
if effective_size <= min_size: | |
return waveform, mask | |
else: | |
window_size = min(torch.randint(low=min_size, high=min(effective_size, max_size), size=()), waveform.shape[0]) | |
if window_size == waveform.shape[0]: | |
window_start = 0 | |
else: | |
window_start = torch.randint(low=0, high=effective_size - window_size, size=()) | |
new_waveform = torch.zeros_like(waveform) | |
new_mask = torch.zeros_like(mask) | |
new_waveform[window_start:window_start + window_size] = waveform[window_start:window_start + window_size] | |
new_mask[window_start:window_start + window_size] = mask[window_start:window_start + window_size] | |
return new_waveform, new_mask | |
def _splice_clips(clip1, clip2, loc, easing_size): | |
assert loc >= 0 and loc < len(clip1), "Invalid location" | |
assert easing_size > 0 and easing_size <= len(clip2), "Invalid easing size" | |
try: | |
assert loc + clip2.shape[0] < clip1.shape[0] | |
except Exception as e: | |
print(loc, clip2.shape[0], clip1.shape[0]) | |
raise e | |
# Split clip1 into three parts: before splice, easing region, after splice | |
before_splice = clip1[:loc] | |
after_splice = clip1[loc + clip2.shape[0]:] | |
# Compute the fading weights for the easing region | |
# fade_in_weights = torch.cos(torch.linspace(1, 0, easing_size, device=clip1.device)) | |
fade_in_weights = 0.5 * (1 + torch.cos(math.pi * torch.linspace(0, 1, easing_size))) | |
fade_out_weights = 1 - fade_in_weights | |
clip1_ease = torch.cat([ | |
fade_in_weights, | |
torch.zeros(clip2.shape[0] - easing_size * 2), | |
fade_out_weights, | |
]) | |
mask = torch.cat([torch.ones(loc), clip1_ease, torch.ones(clip1.shape[0] - (loc + clip2.shape[0]))]) | |
# Apply fading weights to clip1 and clip2 within the easing region | |
splice = clip1_ease * clip1[loc:loc + clip2.shape[0]] + (1 - clip1_ease) * clip2 | |
# Concatenate all parts back together | |
spliced_clip = torch.cat((before_splice, splice, after_splice)) | |
return spliced_clip, mask | |
def _generate_random_subset(waveform, low, high): | |
length = len(waveform) | |
# If waveform is smaller than low or has zero length, return unmodified | |
if length < low or length == 0: | |
return waveform | |
# Generate random start index within valid range | |
start = random.randint(0, length - low) | |
# Generate random subset size within valid range | |
subset_size = random.randint(low, min(high, length - start)) | |
# Extract the random subset from the waveform | |
subset = waveform[start: start + subset_size] | |
return subset | |
def level_audio(waveform): | |
waveform -= waveform.mean() | |
waveform /= waveform.abs.max().valus.clamp_min(.0001) | |
return waveform | |
def prep_waveform(waveform, | |
obs_sr, | |
target_length, | |
spec_mel_bins, | |
spec_mean, | |
spec_std, | |
sample_rate, | |
return_spec, | |
random_clip, | |
extra_audio_masking, | |
neg_waveform, | |
neg_obs_sr, | |
audio_level, | |
audio_aug, | |
): | |
if obs_sr != sample_rate: | |
waveform = resample(waveform, obs_sr, sample_rate) | |
if audio_level: | |
waveform = level_audio(waveform) | |
if neg_obs_sr is not None and neg_obs_sr != sample_rate: | |
neg_waveform = resample(neg_waveform, neg_obs_sr, sample_rate) | |
if audio_level: | |
neg_waveform = level_audio(neg_waveform) | |
if neg_obs_sr is not None: # and random.random() > .5: | |
neg_waveform_clip = _generate_random_subset(neg_waveform, sample_rate, sample_rate * 4) | |
if waveform.shape[0] - neg_waveform_clip.shape[0] > 0: | |
start = random.randint(0, waveform.shape[0] - neg_waveform_clip.shape[0] - 1) | |
easing = max(int(neg_waveform_clip.shape[0] * 1 / 4), sample_rate // 2) | |
easing = min(int(neg_waveform_clip.shape[0] * 1 / 2), easing) | |
waveform, pos_mask = _splice_clips(waveform, neg_waveform_clip, start, easing_size=easing) | |
else: | |
waveform, pos_mask = waveform, torch.ones_like(waveform) | |
else: | |
waveform, pos_mask = waveform, torch.ones_like(waveform) | |
mask = torch.ones_like(waveform) | |
original_length = waveform.shape[0] | |
if target_length == 10: | |
target_samples = 164200 # Result is 1024 after spec | |
else: | |
target_samples = int(target_length * sample_rate) | |
padding = target_samples - original_length | |
if padding > 0: | |
p = torch.nn.ZeroPad2d((0, padding)) | |
waveform = p(waveform) | |
mask = p(mask) | |
pos_mask = p(pos_mask) | |
else: | |
if random_clip: | |
start = torch.randint(0, waveform.shape[0] - target_samples, size=()) | |
else: | |
start = 0 | |
end = start + target_samples | |
waveform = waveform[start:end] | |
mask = mask[start:end] | |
pos_mask = pos_mask[start:end] | |
audio_length = min(original_length, target_samples) | |
total_length = target_samples | |
if extra_audio_masking: | |
min_size = sample_rate // 2 | |
max_size = total_length | |
if original_length > min_size and random.random() > .5: | |
waveform, mask = _get_random_window(waveform, mask, min_size, max_size) | |
if audio_aug: | |
import torchaudio_augmentations as AA | |
from torchvision.transforms import RandomApply, Compose | |
transform = Compose([ | |
RandomApply([AA.PolarityInversion()], p=0.5), | |
RandomApply([AA.Noise(min_snr=0.001, max_snr=0.005)], p=0.2), | |
RandomApply([AA.Gain()], p=0.2), | |
RandomApply([AA.HighLowPass(sample_rate=sample_rate)], p=0.2), | |
RandomApply([AA.PitchShift(n_samples=waveform.shape[-1], sample_rate=sample_rate)], p=0.2), | |
RandomApply([AA.Reverb(sample_rate=sample_rate)], p=0.2) | |
]) | |
waveform = transform(waveform.unsqueeze(0)).squeeze(0) | |
if return_spec: | |
spectrogram = torchaudio.compliance.kaldi.fbank( | |
waveform.unsqueeze(0) - waveform.mean(), | |
htk_compat=True, | |
sample_frequency=sample_rate, | |
use_energy=False, | |
window_type='hanning', | |
num_mel_bins=spec_mel_bins, | |
dither=0.0, | |
frame_shift=10) | |
spectrogram = ((spectrogram - spec_mean) / spec_std).unsqueeze(0) | |
else: | |
spectrogram = None | |
if mask.mean() < .04: | |
print(f"Bad entry: {mask.mean()}") | |
return waveform, spectrogram, audio_length, total_length, original_length, mask, pos_mask | |
class ToTargetTensor(object): | |
def __call__(self, target): | |
return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0) | |
def show_heatmap(ax, | |
image, | |
heatmap, | |
cmap="bwr", | |
color=False, | |
center=False, | |
show_negative=False, | |
cax=None, | |
vmax=None, | |
vmin=None): | |
frame = [] | |
if color: | |
frame.append(ax.imshow(image)) | |
else: | |
bw = np.dot(np.array(image)[..., :3] / 255, [0.2989, 0.5870, 0.1140]) | |
bw = np.ones_like(image) * np.expand_dims(bw, -1) | |
frame.append(ax.imshow(bw)) | |
if center: | |
heatmap -= heatmap.mean() | |
if not show_negative: | |
heatmap = heatmap.clamp_min(0) | |
heatmap = F.interpolate(heatmap.unsqueeze(0).unsqueeze(0), (image.shape[0], image.shape[1])) \ | |
.squeeze(0).squeeze(0) | |
if vmax is None: | |
vmax = np.abs(heatmap).max() | |
if vmin is None: | |
vmin = -vmax | |
hm = ax.imshow(heatmap, alpha=.5, cmap=cmap, vmax=vmax, vmin=vmin) | |
if cax is not None: | |
plt.colorbar(hm, cax=cax, orientation='vertical') | |
frame.extend([hm]) | |
return frame | |
class TorchPCA(object): | |
def __init__(self, n_components): | |
self.n_components = n_components | |
def fit(self, X): | |
self.mean_ = X.mean(dim=0) | |
unbiased = X - self.mean_.unsqueeze(0) | |
U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4) | |
self.components_ = V.T | |
self.singular_values_ = S | |
return self | |
def transform(self, X): | |
t0 = X - self.mean_.unsqueeze(0) | |
projected = t0 @ self.components_.T | |
return projected | |
def pca(image_feats_list, dim=3, fit_pca=None): | |
device = image_feats_list[0].device | |
def flatten(tensor, target_size=None): | |
if target_size is not None and fit_pca is None: | |
F.interpolate(tensor, (target_size, target_size), mode="bilinear") | |
B, C, H, W = tensor.shape | |
return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu() | |
if len(image_feats_list) > 1 and fit_pca is None: | |
target_size = image_feats_list[0].shape[2] | |
else: | |
target_size = None | |
flattened_feats = [] | |
for feats in image_feats_list: | |
flattened_feats.append(flatten(feats, target_size)) | |
x = torch.cat(flattened_feats, dim=0) | |
if fit_pca is None: | |
# fit_pca = PCA(n_components=dim, svd_solver='arpack').fit(np.nan_to_num(x.detach().numpy())) | |
fit_pca = TorchPCA(n_components=dim).fit(x) | |
reduced_feats = [] | |
for feats in image_feats_list: | |
# x_red = torch.from_numpy(fit_pca.transform(flatten(feats))) | |
x_red = fit_pca.transform(flatten(feats)) | |
x_red -= x_red.min(dim=0, keepdim=True).values | |
x_red /= x_red.max(dim=0, keepdim=True).values | |
B, C, H, W = feats.shape | |
reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device)) | |
return reduced_feats, fit_pca | |
def merge_col(fig, axes, col): | |
gs = axes[0, col].get_gridspec() | |
for ax in axes[:, col]: | |
ax.remove() | |
return fig.add_subplot(gs[:, col]) | |
def visualize_av_features( | |
audio, | |
video, | |
feat_a, | |
feat_v, | |
att_a, | |
n_frames, | |
norm_before_pca=True, | |
axes=None, | |
fig=None, | |
modify_fig=True, | |
video_time=0, | |
fit_pca=None | |
): | |
assert (len(audio.shape) == 3) # C, F, T | |
assert (len(video.shape) == 4) # T, C, H, W | |
assert (len(feat_a.shape) == 2) # C, T | |
assert (len(feat_v.shape) == 4) # T, C, H, W | |
assert (len(att_a.shape) == 2) # F, T | |
ac, af, at = audio.shape | |
fac, fat = feat_a.shape | |
if modify_fig: | |
if axes is None: | |
fig, axes = plt.subplots(3, 3, figsize=(5 * 3, 5)) | |
fig.tight_layout() | |
bigax1 = merge_col(fig, axes, 0) | |
bigax2 = merge_col(fig, axes, 1) | |
_remove_axes(bigax1) | |
_remove_axes(bigax2) | |
remove_axes(axes[:, 2]) | |
else: | |
bigax1 = fig.axes[-2] | |
bigax2 = fig.axes[-1] | |
frame_v = unnorm(video).permute(0, 2, 3, 1).detach().cpu() | |
frame_v -= frame_v.min() | |
frame_v /= frame_v.max() | |
frame_a = audio.detach().cpu() | |
frame_a -= frame_a.min() | |
frame_a /= frame_a.max() | |
if norm_before_pca: | |
[red_feat_v], fit_pca = pca([F.normalize(feat_v, dim=1)], fit_pca=fit_pca) | |
[red_feat_a], _ = pca([F.normalize(feat_a.unsqueeze(0).unsqueeze(-1), dim=1)], fit_pca=fit_pca) | |
else: | |
[red_feat_v], fit_pca = pca([feat_v], fit_pca=fit_pca) | |
[red_feat_a], _ = pca([feat_a.unsqueeze(0).unsqueeze(-1)], fit_pca=fit_pca) | |
red_feat_v = red_feat_v.permute(0, 2, 3, 1).detach().cpu() | |
red_feat_a = red_feat_a.permute(0, 2, 3, 1)[0].detach().cpu() | |
if red_feat_a.shape[0] == 1: | |
new_height = int((frame_a.shape[0] / frame_a.shape[1]) * red_feat_a.shape[1]) | |
red_feat_a = torch.broadcast_to( | |
red_feat_a, (new_height, red_feat_a.shape[1], red_feat_a.shape[2])) | |
plt_att_a = torch.broadcast_to(att_a, (new_height, att_a.shape[1])) | |
else: | |
plt_att_a = att_a | |
frac_signal = n_frames / fat | |
n_at = int(at * frac_signal) | |
return [bigax1.imshow(frame_v[video_time]), | |
bigax2.imshow(red_feat_v[video_time]), | |
axes[0, 2].imshow(frame_a[:, :n_at]), | |
axes[0, 2].set_title("Spectrogram"), | |
axes[1, 2].imshow(red_feat_a[:, :n_frames]), | |
axes[1, 2].set_title("Audio Features"), | |
axes[2, 2].imshow(plt_att_a[:, :n_frames], vmin=0), | |
axes[2, 2].set_title("Audio Attention")], fig, fit_pca | |
def create_label_tensor(labels, starts, ends, max_time, n_steps): | |
assert isinstance(starts, torch.Tensor) | |
assert isinstance(ends, torch.Tensor) | |
ends[ends < 0] = max_time | |
fps = n_steps / max_time | |
times = (torch.arange(0, n_steps, device=labels.device, dtype=torch.float32) + .5) / fps | |
after_start = starts.unsqueeze(1) <= times.unsqueeze(0) | |
before_end = ends.unsqueeze(1) >= times.unsqueeze(0) | |
# Find when you are inside of a word | |
in_word = (after_start * before_end) | |
# Find which word you are inside of | |
word_to_use = in_word.to(torch.float32).argmax(0) | |
# Get the label for that word, or mask out the label if in no word | |
final_labels = labels[word_to_use] * in_word.any(0).reshape(-1, 1, 1) | |
return final_labels | |
def generate_subset(n, batch, seed=0): | |
np.random.seed(seed) | |
return np.random.permutation(n)[:batch] | |
def channel_blur(t, window=5, std_dev=1): | |
tb, tc, th, tw = t.shape | |
x = torch.linspace(-2, 2, window, device=t.device, dtype=torch.float32) | |
k = torch.exp((-x ** 2 / (2 * std_dev ** 2))) | |
k = k / k.sum() | |
pad = window // 2 | |
t_pad = F.pad(t, [0, 0, 0, 0, pad, pad], mode="replicate") | |
tpb, tpc, tph, tpw = t_pad.shape | |
flattened_t = t_pad.permute(0, 2, 3, 1).reshape(tpb * tph * tpw, 1, -1) | |
return F.conv1d(flattened_t, k.reshape(1, 1, window)).reshape(tpb, tph, tpw, tc).permute(0, 3, 1, 2) | |
def time_blur(t, window=5, std_dev=1): | |
tb, tc, tt = t.shape | |
with torch.no_grad(): | |
x = torch.linspace(-2, 2, window, device=t.device, dtype=torch.float32) | |
k = torch.exp((-x ** 2 / (2 * std_dev ** 2))) | |
k = k / k.sum() | |
k = k.reshape(1, 1, window).detach() | |
pad = window // 2 | |
t_pad = F.pad(t, [pad, pad], mode="replicate") | |
return F.conv1d(t_pad.reshape(tb * tc, 1, -1), k).reshape(tb, tc, tt) | |
def create_model_from_cfg(clazz, cfg, extra_args): | |
import inspect | |
expected_args = inspect.getfullargspec(clazz.__init__).args[1:] | |
new_args = {k: v for k, v in {**cfg, **extra_args}.items() if k in expected_args} | |
return clazz(**new_args) | |
def load_trained_model(chkpt_dir, extra_args, strict=True): | |
from train_av_alignment import LitAVAligner | |
model = LitAVAligner.load_from_checkpoint(chkpt_dir, **extra_args, strict=strict).cuda() | |
return model | |
def flatten(l): | |
return [item for sublist in l for item in sublist] | |
def flatten_preds(preds): | |
results = {} | |
for k in preds[0].keys(): | |
if k == "caption_labels": | |
continue | |
if isinstance(preds[0][k], torch.Tensor): | |
results[k] = torch.cat([p[k] for p in preds], dim=0) | |
if "caption" in preds[0]: | |
results["caption"] = flatten([p["caption"] for p in preds]) | |
if "metadata" in preds[0]: | |
results["frame_files"] = flatten([list(p["metadata"]["frame_files"][0]) for p in preds]) | |
results["audio_file"] = flatten([list(p["metadata"]["audio_file"]) for p in preds]) | |
results["id"] = flatten([list(p["metadata"]["id"]) for p in preds]) | |
results["index"] = torch.tensor(flatten([list(p["metadata"]["index"]) for p in preds])) | |
return results | |
def batch(iterable, n=1): | |
l = len(iterable) | |
for ndx in range(0, l, n): | |
yield iterable[ndx:min(ndx + n, l)] | |
class GatherLayer(torch.autograd.Function): | |
"""Gather tensors from all process, supporting backward propagation.""" | |
def jvp(ctx: Any, *grad_inputs: Any) -> Any: | |
pass | |
def forward(ctx, inputs): | |
ctx.save_for_backward(inputs) | |
output = [torch.zeros_like(inputs) for _ in range(dist.get_world_size())] | |
dist.all_gather(output, inputs) | |
return tuple(output) | |
def backward(ctx, *grads): | |
(inputs,) = ctx.saved_tensors | |
grad_out = torch.zeros_like(inputs) | |
grad_out[:] = grads[dist.get_rank()] | |
return grad_out | |
class RollingAvg: | |
def __init__(self, length, nonzero=False): | |
self.length = length | |
self.nonzero = nonzero | |
self.metrics = defaultdict(lambda: deque(maxlen=self.length)) | |
def add(self, name, metric): | |
if self.nonzero and metric == 0: | |
return | |
if isinstance(metric, torch.Tensor): | |
metric = metric.detach() | |
self.metrics[name].append(metric) | |
def get(self, name): | |
with torch.no_grad(): | |
return torch.tensor(list(self.metrics[name])).mean() | |
def get_all(self): | |
return {k: self.get(k) for k in self.metrics.keys()} | |
def add_all(self, values): | |
for k, v in values.items(): | |
self.add(k, v) | |
def logall(self, log_func): | |
for k in self.metrics.keys(): | |
log_func(k, self.get(k)) | |
def gaussian_kernel(k, sigma): | |
kernel = torch.tensor([math.exp(-0.5 * (x - (k // 2)) ** 2 / sigma ** 2) for x in range(k)], dtype=torch.float32) | |
kernel /= kernel.sum() # Normalize the kernel | |
return kernel | |
def blur_dim(t, window=5, std_dev=1, dim=-1): | |
shape = t.shape | |
n_dims = len(shape) | |
# Create the Gaussian kernel | |
with torch.no_grad(): | |
x = torch.linspace(-2, 2, window, device=t.device, dtype=torch.float32) | |
k = torch.exp(-x ** 2 / (2 * std_dev ** 2)) | |
k = k / k.sum() | |
k = k.view(1, 1, window).detach() | |
# Calculate padding | |
pad = window // 2 | |
# Move the target dimension to the end | |
permute_order = list(range(n_dims)) | |
permute_order.append(permute_order.pop(dim)) | |
t_permuted = t.permute(permute_order) | |
# Flatten all dimensions except the last one | |
new_shape = (-1, t_permuted.size(-1)) | |
t_flattened = t_permuted.reshape(new_shape) | |
# Pad the tensor | |
t_padded = F.pad(t_flattened.unsqueeze(1), (pad, pad), mode="replicate") | |
# Apply convolution | |
blurred = F.conv1d(t_padded, k) | |
# Reshape back to original | |
blurred = blurred.squeeze(1).reshape(*t_permuted.shape) | |
blurred = blurred.permute([permute_order.index(i) for i in range(n_dims)]) | |
return blurred | |