# import os # import cv2 # import glob # import numpy as np # from PIL import Image # from scipy.linalg import sqrtm # import torch # from torch import nn # import torchvision.transforms as transforms # def PSNR(gt_imgs, pred_imgs): # """ # Calculate PSNR for a batch of images # Args: # gt_imgs (list): list of ground truth images # pred_imgs (list): list of predicted images # Returns: # float: average PSNR score # """ # total_psnr = 0 # for idx, (gt, pred) in enumerate(zip(gt_imgs, pred_imgs)): # assert gt.shape == pred.shape, f"Shape mismatch at {idx}: GT and prediction" # total_psnr += cv2.PSNR(gt, pred) # return total_psnr / len(pred_imgs) # class FrechetDistance: # def __init__(self, model_name="inception_v3", device="cpu"): # self.device = torch.device(device) # self.model = torch.hub.load("pytorch/vision:v0.10.0", model_name, pretrained=True) # .to(self.device) # self.model.fc = nn.Identity() # print(self.model) # self.model.eval() # self.transform = transforms.Compose( # [ # transforms.ToTensor(), # transforms.Resize(299), # transforms.CenterCrop(299), # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ] # ) # # Return parts to calculate in FID and FVD # def _calculate_act(self, images1, images2): # images1 = [self.transform(img) for img in images1] # images2 = [self.transform(img) for img in images2] # images1 = torch.stack(images1).to(self.device) # images2 = torch.stack(images2).to(self.device) # # Get activations # act1 = self.model(images1).detach().numpy() # act2 = self.model(images2).detach().numpy() # return act1, act2 # def calculate_fid(self, images1, images2): # act1, act2 = self._calculate_act(images1, images2) # # calculate mean and covariance statistics # mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False) # mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False) # fid = (np.sum((mu1 - mu2) ** 2.0)) + np.trace(sigma1 + sigma2 - 2.0 * sqrtm(sigma1.dot(sigma2))) # return fid # def calculate_fvd(self, frames_list_folder1, frames_list_folder2, batch_size=2): # frames_list1 = glob.glob(os.path.join(frames_list_folder1, "*.png")) # frames_list2 = glob.glob(os.path.join(frames_list_folder2, "*.png")) # assert len(frames_list1) == len(frames_list2), "Number of frames in 2 folders must be equal" # all_act1, all_act2 = [], [] # for i in range(0, len(frames_list1), batch_size): # batch1 = frames_list1[i : min(i + batch_size, len(frames_list1))] # batch2 = frames_list2[i : min(i + batch_size, len(frames_list1))] # img1 = [Image.open(img) for img in batch1] # img2 = [Image.open(img) for img in batch2] # act1, act2 = self._calculate_act(img1, img2) # all_act1.append(act1) # all_act2.append(act2) # all_act1 = np.concatenate(all_act1, axis=0) # all_act2 = np.concatenate(all_act2, axis=0) # print(all_act1.shape) # print(all_act1.shape) # fid = self.calculate_fid(all_act1, all_act2) # return np.sqrt(fid)