mcding
published version
ad552d8
from PIL import Image
import numpy as np
import torch
from .features import build_feature_extractor, get_reference_statistics
from .fid import get_batch_features, fid_from_feats
from .resize import build_resizer
"""
A helper class that allowing adding the images one batch at a time.
"""
class CleanFID:
def __init__(self, mode="clean", model_name="inception_v3", device="cuda"):
self.real_features = []
self.gen_features = []
self.mode = mode
self.device = device
if model_name == "inception_v3":
self.feat_model = build_feature_extractor(mode, device)
self.fn_resize = build_resizer(mode)
elif model_name == "clip_vit_b_32":
from .clip_features import CLIP_fx, img_preprocess_clip
clip_fx = CLIP_fx("ViT-B/32")
self.feat_model = clip_fx
self.fn_resize = img_preprocess_clip
"""
Funtion that takes an image (PIL.Image or np.array or torch.tensor)
and returns the corresponding feature embedding vector.
The image x is expected to be in range [0, 255]
"""
def compute_features(self, x):
# if x is a PIL Image
if isinstance(x, Image.Image):
x_np = np.array(x)
x_np_resized = self.fn_resize(x_np)
x_t = torch.tensor(x_np_resized.transpose((2, 0, 1))).unsqueeze(0)
x_feat = get_batch_features(x_t, self.feat_model, self.device)
elif isinstance(x, np.ndarray):
x_np_resized = self.fn_resize(x)
x_t = (
torch.tensor(x_np_resized.transpose((2, 0, 1)))
.unsqueeze(0)
.to(self.device)
)
# normalization happens inside the self.feat_model, expected image range here is [0,255]
x_feat = get_batch_features(x_t, self.feat_model, self.device)
elif isinstance(x, torch.Tensor):
# pdb.set_trace()
# add the batch dimension if x is passed in as C,H,W
if len(x.shape) == 3:
x = x.unsqueeze(0)
b, c, h, w = x.shape
# convert back to np array and resize
l_x_np_resized = []
for _ in range(b):
x_np = x[_].cpu().numpy().transpose((1, 2, 0))
l_x_np_resized.append(self.fn_resize(x_np)[None,])
x_np_resized = np.concatenate(l_x_np_resized)
x_t = torch.tensor(x_np_resized.transpose((0, 3, 1, 2))).to(self.device)
# normalization happens inside the self.feat_model, expected image range here is [0,255]
x_feat = get_batch_features(x_t, self.feat_model, self.device)
else:
raise ValueError("image type could not be inferred")
return x_feat
"""
Extract the faetures from x and add to the list of reference real images
"""
def add_real_images(self, x):
x_feat = self.compute_features(x)
self.real_features.append(x_feat)
"""
Extract the faetures from x and add to the list of generated images
"""
def add_gen_images(self, x):
x_feat = self.compute_features(x)
self.gen_features.append(x_feat)
"""
Compute FID between the real and generated images added so far
"""
def calculate_fid(self, verbose=True):
feats1 = np.concatenate(self.real_features)
feats2 = np.concatenate(self.gen_features)
if verbose:
print(f"# real images = {feats1.shape[0]}")
print(f"# generated images = {feats2.shape[0]}")
return fid_from_feats(feats1, feats2)
"""
Remove the real image features added so far
"""
def reset_real_features(self):
self.real_features = []
"""
Remove the generated image features added so far
"""
def reset_gen_features(self):
self.gen_features = []