SwinTExCo / src /metrics.py
duongttr's picture
Upload folder using huggingface_hub
62ef5f4
raw
history blame
3.44 kB
# import os
# import cv2
# import glob
# import numpy as np
# from PIL import Image
# from scipy.linalg import sqrtm
# import torch
# from torch import nn
# import torchvision.transforms as transforms
# def PSNR(gt_imgs, pred_imgs):
# """
# Calculate PSNR for a batch of images
# Args:
# gt_imgs (list): list of ground truth images
# pred_imgs (list): list of predicted images
# Returns:
# float: average PSNR score
# """
# total_psnr = 0
# for idx, (gt, pred) in enumerate(zip(gt_imgs, pred_imgs)):
# assert gt.shape == pred.shape, f"Shape mismatch at {idx}: GT and prediction"
# total_psnr += cv2.PSNR(gt, pred)
# return total_psnr / len(pred_imgs)
# class FrechetDistance:
# def __init__(self, model_name="inception_v3", device="cpu"):
# self.device = torch.device(device)
# self.model = torch.hub.load("pytorch/vision:v0.10.0", model_name, pretrained=True) # .to(self.device)
# self.model.fc = nn.Identity()
# print(self.model)
# self.model.eval()
# self.transform = transforms.Compose(
# [
# transforms.ToTensor(),
# transforms.Resize(299),
# transforms.CenterCrop(299),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ]
# )
# # Return parts to calculate in FID and FVD
# def _calculate_act(self, images1, images2):
# images1 = [self.transform(img) for img in images1]
# images2 = [self.transform(img) for img in images2]
# images1 = torch.stack(images1).to(self.device)
# images2 = torch.stack(images2).to(self.device)
# # Get activations
# act1 = self.model(images1).detach().numpy()
# act2 = self.model(images2).detach().numpy()
# return act1, act2
# def calculate_fid(self, images1, images2):
# act1, act2 = self._calculate_act(images1, images2)
# # calculate mean and covariance statistics
# mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
# mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
# fid = (np.sum((mu1 - mu2) ** 2.0)) + np.trace(sigma1 + sigma2 - 2.0 * sqrtm(sigma1.dot(sigma2)))
# return fid
# def calculate_fvd(self, frames_list_folder1, frames_list_folder2, batch_size=2):
# frames_list1 = glob.glob(os.path.join(frames_list_folder1, "*.png"))
# frames_list2 = glob.glob(os.path.join(frames_list_folder2, "*.png"))
# assert len(frames_list1) == len(frames_list2), "Number of frames in 2 folders must be equal"
# all_act1, all_act2 = [], []
# for i in range(0, len(frames_list1), batch_size):
# batch1 = frames_list1[i : min(i + batch_size, len(frames_list1))]
# batch2 = frames_list2[i : min(i + batch_size, len(frames_list1))]
# img1 = [Image.open(img) for img in batch1]
# img2 = [Image.open(img) for img in batch2]
# act1, act2 = self._calculate_act(img1, img2)
# all_act1.append(act1)
# all_act2.append(act2)
# all_act1 = np.concatenate(all_act1, axis=0)
# all_act2 = np.concatenate(all_act2, axis=0)
# print(all_act1.shape)
# print(all_act1.shape)
# fid = self.calculate_fid(all_act1, all_act2)
# return np.sqrt(fid)