Spaces:
Runtime error
Runtime error
# pip install git+https://github.com/openai/CLIP.git | |
import pdb | |
from PIL import Image | |
import numpy as np | |
import torch | |
import torchvision.transforms as transforms | |
import clip | |
from .fid import compute_fid | |
def img_preprocess_clip(img_np): | |
x = Image.fromarray(img_np.astype(np.uint8)).convert("RGB") | |
T = transforms.Compose( | |
[ | |
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), | |
transforms.CenterCrop(224), | |
] | |
) | |
return np.asarray(T(x)).clip(0, 255).astype(np.uint8) | |
class CLIP_fx: | |
def __init__(self, name="ViT-B/32", device="cuda"): | |
self.model, _ = clip.load(name, device=device) | |
self.model.eval() | |
self.name = "clip_" + name.lower().replace("-", "_").replace("/", "_") | |
def __call__(self, img_t): | |
img_x = img_t / 255.0 | |
T_norm = transforms.Normalize( | |
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) | |
) | |
img_x = T_norm(img_x) | |
assert torch.is_tensor(img_x) | |
if len(img_x.shape) == 3: | |
img_x = img_x.unsqueeze(0) | |
B, C, H, W = img_x.shape | |
with torch.no_grad(): | |
z = self.model.encode_image(img_x) | |
return z | |