Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,368 Bytes
246c106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import os
import math
import torch.nn.functional as F
import numpy as np
import einops
def load_i3d_pretrained(device=torch.device('cpu')):
i3D_WEIGHTS_URL = "https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI"
filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_pretrained_400.pt')
print(filepath)
if not os.path.exists(filepath):
print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
from .pytorch_i3d import InceptionI3d
i3d = InceptionI3d(400, in_channels=3).eval().to(device)
i3d.load_state_dict(torch.load(filepath, map_location=device))
i3d = torch.nn.DataParallel(i3d)
return i3d
def preprocess_single(video, resolution, sequence_length=None):
# video: THWC, {0, ..., 255}
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
t, c, h, w = video.shape
# temporal crop
if sequence_length is not None:
assert sequence_length <= t
video = video[:sequence_length]
# scale shorter side to resolution
scale = resolution / min(h, w)
if h < w:
target_size = (resolution, math.ceil(w * scale))
else:
target_size = (math.ceil(h * scale), resolution)
video = F.interpolate(video, size=target_size, mode='bilinear',
align_corners=False)
# center crop
t, c, h, w = video.shape
w_start = (w - resolution) // 2
h_start = (h - resolution) // 2
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
video = video.permute(1, 0, 2, 3).contiguous() # CTHW
video -= 0.5
return video
def preprocess(videos, target_resolution=224):
# we should tras videos in [0-1] [b c t h w] as th.float
# -> videos in {0, ..., 255} [b t h w c] as np.uint8 array
videos = einops.rearrange(videos, 'b c t h w -> b t h w c')
videos = (videos*255).numpy().astype(np.uint8)
b, t, h, w, c = videos.shape
videos = torch.from_numpy(videos)
videos = torch.stack([preprocess_single(video, target_resolution) for video in videos])
return videos * 2 # [-0.5, 0.5] -> [-1, 1]
def get_fvd_logits(videos, i3d, device, bs=10):
videos = preprocess(videos)
embeddings = get_logits(i3d, videos, device, bs=10)
return embeddings
# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161
def _symmetric_matrix_square_root(mat, eps=1e-10):
u, s, v = torch.svd(mat)
si = torch.where(s < eps, s, torch.sqrt(s))
return torch.matmul(torch.matmul(u, torch.diag(si)), v.t())
# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400
def trace_sqrt_product(sigma, sigma_v):
sqrt_sigma = _symmetric_matrix_square_root(sigma)
sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma))
return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
def cov(m, rowvar=False):
'''Estimate a covariance matrix given data.
Covariance indicates the level to which two variables vary together.
If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
then the covariance matrix element `C_{ij}` is the covariance of
`x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
Args:
m: A 1-D or 2-D array containing multiple variables and observations.
Each row of `m` represents a variable, and each column a single
observation of all those variables.
rowvar: If `rowvar` is True, then each row represents a
variable, with observations in the columns. Otherwise, the
relationship is transposed: each column represents a variable,
while the rows contain observations.
Returns:
The covariance matrix of the variables.
'''
if m.dim() > 2:
raise ValueError('m has more than 2 dimensions')
if m.dim() < 2:
m = m.view(1, -1)
if not rowvar and m.size(0) != 1:
m = m.t()
fact = 1.0 / (m.size(1) - 1) # unbiased estimate
m -= torch.mean(m, dim=1, keepdim=True)
mt = m.t() # if complex: mt = m.t().conj()
return fact * m.matmul(mt).squeeze()
def frechet_distance(x1, x2):
x1 = x1.flatten(start_dim=1)
x2 = x2.flatten(start_dim=1)
m, m_w = x1.mean(dim=0), x2.mean(dim=0)
sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False)
mean = torch.sum((m - m_w) ** 2)
if x1.shape[0]>1:
sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)
trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component
fd = trace + mean
else:
fd = np.real(mean)
return float(fd)
def get_logits(i3d, videos, device, bs=10):
# assert videos.shape[0] % 16 == 0
with torch.no_grad():
logits = []
for i in range(0, videos.shape[0], bs):
batch = videos[i:i + bs].to(device)
# logits.append(i3d.module.extract_features(batch)) # wrong
logits.append(i3d(batch)) # right
logits = torch.cat(logits, dim=0)
return logits
|