|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
from pytorch_fid.inception import InceptionV3 |
|
from pytorch_fid.fid_score import calculate_frechet_distance |
|
|
|
class FIDMetric: |
|
def __init__(self, device, dims=2048): |
|
self.device = device |
|
self.num_workers = 32 |
|
|
|
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] |
|
self.model = InceptionV3([block_idx]).to(device) |
|
self.model.eval() |
|
|
|
self.reset_metrics() |
|
|
|
def reset_metrics(self): |
|
self.x_pred = [] |
|
self.x_rec_pred = [] |
|
|
|
@torch.no_grad() |
|
def get_activates(self, x: torch.Tensor): |
|
pred = self.model(x)[0] |
|
|
|
|
|
if pred.size(2) != 1 or pred.size(3) != 1: |
|
pred = torch.nn.functional.adaptive_avg_pool2d(pred, output_size=(1, 1)) |
|
return pred.squeeze().cpu().numpy() |
|
|
|
def update(self, x: torch.Tensor, x_rec: torch.Tensor): |
|
""" |
|
Args: |
|
x (torch.Tensor): input tensor range from 0 to 1 |
|
x_rec (torch.Tensor): reconstructed tensor range from 0 to 1 |
|
""" |
|
self.x_pred.append(self.get_activates(x)) |
|
self.x_rec_pred.append(self.get_activates(x_rec)) |
|
|
|
def result(self): |
|
assert len(self.x_pred) != 0, "No data to compute FID" |
|
x = np.concatenate(self.x_pred, axis=0) |
|
x_rec = np.concatenate(self.x_rec_pred, axis=0) |
|
|
|
x_mu = np.mean(x, axis=0) |
|
x_sigma = np.cov(x, rowvar=False) |
|
|
|
x_rec_mu = np.mean(x_rec, axis=0) |
|
x_rec_sigma = np.cov(x_rec, rowvar=False) |
|
|
|
fid_score = calculate_frechet_distance(x_mu, x_sigma, x_rec_mu, x_rec_sigma) |
|
self.reset_metrics() |
|
return fid_score |