Spaces:
Runtime error
Runtime error
from PIL import Image | |
import numpy as np | |
import torch | |
from .features import build_feature_extractor, get_reference_statistics | |
from .fid import get_batch_features, fid_from_feats | |
from .resize import build_resizer | |
""" | |
A helper class that allowing adding the images one batch at a time. | |
""" | |
class CleanFID: | |
def __init__(self, mode="clean", model_name="inception_v3", device="cuda"): | |
self.real_features = [] | |
self.gen_features = [] | |
self.mode = mode | |
self.device = device | |
if model_name == "inception_v3": | |
self.feat_model = build_feature_extractor(mode, device) | |
self.fn_resize = build_resizer(mode) | |
elif model_name == "clip_vit_b_32": | |
from .clip_features import CLIP_fx, img_preprocess_clip | |
clip_fx = CLIP_fx("ViT-B/32") | |
self.feat_model = clip_fx | |
self.fn_resize = img_preprocess_clip | |
""" | |
Funtion that takes an image (PIL.Image or np.array or torch.tensor) | |
and returns the corresponding feature embedding vector. | |
The image x is expected to be in range [0, 255] | |
""" | |
def compute_features(self, x): | |
# if x is a PIL Image | |
if isinstance(x, Image.Image): | |
x_np = np.array(x) | |
x_np_resized = self.fn_resize(x_np) | |
x_t = torch.tensor(x_np_resized.transpose((2, 0, 1))).unsqueeze(0) | |
x_feat = get_batch_features(x_t, self.feat_model, self.device) | |
elif isinstance(x, np.ndarray): | |
x_np_resized = self.fn_resize(x) | |
x_t = ( | |
torch.tensor(x_np_resized.transpose((2, 0, 1))) | |
.unsqueeze(0) | |
.to(self.device) | |
) | |
# normalization happens inside the self.feat_model, expected image range here is [0,255] | |
x_feat = get_batch_features(x_t, self.feat_model, self.device) | |
elif isinstance(x, torch.Tensor): | |
# pdb.set_trace() | |
# add the batch dimension if x is passed in as C,H,W | |
if len(x.shape) == 3: | |
x = x.unsqueeze(0) | |
b, c, h, w = x.shape | |
# convert back to np array and resize | |
l_x_np_resized = [] | |
for _ in range(b): | |
x_np = x[_].cpu().numpy().transpose((1, 2, 0)) | |
l_x_np_resized.append(self.fn_resize(x_np)[None,]) | |
x_np_resized = np.concatenate(l_x_np_resized) | |
x_t = torch.tensor(x_np_resized.transpose((0, 3, 1, 2))).to(self.device) | |
# normalization happens inside the self.feat_model, expected image range here is [0,255] | |
x_feat = get_batch_features(x_t, self.feat_model, self.device) | |
else: | |
raise ValueError("image type could not be inferred") | |
return x_feat | |
""" | |
Extract the faetures from x and add to the list of reference real images | |
""" | |
def add_real_images(self, x): | |
x_feat = self.compute_features(x) | |
self.real_features.append(x_feat) | |
""" | |
Extract the faetures from x and add to the list of generated images | |
""" | |
def add_gen_images(self, x): | |
x_feat = self.compute_features(x) | |
self.gen_features.append(x_feat) | |
""" | |
Compute FID between the real and generated images added so far | |
""" | |
def calculate_fid(self, verbose=True): | |
feats1 = np.concatenate(self.real_features) | |
feats2 = np.concatenate(self.gen_features) | |
if verbose: | |
print(f"# real images = {feats1.shape[0]}") | |
print(f"# generated images = {feats2.shape[0]}") | |
return fid_from_feats(feats1, feats2) | |
""" | |
Remove the real image features added so far | |
""" | |
def reset_real_features(self): | |
self.real_features = [] | |
""" | |
Remove the generated image features added so far | |
""" | |
def reset_gen_features(self): | |
self.gen_features = [] | |