Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
all_times = [] | |
class Timer: | |
def __init__(self, name: str, enabled: bool = True): | |
self.name = name | |
self.enabled = enabled | |
if self.enabled: | |
self.start = torch.cuda.Event(enable_timing=True) | |
self.end = torch.cuda.Event(enable_timing=True) | |
def __enter__(self): | |
if self.enabled: | |
self.start.record() | |
def __exit__(self, type, value, traceback): | |
global all_times | |
if self.enabled: | |
self.end.record() | |
torch.cuda.synchronize() | |
elapsed = self.start.elapsed_time(self.end) | |
all_times.append(elapsed) | |
print(f"{self.name}: {elapsed:.2f}ms") | |
def coords_grid(b, n, h, w, **kwargs): | |
"""coordinate grid""" | |
x = torch.arange(0, w, dtype=torch.float, **kwargs) | |
y = torch.arange(0, h, dtype=torch.float, **kwargs) | |
coords = torch.stack(torch.meshgrid(y, x, indexing="ij")) | |
return coords[[1, 0]].view(1, 1, 2, h, w).repeat(b, n, 1, 1, 1) | |
def coords_grid_with_index(d, **kwargs): | |
"""coordinate grid with frame index""" | |
b, n, h, w = d.shape | |
i = torch.ones_like(d) | |
x = torch.arange(0, w, dtype=torch.float, **kwargs) | |
y = torch.arange(0, h, dtype=torch.float, **kwargs) | |
y, x = torch.stack(torch.meshgrid(y, x, indexing="ij")) | |
y = y.view(1, 1, h, w).repeat(b, n, 1, 1) | |
x = x.view(1, 1, h, w).repeat(b, n, 1, 1) | |
coords = torch.stack([x, y, d], dim=2) | |
index = torch.arange(0, n, dtype=torch.float, **kwargs) | |
index = index.view(1, n, 1, 1, 1).repeat(b, 1, 1, h, w) | |
return coords, index | |
def patchify(x, patch_size=3): | |
"""extract patches from video""" | |
b, n, c, h, w = x.shape | |
x = x.view(b * n, c, h, w) | |
y = F.unfold(x, patch_size) | |
y = y.transpose(1, 2) | |
return y.reshape(b, -1, c, patch_size, patch_size) | |
def pyramidify(fmap, lvls=[1]): | |
"""turn fmap into a pyramid""" | |
b, n, c, h, w = fmap.shape | |
pyramid = [] | |
for lvl in lvls: | |
gmap = F.avg_pool2d(fmap.view(b * n, c, h, w), lvl, stride=lvl) | |
pyramid += [gmap.view(b, n, c, h // lvl, w // lvl)] | |
return pyramid | |
def all_pairs_exclusive(n, **kwargs): | |
ii, jj = torch.meshgrid(torch.arange(n, **kwargs), torch.arange(n, **kwargs)) | |
k = ii != jj | |
return ii[k].reshape(-1), jj[k].reshape(-1) | |
def set_depth(patches, depth): | |
patches[..., 2, :, :] = depth[..., None, None] | |
return patches | |
def flatmeshgrid(*args, **kwargs): | |
grid = torch.meshgrid(*args, **kwargs) | |
return (x.reshape(-1) for x in grid) | |