Spaces:
Sleeping
Sleeping
# import os | |
# import cv2 | |
# import glob | |
# import numpy as np | |
# from PIL import Image | |
# from scipy.linalg import sqrtm | |
# import torch | |
# from torch import nn | |
# import torchvision.transforms as transforms | |
# def PSNR(gt_imgs, pred_imgs): | |
# """ | |
# Calculate PSNR for a batch of images | |
# Args: | |
# gt_imgs (list): list of ground truth images | |
# pred_imgs (list): list of predicted images | |
# Returns: | |
# float: average PSNR score | |
# """ | |
# total_psnr = 0 | |
# for idx, (gt, pred) in enumerate(zip(gt_imgs, pred_imgs)): | |
# assert gt.shape == pred.shape, f"Shape mismatch at {idx}: GT and prediction" | |
# total_psnr += cv2.PSNR(gt, pred) | |
# return total_psnr / len(pred_imgs) | |
# class FrechetDistance: | |
# def __init__(self, model_name="inception_v3", device="cpu"): | |
# self.device = torch.device(device) | |
# self.model = torch.hub.load("pytorch/vision:v0.10.0", model_name, pretrained=True) # .to(self.device) | |
# self.model.fc = nn.Identity() | |
# print(self.model) | |
# self.model.eval() | |
# self.transform = transforms.Compose( | |
# [ | |
# transforms.ToTensor(), | |
# transforms.Resize(299), | |
# transforms.CenterCrop(299), | |
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
# ] | |
# ) | |
# # Return parts to calculate in FID and FVD | |
# def _calculate_act(self, images1, images2): | |
# images1 = [self.transform(img) for img in images1] | |
# images2 = [self.transform(img) for img in images2] | |
# images1 = torch.stack(images1).to(self.device) | |
# images2 = torch.stack(images2).to(self.device) | |
# # Get activations | |
# act1 = self.model(images1).detach().numpy() | |
# act2 = self.model(images2).detach().numpy() | |
# return act1, act2 | |
# def calculate_fid(self, images1, images2): | |
# act1, act2 = self._calculate_act(images1, images2) | |
# # calculate mean and covariance statistics | |
# mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False) | |
# mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False) | |
# fid = (np.sum((mu1 - mu2) ** 2.0)) + np.trace(sigma1 + sigma2 - 2.0 * sqrtm(sigma1.dot(sigma2))) | |
# return fid | |
# def calculate_fvd(self, frames_list_folder1, frames_list_folder2, batch_size=2): | |
# frames_list1 = glob.glob(os.path.join(frames_list_folder1, "*.png")) | |
# frames_list2 = glob.glob(os.path.join(frames_list_folder2, "*.png")) | |
# assert len(frames_list1) == len(frames_list2), "Number of frames in 2 folders must be equal" | |
# all_act1, all_act2 = [], [] | |
# for i in range(0, len(frames_list1), batch_size): | |
# batch1 = frames_list1[i : min(i + batch_size, len(frames_list1))] | |
# batch2 = frames_list2[i : min(i + batch_size, len(frames_list1))] | |
# img1 = [Image.open(img) for img in batch1] | |
# img2 = [Image.open(img) for img in batch2] | |
# act1, act2 = self._calculate_act(img1, img2) | |
# all_act1.append(act1) | |
# all_act2.append(act2) | |
# all_act1 = np.concatenate(all_act1, axis=0) | |
# all_act2 = np.concatenate(all_act2, axis=0) | |
# print(all_act1.shape) | |
# print(all_act1.shape) | |
# fid = self.calculate_fid(all_act1, all_act2) | |
# return np.sqrt(fid) | |