File size: 6,860 Bytes
feb2918 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
r"""CLIP-IQA metric, proposed by
Exploring CLIP for Assessing the Look and Feel of Images.
Jianyi Wang Kelvin C.K. Chan Chen Change Loy.
AAAI 2023.
Ref url: https://github.com/IceClear/CLIP-IQA
Re-implmented by: Chaofeng Chen (https://github.com/chaofengc) with the following modification:
- We assemble multiple prompts to improve the results of clipiqa model.
"""
import torch
import torch.nn as nn
import sys
import pyiqa
from pyiqa.archs.arch_util import load_file_from_url
from pyiqa.archs.arch_util import load_pretrained_network
import clip
from .constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from .clip_model import load
default_model_urls = {
'clipiqa+': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/CLIP-IQA+_learned_prompts-603f3273.pth',
'clipiqa+_rn50_512': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/CLIPIQA+_RN50_512-89f5d940.pth',
'clipiqa+_vitL14_512': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/CLIPIQA+_ViTL14_512-e66488f2.pth',
}
class PromptLearner(nn.Module):
"""
Disclaimer:
This implementation follows exactly the official codes in: https://github.com/IceClear/CLIP-IQA. We have no idea why some tricks are implemented like this, which include
1. Using n_ctx prefix characters "X"
2. Appending extra "." at the end
3. Insert the original text embedding at the middle
"""
def __init__(self, clip_model, n_ctx=16) -> None:
super().__init__()
# For the following codes about prompts, we follow the official codes to get the same results
prompt_prefix = " ".join(["X"] * n_ctx) + ' '
init_prompts = [prompt_prefix + 'Good photo..', prompt_prefix + 'Bad photo..']
with torch.no_grad():
txt_token = clip.tokenize(init_prompts)
self.tokenized_prompts = txt_token
init_embedding = clip_model.token_embedding(txt_token)
init_ctx = init_embedding[:, 1: 1 + n_ctx]
self.ctx = nn.Parameter(init_ctx)
self.n_ctx = n_ctx
self.n_cls = len(init_prompts)
self.name_lens = [3, 3] # hard coded length, which does not include the extra "." at the end
self.register_buffer("token_prefix", init_embedding[:, :1, :]) # SOS
self.register_buffer("token_suffix", init_embedding[:, 1 + n_ctx:, :]) # CLS, EOS
def get_prompts_with_middel_class(self,):
ctx = self.ctx.to(self.token_prefix)
if ctx.dim() == 2:
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
half_n_ctx = self.n_ctx // 2
prompts = []
for i in range(self.n_cls):
name_len = self.name_lens[i]
prefix_i = self.token_prefix[i: i + 1, :, :]
class_i = self.token_suffix[i: i + 1, :name_len, :]
suffix_i = self.token_suffix[i: i + 1, name_len:, :]
ctx_i_half1 = ctx[i: i + 1, :half_n_ctx, :]
ctx_i_half2 = ctx[i: i + 1, half_n_ctx:, :]
prompt = torch.cat(
[
prefix_i, # (1, 1, dim)
ctx_i_half1, # (1, n_ctx//2, dim)
class_i, # (1, name_len, dim)
ctx_i_half2, # (1, n_ctx//2, dim)
suffix_i, # (1, *, dim)
],
dim=1,
)
prompts.append(prompt)
prompts = torch.cat(prompts, dim=0)
return prompts
def forward(self, clip_model):
prompts = self.get_prompts_with_middel_class()
# self.get_prompts_with_middel_class
x = prompts + clip_model.positional_embedding.type(clip_model.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = clip_model.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = clip_model.ln_final(x).type(clip_model.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), self.tokenized_prompts.argmax(dim=-1)] @ clip_model.text_projection
return x
class CLIPIQA(nn.Module):
def __init__(self,
model_type='clipiqa+_vitL14_512',
backbone='ViT-L/14',
pretrained=True,
pos_embedding=False,
) -> None:
super().__init__()
self.clip_model = [load(backbone, 'cpu')] # avoid saving clip weights
# Different from original paper, we assemble multiple prompts to improve performance
self.prompt_pairs = clip.tokenize([
'Good image', 'bad image',
'Sharp image', 'blurry image',
'sharp edges', 'blurry edges',
'High resolution image', 'low resolution image',
'Noise-free image', 'noisy image',
])
self.model_type = model_type
self.pos_embedding = pos_embedding
if 'clipiqa+' in model_type:
self.prompt_learner = PromptLearner(self.clip_model[0])
self.default_mean = torch.Tensor(OPENAI_CLIP_MEAN).view(1, 3, 1, 1)
self.default_std = torch.Tensor(OPENAI_CLIP_STD).view(1, 3, 1, 1)
for p in self.clip_model[0].parameters():
p.requires_grad = False
if pretrained and 'clipiqa+' in model_type:
if model_type == 'clipiqa+' and backbone == 'RN50':
self.prompt_learner.ctx.data = torch.load(load_file_from_url(default_model_urls['clipiqa+']))
elif model_type in default_model_urls.keys():
load_pretrained_network(self, default_model_urls[model_type], True, 'params')
else:
raise(f'No pretrained model for {model_type}')
def forward(self, x, multi=False, layer=-1):
# no need to preprocess image here
# as already image is already preprocessed
# x = (x - self.default_mean.to(x)) / self.default_std.to(x)
clip_model = self.clip_model[0].to(x)
if self.model_type == 'clipiqa':
prompts = self.prompt_pairs.to(x.device)
logits_per_image, logits_per_text, image_feature, token_feature = clip_model(x, prompts, pos_embedding=self.pos_embedding)
elif 'clipiqa+' in self.model_type:
# learned_prompt_feature = self.prompt_learner(clip_model)
learned_prompt_feature = 0
logits_per_image, logits_per_text, image_feature, token_feature = clip_model(
x, None, text_features=learned_prompt_feature, pos_embedding=self.pos_embedding)
# probs = logits_per_image.reshape(logits_per_image.shape[0], -1, 2).softmax(dim=-1)
# return probs[..., 0].mean(dim=1, keepdim=True), image_feature
return image_feature, token_feature
|