Spaces:
Sleeping
Sleeping
File size: 4,944 Bytes
03f6091 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from typing import *
import random
import torch
from torch import Tensor
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_info
import clip
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
def read_image(imgid):
from pathlib import Path
vanilla = Path(imgid)
fixed = Path(f"data_en/images/{imgid}")
assert not (vanilla.exists() == fixed.exists()) # 両者共に存在/不在だと困る
path = vanilla if vanilla.exists() else fixed
return Image.open(path).convert("RGB")
class MID():
def __init__(self,device="cuda"):
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=device)
self.device = device
def batchify(self, targets, batch_size):
return [targets[i:i+batch_size] for i in range(0,len(targets),batch_size)]
def __call__(self, mt_list, refs_list, img_list, no_ref=False):
B = 32
mt_list, refs_list, img_list = [self.batchify(x,B) for x in [mt_list,refs_list,img_list]]
scores = []
assert len(mt_list) == len(refs_list) == len(img_list)
for mt, refs, imgs in (pbar:= tqdm(zip(mt_list,refs_list, img_list),total=len(mt_list))):
pbar.set_description(f"MID")
imgs = [read_image(imgid) for imgid in imgs]
refs_token = []
for ref_list in refs:
refs_token.append([clip.tokenize(ref,truncate=True).to(self.device) for ref in ref_list])
refs = torch.cat([torch.cat(ref,dim=0) for ref in refs_token], dim=0)
mts = clip.tokenize([x for x in mt],truncate=True).to(self.device)
imgs = torch.cat([self.clip_preprocess(img).unsqueeze(0) for img in imgs],dim=0).to(self.device)
imgs = self.clip.encode_image(imgs)
mts = self.clip.encode_text(mts)
refs = self.clip.encode_text(refs)
compute_pmi(imgs,refs,mts)
return scores
def log_det(X):
eigenvalues = X.svd()[1]
return eigenvalues.log().sum()
def robust_inv(x, eps=0):
Id = torch.eye(x.shape[0]).to(x.device)
return (x + eps * Id).inverse()
def exp_smd(a, b, reduction=True):
a_inv = robust_inv(a)
if reduction:
assert b.shape[0] == b.shape[1]
return (a_inv @ b).trace()
else:
return (b @ a_inv @ b.t()).diag()
def compute_pmi(x: Tensor, y: Tensor, x0: Tensor, limit: int = 30000,
reduction: bool = True, full: bool = False) -> Tensor:
r"""
A numerical stable version of the MID score.
Args:
x (Tensor): features for real samples
y (Tensor): features for text samples
x0 (Tensor): features for fake samples
limit (int): limit the number of samples
reduction (bool): returns the expectation of PMI if true else sample-wise results
full (bool): use full samples from real images
Returns:
Scalar value of the mutual information divergence between the sets.
"""
N = x.shape[0]
excess = N - limit
if 0 < excess:
if not full:
x = x[:-excess]
y = y[:-excess]
x0 = x0[:-excess]
N = x.shape[0]
M = x0.shape[0]
assert N >= x.shape[1], "not full rank for matrix inversion!"
if x.shape[0] < 30000:
rank_zero_info("if it underperforms, please consider to use "
"the epsilon of 5e-4 or something else.")
z = torch.cat([x, y], dim=-1)
z0 = torch.cat([x0, y[:x0.shape[0]]], dim=-1)
x_mean = x.mean(dim=0, keepdim=True)
y_mean = y.mean(dim=0, keepdim=True)
z_mean = torch.cat([x_mean, y_mean], dim=-1)
x0_mean = x0.mean(dim=0, keepdim=True)
z0_mean = z0.mean(dim=0, keepdim=True)
X = (x - x_mean).t() @ (x - x_mean) / (N - 1)
Y = (y - y_mean).t() @ (y - y_mean) / (N - 1)
Z = (z - z_mean).t() @ (z - z_mean) / (N - 1)
X0 = (x0 - x_mean).t() @ (x0 - x_mean) / (M - 1) # use the reference mean
Z0 = (z0 - z_mean).t() @ (z0 - z_mean) / (M - 1) # use the reference mean
alternative_comp = False
# notice that it may have numerical unstability. we don't use this.
if alternative_comp:
def factorized_cov(x, m):
N = x.shape[0]
return (x.t() @ x - N * m.t() @ m) / (N - 1)
X0 = factorized_cov(x0, x_mean)
Z0 = factorized_cov(z0, z_mean)
# assert double precision
for _ in [X, Y, Z, X0, Z0]:
assert _.dtype == torch.float64
# Expectation of PMI
mi = (log_det(X) + log_det(Y) - log_det(Z)) / 2
rank_zero_info(f"MI of real images: {mi:.4f}")
# Squared Mahalanobis Distance terms
if reduction:
smd = (exp_smd(X, X0) + exp_smd(Y, Y) - exp_smd(Z, Z0)) / 2
else:
smd = (exp_smd(X, x0 - x_mean, False) + exp_smd(Y, y - y_mean, False)
- exp_smd(Z, z0 - z_mean, False)) / 2
mi = mi.unsqueeze(0) # for broadcasting
return mi + smd
|