hma / common /calculate_fvd.py
LeroyWaa's picture
draft
246c106
raw
history blame
2.49 kB
# Code adapted from https://github.com/JunyaoHu/common_metrics_on_video_quality
import numpy as np
import torch
from tqdm import tqdm
def trans(x):
# if greyscale images add channel
if x.shape[-3] == 1:
x = x.repeat(1, 1, 3, 1, 1)
# permute BTCHW -> BCTHW
x = x.permute(0, 2, 1, 3, 4)
return x
def calculate_fvd(videos1, videos2, device="cuda", method='styleganv'):
if method == 'styleganv':
from .fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained
elif method == 'videogpt':
from .fvd.videogpt.fvd import load_i3d_pretrained
from .fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats
from .fvd.videogpt.fvd import frechet_distance
# videos [batch_size, timestamps, channel, h, w]
assert videos1.shape == videos2.shape
i3d = load_i3d_pretrained(device=device)
fvd_results = []
# support grayscale input, if grayscale -> channel*3
# BTCHW -> BCTHW
# videos -> [batch_size, channel, timestamps, h, w]
videos1 = trans(videos1)
videos2 = trans(videos2)
# fvd_results = {}
# for calculate FVD, each clip_timestamp must >= 10
for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)):
# print("clip_timestamp", clip_timestamp)
# get a video clip
# videos_clip [batch_size, channel, timestamps[:clip], h, w]
videos_clip1 = videos1[:, :, : clip_timestamp]
videos_clip2 = videos2[:, :, : clip_timestamp]
# get FVD features
feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)
feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)
# calculate FVD when timestamps[:clip]
fvd_results.append(frechet_distance(feats1, feats2))
return fvd_results[-1] # only the last step
# test code / using example
def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
device = torch.device("cuda")
# device = torch.device("cpu")
import json
result = calculate_fvd(videos1, videos2, device, method='videogpt')
print(json.dumps(result, indent=4))
result = calculate_fvd(videos1, videos2, device, method='styleganv')
print(json.dumps(result, indent=4))
if __name__ == "__main__":
main()