File size: 3,443 Bytes
62ef5f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# import os
# import cv2
# import glob
# import numpy as np
# from PIL import Image
# from scipy.linalg import sqrtm

# import torch
# from torch import nn
# import torchvision.transforms as transforms


# def PSNR(gt_imgs, pred_imgs):
#     """
#     Calculate PSNR for a batch of images
#     Args:
#         gt_imgs (list): list of ground truth images
#         pred_imgs (list): list of predicted images
#     Returns:
#         float: average PSNR score
#     """
#     total_psnr = 0
#     for idx, (gt, pred) in enumerate(zip(gt_imgs, pred_imgs)):
#         assert gt.shape == pred.shape, f"Shape mismatch at {idx}: GT and prediction"
#         total_psnr += cv2.PSNR(gt, pred)
#     return total_psnr / len(pred_imgs)


# class FrechetDistance:
#     def __init__(self, model_name="inception_v3", device="cpu"):
#         self.device = torch.device(device)
#         self.model = torch.hub.load("pytorch/vision:v0.10.0", model_name, pretrained=True)  # .to(self.device)
#         self.model.fc = nn.Identity()
#         print(self.model)
#         self.model.eval()

#         self.transform = transforms.Compose(
#             [
#                 transforms.ToTensor(),
#                 transforms.Resize(299),
#                 transforms.CenterCrop(299),
#                 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#             ]
#         )

#     # Return parts to calculate in FID and FVD
#     def _calculate_act(self, images1, images2):
#         images1 = [self.transform(img) for img in images1]
#         images2 = [self.transform(img) for img in images2]

#         images1 = torch.stack(images1).to(self.device)
#         images2 = torch.stack(images2).to(self.device)

#         # Get activations
#         act1 = self.model(images1).detach().numpy()
#         act2 = self.model(images2).detach().numpy()

#         return act1, act2

#     def calculate_fid(self, images1, images2):
#         act1, act2 = self._calculate_act(images1, images2)

#         # calculate mean and covariance statistics
#         mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
#         mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)

#         fid = (np.sum((mu1 - mu2) ** 2.0)) + np.trace(sigma1 + sigma2 - 2.0 * sqrtm(sigma1.dot(sigma2)))
#         return fid

#     def calculate_fvd(self, frames_list_folder1, frames_list_folder2, batch_size=2):
#         frames_list1 = glob.glob(os.path.join(frames_list_folder1, "*.png"))
#         frames_list2 = glob.glob(os.path.join(frames_list_folder2, "*.png"))

#         assert len(frames_list1) == len(frames_list2), "Number of frames in 2 folders must be equal"

#         all_act1, all_act2 = [], []
#         for i in range(0, len(frames_list1), batch_size):
#             batch1 = frames_list1[i : min(i + batch_size, len(frames_list1))]
#             batch2 = frames_list2[i : min(i + batch_size, len(frames_list1))]

#             img1 = [Image.open(img) for img in batch1]
#             img2 = [Image.open(img) for img in batch2]

#             act1, act2 = self._calculate_act(img1, img2)

#             all_act1.append(act1)
#             all_act2.append(act2)

#         all_act1 = np.concatenate(all_act1, axis=0)
#         all_act2 = np.concatenate(all_act2, axis=0)
#         print(all_act1.shape)
#         print(all_act1.shape)
#         fid = self.calculate_fid(all_act1, all_act2)

#         return np.sqrt(fid)