# Code adapted from https://github.com/JunyaoHu/common_metrics_on_video_quality import numpy as np import torch from tqdm import tqdm def trans(x): # if greyscale images add channel if x.shape[-3] == 1: x = x.repeat(1, 1, 3, 1, 1) # permute BTCHW -> BCTHW x = x.permute(0, 2, 1, 3, 4) return x def calculate_fvd(videos1, videos2, device="cuda", method='styleganv'): if method == 'styleganv': from .fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained elif method == 'videogpt': from .fvd.videogpt.fvd import load_i3d_pretrained from .fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats from .fvd.videogpt.fvd import frechet_distance # videos [batch_size, timestamps, channel, h, w] assert videos1.shape == videos2.shape i3d = load_i3d_pretrained(device=device) fvd_results = [] # support grayscale input, if grayscale -> channel*3 # BTCHW -> BCTHW # videos -> [batch_size, channel, timestamps, h, w] videos1 = trans(videos1) videos2 = trans(videos2) # fvd_results = {} # for calculate FVD, each clip_timestamp must >= 10 for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)): # print("clip_timestamp", clip_timestamp) # get a video clip # videos_clip [batch_size, channel, timestamps[:clip], h, w] videos_clip1 = videos1[:, :, : clip_timestamp] videos_clip2 = videos2[:, :, : clip_timestamp] # get FVD features feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device) feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device) # calculate FVD when timestamps[:clip] fvd_results.append(frechet_distance(feats1, feats2)) return fvd_results[-1] # only the last step # test code / using example def main(): NUMBER_OF_VIDEOS = 8 VIDEO_LENGTH = 50 CHANNEL = 3 SIZE = 64 videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) device = torch.device("cuda") # device = torch.device("cpu") import json result = calculate_fvd(videos1, videos2, device, method='videogpt') print(json.dumps(result, indent=4)) result = calculate_fvd(videos1, videos2, device, method='styleganv') print(json.dumps(result, indent=4)) if __name__ == "__main__": main()