refine promptiqa.py
Browse files- PromptIQA/models/gc_loss.py +0 -99
- PromptIQA/models/monet_IPF.py +0 -397
- PromptIQA/models/monet_test.py +0 -389
- PromptIQA/models/monet_wo_prompt.py +0 -392
- PromptIQA/models/{monet.py → promptiqa.py} +2 -84
- PromptIQA/models/vit_base.py +0 -402
- PromptIQA/models/vit_large.py +0 -405
- PromptIQA/run_promptIQA copy.py +0 -109
- PromptIQA/run_promptIQA.py +2 -2
- PromptIQA/t.py +0 -2
- PromptIQA/test.py +0 -429
- PromptIQA/test.sh +0 -9
- best_model.pth.tar +3 -0
- get_examplt.py +0 -27
PromptIQA/models/gc_loss.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
import torch.nn as nn
|
2 |
-
import torch
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
|
6 |
-
class GC_Loss(nn.Module):
|
7 |
-
def __init__(self, queue_len=800, alpha=0.5, beta=0.5, gamma=1):
|
8 |
-
super(GC_Loss, self).__init__()
|
9 |
-
self.pred_queue = list()
|
10 |
-
self.gt_queue = list()
|
11 |
-
self.queue_len = 0
|
12 |
-
|
13 |
-
self.queue_max_len = queue_len
|
14 |
-
print('The queue length is: ', self.queue_max_len)
|
15 |
-
self.mse = torch.nn.MSELoss().cuda()
|
16 |
-
|
17 |
-
self.alpha, self.beta, self.gamma = alpha, beta, gamma
|
18 |
-
|
19 |
-
def consistency(self, pred_data, gt_data):
|
20 |
-
pred_one_batch, pred_queue = pred_data
|
21 |
-
gt_one_batch, gt_queue = gt_data
|
22 |
-
|
23 |
-
pred_mean = torch.mean(pred_queue)
|
24 |
-
gt_mean = torch.mean(gt_queue)
|
25 |
-
|
26 |
-
diff_pred = pred_one_batch - pred_mean
|
27 |
-
diff_gt = gt_one_batch - gt_mean
|
28 |
-
|
29 |
-
x1 = torch.sum(torch.mul(diff_pred, diff_gt))
|
30 |
-
x2_1 = torch.sqrt(torch.sum(torch.mul(diff_pred, diff_pred)))
|
31 |
-
x2_2 = torch.sqrt(torch.sum(torch.mul(diff_gt, diff_gt)))
|
32 |
-
|
33 |
-
return x1 / (x2_1 * x2_2)
|
34 |
-
|
35 |
-
def ppra(self, x):
|
36 |
-
"""
|
37 |
-
Pairwise Preference-based Rank Approximation
|
38 |
-
"""
|
39 |
-
|
40 |
-
x_bar, x_std = torch.mean(x), torch.std(x)
|
41 |
-
x_n = (x - x_bar) / x_std
|
42 |
-
x_n_T = x_n.reshape(-1, 1)
|
43 |
-
|
44 |
-
rank_x = x_n_T - x_n_T.transpose(1, 0)
|
45 |
-
rank_x = torch.sum(1 / 2 * (1 + torch.erf(rank_x / torch.sqrt(torch.tensor(2, dtype=torch.float)))), dim=1)
|
46 |
-
|
47 |
-
return rank_x
|
48 |
-
|
49 |
-
@torch.no_grad()
|
50 |
-
def enqueue(self, pred, gt):
|
51 |
-
bs = pred.shape[0]
|
52 |
-
self.queue_len = self.queue_len + bs
|
53 |
-
|
54 |
-
self.pred_queue = self.pred_queue + pred.tolist()
|
55 |
-
self.gt_queue = self.gt_queue + gt.cpu().detach().numpy().tolist()
|
56 |
-
|
57 |
-
if self.queue_len > self.queue_max_len:
|
58 |
-
self.dequeue(self.queue_len - self.queue_max_len)
|
59 |
-
self.queue_len = self.queue_max_len
|
60 |
-
|
61 |
-
@torch.no_grad()
|
62 |
-
def dequeue(self, n):
|
63 |
-
for _ in range(n):
|
64 |
-
self.pred_queue.pop(0)
|
65 |
-
self.gt_queue.pop(0)
|
66 |
-
|
67 |
-
def clear(self):
|
68 |
-
self.pred_queue.clear()
|
69 |
-
self.gt_queue.clear()
|
70 |
-
|
71 |
-
def forward(self, x, y):
|
72 |
-
x_queue = self.pred_queue.copy()
|
73 |
-
y_queue = self.gt_queue.copy()
|
74 |
-
|
75 |
-
x_all = torch.cat((x, torch.tensor(x_queue).cuda()), dim=0)
|
76 |
-
y_all = torch.cat((y, torch.tensor(y_queue).cuda()), dim=0)
|
77 |
-
|
78 |
-
PLCC = self.consistency((x, x_all), (y, y_all))
|
79 |
-
PGC = 1 - PLCC
|
80 |
-
|
81 |
-
rank_x = self.ppra(x_all)
|
82 |
-
rank_y = self.ppra(y_all)
|
83 |
-
SROCC = self.consistency((rank_x[:x.shape[0]], rank_x), (rank_y[:y.shape[0]], rank_y))
|
84 |
-
SGC = 1 - SROCC
|
85 |
-
|
86 |
-
GC = (self.alpha * PGC + self.beta * SGC + self.gamma) * self.mse(x, y)
|
87 |
-
self.enqueue(x, y)
|
88 |
-
|
89 |
-
return GC
|
90 |
-
|
91 |
-
|
92 |
-
if __name__ == '__main__':
|
93 |
-
gc = GC_Loss().cuda()
|
94 |
-
x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float).cuda()
|
95 |
-
y = torch.tensor([6, 7, 8, 9, 15], dtype=torch.float).cuda()
|
96 |
-
|
97 |
-
res = gc(x, y)
|
98 |
-
|
99 |
-
print(res)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/models/monet_IPF.py
DELETED
@@ -1,397 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
The completion for Mean-opinion Network(MoNet)
|
3 |
-
"""
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import timm
|
7 |
-
|
8 |
-
from timm.models.vision_transformer import Block
|
9 |
-
from einops import rearrange
|
10 |
-
from itertools import combinations
|
11 |
-
|
12 |
-
from tqdm import tqdm
|
13 |
-
|
14 |
-
class Attention_Block(nn.Module):
|
15 |
-
def __init__(self, dim, drop=0.1):
|
16 |
-
super().__init__()
|
17 |
-
self.c_q = nn.Linear(dim, dim)
|
18 |
-
self.c_k = nn.Linear(dim, dim)
|
19 |
-
self.c_v = nn.Linear(dim, dim)
|
20 |
-
self.norm_fact = dim ** -0.5
|
21 |
-
self.softmax = nn.Softmax(dim=-1)
|
22 |
-
self.proj_drop = nn.Dropout(drop)
|
23 |
-
|
24 |
-
def forward(self, x):
|
25 |
-
_x = x
|
26 |
-
B, C, N = x.shape
|
27 |
-
q = self.c_q(x)
|
28 |
-
k = self.c_k(x)
|
29 |
-
v = self.c_v(x)
|
30 |
-
|
31 |
-
attn = q @ k.transpose(-2, -1) * self.norm_fact
|
32 |
-
attn = self.softmax(attn)
|
33 |
-
x = (attn @ v).transpose(1, 2).reshape(B, C, N)
|
34 |
-
x = self.proj_drop(x)
|
35 |
-
x = x + _x
|
36 |
-
return x
|
37 |
-
|
38 |
-
|
39 |
-
class Self_Attention(nn.Module):
|
40 |
-
""" Self attention Layer"""
|
41 |
-
|
42 |
-
def __init__(self, in_dim):
|
43 |
-
super(Self_Attention, self).__init__()
|
44 |
-
|
45 |
-
self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
46 |
-
self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
47 |
-
self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
48 |
-
self.gamma = nn.Parameter(torch.zeros(1))
|
49 |
-
|
50 |
-
self.softmax = nn.Softmax(dim=-1)
|
51 |
-
|
52 |
-
def forward(self, inFeature):
|
53 |
-
bs, C, w, h = inFeature.size()
|
54 |
-
|
55 |
-
proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
|
56 |
-
proj_key = self.kConv(inFeature).view(bs, -1, w * h)
|
57 |
-
energy = torch.bmm(proj_query, proj_key)
|
58 |
-
attention = self.softmax(energy)
|
59 |
-
proj_value = self.vConv(inFeature).view(bs, -1, w * h)
|
60 |
-
|
61 |
-
out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
|
62 |
-
out = out.view(bs, C, w, h)
|
63 |
-
|
64 |
-
out = self.gamma * out + inFeature
|
65 |
-
|
66 |
-
return out
|
67 |
-
|
68 |
-
|
69 |
-
class MAL(nn.Module):
|
70 |
-
"""
|
71 |
-
Multi-view Attention Learning (MAL) module
|
72 |
-
"""
|
73 |
-
|
74 |
-
def __init__(self, in_dim=768, feature_num=4, feature_size=28):
|
75 |
-
super().__init__()
|
76 |
-
|
77 |
-
self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
|
78 |
-
self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention
|
79 |
-
|
80 |
-
# Self attention module for each input feature
|
81 |
-
self.attention_module = nn.ModuleList()
|
82 |
-
for _ in range(feature_num):
|
83 |
-
self.attention_module.append(Self_Attention(in_dim))
|
84 |
-
|
85 |
-
self.feature_num = feature_num
|
86 |
-
self.in_dim = in_dim
|
87 |
-
|
88 |
-
def forward(self, features):
|
89 |
-
feature = torch.tensor([]).cuda()
|
90 |
-
for index, _ in enumerate(features):
|
91 |
-
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
|
92 |
-
features = feature
|
93 |
-
|
94 |
-
input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28
|
95 |
-
bs, _, _ = input_tensor.shape # [2, 3072, 784]
|
96 |
-
|
97 |
-
in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim,
|
98 |
-
c=self.feature_num) # bs, 768, 28 * 28 * feature_num
|
99 |
-
feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768
|
100 |
-
|
101 |
-
in_channel = input_tensor.permute(0, 2, 1).contiguous() # bs, 28 * 28, 768 * feature_num
|
102 |
-
channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28
|
103 |
-
|
104 |
-
weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim,
|
105 |
-
c=self.feature_num) + channel_weight_sum.permute(0, 2, 1).contiguous()) / 2 # [2, 3072, 784]
|
106 |
-
|
107 |
-
weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1)
|
108 |
-
|
109 |
-
return weight_sum_res # bs, 768, 28 * 28
|
110 |
-
|
111 |
-
|
112 |
-
class SaveOutput:
|
113 |
-
def __init__(self):
|
114 |
-
self.outputs = []
|
115 |
-
|
116 |
-
def __call__(self, module, module_in, module_out):
|
117 |
-
self.outputs.append(module_out)
|
118 |
-
|
119 |
-
def clear(self):
|
120 |
-
self.outputs = []
|
121 |
-
|
122 |
-
# utils
|
123 |
-
@torch.no_grad()
|
124 |
-
def concat_all_gather(tensor):
|
125 |
-
"""
|
126 |
-
Performs all_gather operation on the provided tensors.
|
127 |
-
*** Warning ***: torch.distributed.all_gather has no gradient.
|
128 |
-
"""
|
129 |
-
tensors_gather = [
|
130 |
-
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
131 |
-
]
|
132 |
-
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
133 |
-
|
134 |
-
output = torch.cat(tensors_gather, dim=0)
|
135 |
-
return output
|
136 |
-
class Attention(nn.Module):
|
137 |
-
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
138 |
-
super().__init__()
|
139 |
-
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
140 |
-
self.num_heads = num_heads
|
141 |
-
head_dim = dim // num_heads
|
142 |
-
self.scale = head_dim ** -0.5
|
143 |
-
|
144 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
145 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
146 |
-
self.proj = nn.Linear(dim, dim)
|
147 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
148 |
-
|
149 |
-
def forward(self, x):
|
150 |
-
B, N, C = x.shape
|
151 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
152 |
-
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
153 |
-
|
154 |
-
attn = (q @ k.transpose(-2, -1)) * self.scale
|
155 |
-
attn = attn.softmax(dim=-1)
|
156 |
-
attn = self.attn_drop(attn)
|
157 |
-
|
158 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
159 |
-
x = self.proj(x)
|
160 |
-
x = self.proj_drop(x)
|
161 |
-
return x
|
162 |
-
import torch
|
163 |
-
from functools import partial
|
164 |
-
class MoNet(nn.Module):
|
165 |
-
def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
166 |
-
super().__init__()
|
167 |
-
self.img_size = img_size
|
168 |
-
self.input_size = img_size // patch_size
|
169 |
-
self.dim_mlp = dim_mlp
|
170 |
-
|
171 |
-
self.vit = timm.create_model('vit_base_patch8_224', pretrained=True)
|
172 |
-
self.vit.norm = nn.Identity()
|
173 |
-
self.vit.head = nn.Identity()
|
174 |
-
|
175 |
-
self.save_output = SaveOutput()
|
176 |
-
|
177 |
-
# Register Hooks
|
178 |
-
hook_handles = []
|
179 |
-
for layer in self.vit.modules():
|
180 |
-
if isinstance(layer, Block):
|
181 |
-
handle = layer.register_forward_hook(self.save_output)
|
182 |
-
hook_handles.append(handle)
|
183 |
-
|
184 |
-
self.MALs = nn.ModuleList()
|
185 |
-
for _ in range(3):
|
186 |
-
self.MALs.append(MAL())
|
187 |
-
|
188 |
-
# Image Quality Score Regression
|
189 |
-
self.fusion_mal = MAL(feature_num=3)
|
190 |
-
self.block = Block(dim_mlp, 12)
|
191 |
-
self.cnn = nn.Sequential(
|
192 |
-
nn.Conv2d(dim_mlp, 256, 5),
|
193 |
-
nn.BatchNorm2d(256),
|
194 |
-
nn.ReLU(inplace=True),
|
195 |
-
nn.AvgPool2d((2, 2)),
|
196 |
-
nn.Conv2d(256, 128, 3),
|
197 |
-
nn.BatchNorm2d(128),
|
198 |
-
nn.ReLU(inplace=True),
|
199 |
-
nn.AvgPool2d((2, 2)),
|
200 |
-
nn.Conv2d(128, 128, 3),
|
201 |
-
nn.BatchNorm2d(128),
|
202 |
-
nn.ReLU(inplace=True),
|
203 |
-
nn.AvgPool2d((3, 3)),
|
204 |
-
)
|
205 |
-
|
206 |
-
self.i_p_fusion = nn.Sequential(
|
207 |
-
Block(128, 4),
|
208 |
-
Block(128, 4),
|
209 |
-
Block(128, 4),
|
210 |
-
)
|
211 |
-
self.mlp = nn.Sequential(
|
212 |
-
nn.Linear(128, 64),
|
213 |
-
nn.GELU(),
|
214 |
-
nn.Linear(64, 128),
|
215 |
-
)
|
216 |
-
|
217 |
-
self.prompt_fusion = nn.Sequential(
|
218 |
-
Block(128, 4),
|
219 |
-
Block(128, 4),
|
220 |
-
Block(128, 4),
|
221 |
-
)
|
222 |
-
|
223 |
-
dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule
|
224 |
-
self.blocks = nn.Sequential(*[
|
225 |
-
Block(
|
226 |
-
dim=128, num_heads=4, mlp_ratio=4, qkv_bias=True, drop=0,
|
227 |
-
attn_drop=0, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
|
228 |
-
for i in range(8)])
|
229 |
-
self.norm = nn.LayerNorm(128)
|
230 |
-
|
231 |
-
self.score_block = nn.Sequential(
|
232 |
-
nn.Linear(128, 128 // 2),
|
233 |
-
nn.ReLU(),
|
234 |
-
nn.Dropout(drop),
|
235 |
-
nn.Linear(128 // 2, 1),
|
236 |
-
nn.Sigmoid()
|
237 |
-
)
|
238 |
-
|
239 |
-
self.prompt_feature = {}
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
@torch.no_grad()
|
244 |
-
def clear(self):
|
245 |
-
self.prompt_feature = {}
|
246 |
-
|
247 |
-
@torch.no_grad()
|
248 |
-
def inference(self, x, data_type):
|
249 |
-
prompt_feature = self.prompt_feature[data_type] # 1, n, 128
|
250 |
-
|
251 |
-
_x = self.vit(x)
|
252 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
253 |
-
self.save_output.outputs.clear()
|
254 |
-
|
255 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
256 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
257 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
|
258 |
-
|
259 |
-
# Different Opinion Features (DOF)
|
260 |
-
DOF = torch.tensor([]).cuda()
|
261 |
-
for index, _ in enumerate(self.MALs):
|
262 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
263 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
264 |
-
|
265 |
-
# Image Quality Score Regression
|
266 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
267 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
268 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
269 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
270 |
-
|
271 |
-
prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
|
272 |
-
prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128
|
273 |
-
|
274 |
-
fusion = self.blocks(torch.cat((img_feature, prompt_feature), dim=1))[:, 0, :] # bs, 2, 1
|
275 |
-
# fusion = self.norm(fusion)[:, 0, :]
|
276 |
-
# fusion = self.score_block(fusion)
|
277 |
-
|
278 |
-
# # iq_res = torch.mean(fusion, dim=1).view(-1)
|
279 |
-
# iq_res = fusion[:, 0].view(-1)
|
280 |
-
|
281 |
-
return fusion
|
282 |
-
|
283 |
-
@torch.no_grad()
|
284 |
-
def check_prompt(self, data_type):
|
285 |
-
return data_type in self.prompt_feature
|
286 |
-
|
287 |
-
@torch.no_grad()
|
288 |
-
def forward_prompt(self, x, score, data_type):
|
289 |
-
if data_type in self.prompt_feature:
|
290 |
-
return
|
291 |
-
_x = self.vit(x)
|
292 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
293 |
-
self.save_output.outputs.clear()
|
294 |
-
|
295 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
296 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
297 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
298 |
-
|
299 |
-
# Different Opinion Features (DOF)
|
300 |
-
DOF = torch.tensor([]).cuda()
|
301 |
-
for index, _ in enumerate(self.MALs):
|
302 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
303 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
304 |
-
|
305 |
-
# Image Quality Score Regression
|
306 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
307 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
308 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
309 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
310 |
-
|
311 |
-
# 分数线性变换为128维
|
312 |
-
# score_feature = self.score_projection(score) # bs, 128
|
313 |
-
score_feature = score.expand(-1, 128)
|
314 |
-
|
315 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
316 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
317 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
|
318 |
-
|
319 |
-
# print('Load Prompt For Testing.', funsion_feature.shape)
|
320 |
-
# self.prompt_feature = funsion_feature.clone()
|
321 |
-
self.prompt_feature[data_type] = funsion_feature.clone()
|
322 |
-
|
323 |
-
def forward(self, x, score):
|
324 |
-
_x = self.vit(x)
|
325 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
326 |
-
self.save_output.outputs.clear()
|
327 |
-
|
328 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
329 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
330 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
331 |
-
|
332 |
-
# Different Opinion Features (DOF)
|
333 |
-
DOF = torch.tensor([]).cuda()
|
334 |
-
for index, _ in enumerate(self.MALs):
|
335 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
336 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
337 |
-
|
338 |
-
# Image Quality Score Regression
|
339 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
340 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
341 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
342 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
343 |
-
|
344 |
-
# 分数线性变换为128维
|
345 |
-
# score_feature = self.score_projection(score) # bs, 128
|
346 |
-
score_feature = score.expand(-1, 128) # bs, 128
|
347 |
-
|
348 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
349 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
350 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
|
351 |
-
funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
|
352 |
-
funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
|
353 |
-
|
354 |
-
fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1
|
355 |
-
fusion = self.norm(fusion)
|
356 |
-
fusion = self.score_block(fusion)
|
357 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
358 |
-
iq_res = fusion[:, 0].view(-1)
|
359 |
-
|
360 |
-
# differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
|
361 |
-
# differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
|
362 |
-
|
363 |
-
gt_res = score.view(-1)
|
364 |
-
# diff_gt_res = 1 - score.view(-1)
|
365 |
-
|
366 |
-
return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
|
367 |
-
|
368 |
-
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
369 |
-
x1 = save_output.outputs[block_index[0]][:, 1:]
|
370 |
-
x2 = save_output.outputs[block_index[1]][:, 1:]
|
371 |
-
x3 = save_output.outputs[block_index[2]][:, 1:]
|
372 |
-
x4 = save_output.outputs[block_index[3]][:, 1:]
|
373 |
-
x = torch.cat((x1, x2, x3, x4), dim=2)
|
374 |
-
return x
|
375 |
-
|
376 |
-
def expand(self, A):
|
377 |
-
A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
|
378 |
-
|
379 |
-
B = None
|
380 |
-
for index, i in enumerate(A_expanded):
|
381 |
-
rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
|
382 |
-
if B is None:
|
383 |
-
B = rmv
|
384 |
-
else:
|
385 |
-
B = torch.cat((B, rmv), dim=0)
|
386 |
-
|
387 |
-
return B
|
388 |
-
|
389 |
-
if __name__ == '__main__':
|
390 |
-
in_feature = torch.zeros((10, 3, 224, 224)).cuda()
|
391 |
-
gt_feature = torch.tensor([[0, 100, 1], [0, 100, 2], [0, 100, 3], [0, 100, 4], [0, 100, 5], [0, 100, 6], [0, 100, 7], [0, 100, 8], [0, 100, 9], [0, 100, 10]], dtype=torch.float).cuda()
|
392 |
-
model = MoNet().cuda()
|
393 |
-
|
394 |
-
iq_res, gt_res = model(in_feature, gt_feature)
|
395 |
-
|
396 |
-
print(iq_res.shape)
|
397 |
-
print(gt_res.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/models/monet_test.py
DELETED
@@ -1,389 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
The completion for Mean-opinion Network(MoNet)
|
3 |
-
"""
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import timm
|
7 |
-
|
8 |
-
from timm.models.vision_transformer import Block
|
9 |
-
from einops import rearrange
|
10 |
-
from itertools import combinations
|
11 |
-
|
12 |
-
from tqdm import tqdm
|
13 |
-
|
14 |
-
class Attention_Block(nn.Module):
|
15 |
-
def __init__(self, dim, drop=0.1):
|
16 |
-
super().__init__()
|
17 |
-
self.c_q = nn.Linear(dim, dim)
|
18 |
-
self.c_k = nn.Linear(dim, dim)
|
19 |
-
self.c_v = nn.Linear(dim, dim)
|
20 |
-
self.norm_fact = dim ** -0.5
|
21 |
-
self.softmax = nn.Softmax(dim=-1)
|
22 |
-
self.proj_drop = nn.Dropout(drop)
|
23 |
-
|
24 |
-
def forward(self, x):
|
25 |
-
_x = x
|
26 |
-
B, C, N = x.shape
|
27 |
-
q = self.c_q(x)
|
28 |
-
k = self.c_k(x)
|
29 |
-
v = self.c_v(x)
|
30 |
-
|
31 |
-
attn = q @ k.transpose(-2, -1) * self.norm_fact
|
32 |
-
attn = self.softmax(attn)
|
33 |
-
x = (attn @ v).transpose(1, 2).reshape(B, C, N)
|
34 |
-
x = self.proj_drop(x)
|
35 |
-
x = x + _x
|
36 |
-
return x
|
37 |
-
|
38 |
-
|
39 |
-
class Self_Attention(nn.Module):
|
40 |
-
""" Self attention Layer"""
|
41 |
-
|
42 |
-
def __init__(self, in_dim):
|
43 |
-
super(Self_Attention, self).__init__()
|
44 |
-
|
45 |
-
self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
46 |
-
self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
47 |
-
self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
48 |
-
self.gamma = nn.Parameter(torch.zeros(1))
|
49 |
-
|
50 |
-
self.softmax = nn.Softmax(dim=-1)
|
51 |
-
|
52 |
-
def forward(self, inFeature):
|
53 |
-
bs, C, w, h = inFeature.size()
|
54 |
-
|
55 |
-
proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
|
56 |
-
proj_key = self.kConv(inFeature).view(bs, -1, w * h)
|
57 |
-
energy = torch.bmm(proj_query, proj_key)
|
58 |
-
attention = self.softmax(energy)
|
59 |
-
proj_value = self.vConv(inFeature).view(bs, -1, w * h)
|
60 |
-
|
61 |
-
out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
|
62 |
-
out = out.view(bs, C, w, h)
|
63 |
-
|
64 |
-
out = self.gamma * out + inFeature
|
65 |
-
|
66 |
-
return out
|
67 |
-
|
68 |
-
|
69 |
-
class MAL(nn.Module):
|
70 |
-
"""
|
71 |
-
Multi-view Attention Learning (MAL) module
|
72 |
-
"""
|
73 |
-
|
74 |
-
def __init__(self, in_dim=768, feature_num=4, feature_size=28):
|
75 |
-
super().__init__()
|
76 |
-
|
77 |
-
self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
|
78 |
-
self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention
|
79 |
-
|
80 |
-
# Self attention module for each input feature
|
81 |
-
self.attention_module = nn.ModuleList()
|
82 |
-
for _ in range(feature_num):
|
83 |
-
self.attention_module.append(Self_Attention(in_dim))
|
84 |
-
|
85 |
-
self.feature_num = feature_num
|
86 |
-
self.in_dim = in_dim
|
87 |
-
|
88 |
-
def forward(self, features):
|
89 |
-
feature = torch.tensor([]).cuda()
|
90 |
-
for index, _ in enumerate(features):
|
91 |
-
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
|
92 |
-
features = feature
|
93 |
-
|
94 |
-
input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28
|
95 |
-
bs, _, _ = input_tensor.shape # [2, 3072, 784]
|
96 |
-
|
97 |
-
in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim,
|
98 |
-
c=self.feature_num) # bs, 768, 28 * 28 * feature_num
|
99 |
-
feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768
|
100 |
-
|
101 |
-
in_channel = input_tensor.permute(0, 2, 1).contiguous() # bs, 28 * 28, 768 * feature_num
|
102 |
-
channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28
|
103 |
-
|
104 |
-
weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim,
|
105 |
-
c=self.feature_num) + channel_weight_sum.permute(0, 2, 1).contiguous()) / 2 # [2, 3072, 784]
|
106 |
-
|
107 |
-
weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1)
|
108 |
-
|
109 |
-
return weight_sum_res # bs, 768, 28 * 28
|
110 |
-
|
111 |
-
|
112 |
-
class SaveOutput:
|
113 |
-
def __init__(self):
|
114 |
-
self.outputs = []
|
115 |
-
|
116 |
-
def __call__(self, module, module_in, module_out):
|
117 |
-
self.outputs.append(module_out)
|
118 |
-
|
119 |
-
def clear(self):
|
120 |
-
self.outputs = []
|
121 |
-
|
122 |
-
# utils
|
123 |
-
@torch.no_grad()
|
124 |
-
def concat_all_gather(tensor):
|
125 |
-
"""
|
126 |
-
Performs all_gather operation on the provided tensors.
|
127 |
-
*** Warning ***: torch.distributed.all_gather has no gradient.
|
128 |
-
"""
|
129 |
-
tensors_gather = [
|
130 |
-
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
131 |
-
]
|
132 |
-
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
133 |
-
|
134 |
-
output = torch.cat(tensors_gather, dim=0)
|
135 |
-
return output
|
136 |
-
class Attention(nn.Module):
|
137 |
-
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
138 |
-
super().__init__()
|
139 |
-
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
140 |
-
self.num_heads = num_heads
|
141 |
-
head_dim = dim // num_heads
|
142 |
-
self.scale = head_dim ** -0.5
|
143 |
-
|
144 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
145 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
146 |
-
self.proj = nn.Linear(dim, dim)
|
147 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
148 |
-
|
149 |
-
def forward(self, x):
|
150 |
-
B, N, C = x.shape
|
151 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
152 |
-
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
153 |
-
|
154 |
-
attn = (q @ k.transpose(-2, -1)) * self.scale
|
155 |
-
attn = attn.softmax(dim=-1)
|
156 |
-
attn = self.attn_drop(attn)
|
157 |
-
|
158 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
159 |
-
x = self.proj(x)
|
160 |
-
x = self.proj_drop(x)
|
161 |
-
return x
|
162 |
-
class MoNet(nn.Module):
|
163 |
-
def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
164 |
-
super().__init__()
|
165 |
-
self.img_size = img_size
|
166 |
-
self.input_size = img_size // patch_size
|
167 |
-
self.dim_mlp = dim_mlp
|
168 |
-
|
169 |
-
self.vit = timm.create_model('vit_base_patch8_224', pretrained=True)
|
170 |
-
self.vit.norm = nn.Identity()
|
171 |
-
self.vit.head = nn.Identity()
|
172 |
-
|
173 |
-
self.save_output = SaveOutput()
|
174 |
-
|
175 |
-
# Register Hooks
|
176 |
-
hook_handles = []
|
177 |
-
for layer in self.vit.modules():
|
178 |
-
if isinstance(layer, Block):
|
179 |
-
handle = layer.register_forward_hook(self.save_output)
|
180 |
-
hook_handles.append(handle)
|
181 |
-
|
182 |
-
self.MALs = nn.ModuleList()
|
183 |
-
for _ in range(3):
|
184 |
-
self.MALs.append(MAL())
|
185 |
-
|
186 |
-
# Image Quality Score Regression
|
187 |
-
self.fusion_mal = MAL(feature_num=3)
|
188 |
-
self.block = Block(dim_mlp, 12)
|
189 |
-
self.cnn = nn.Sequential(
|
190 |
-
nn.Conv2d(dim_mlp, 256, 5),
|
191 |
-
nn.BatchNorm2d(256),
|
192 |
-
nn.ReLU(inplace=True),
|
193 |
-
nn.AvgPool2d((2, 2)),
|
194 |
-
nn.Conv2d(256, 128, 3),
|
195 |
-
nn.BatchNorm2d(128),
|
196 |
-
nn.ReLU(inplace=True),
|
197 |
-
nn.AvgPool2d((2, 2)),
|
198 |
-
nn.Conv2d(128, 128, 3),
|
199 |
-
nn.BatchNorm2d(128),
|
200 |
-
nn.ReLU(inplace=True),
|
201 |
-
nn.AvgPool2d((3, 3)),
|
202 |
-
)
|
203 |
-
|
204 |
-
# self.score_projection = nn.Sequential(
|
205 |
-
# nn.Linear(1, 64),
|
206 |
-
# nn.GELU(),
|
207 |
-
# nn.Linear(64, 128),
|
208 |
-
# )
|
209 |
-
|
210 |
-
# self.i_p_fusion = nn.Sequential(
|
211 |
-
# Block(128, 8),
|
212 |
-
# Block(128, 8),
|
213 |
-
# Block(128, 8),
|
214 |
-
# )
|
215 |
-
self.i_p_fusion = nn.Sequential(
|
216 |
-
Block(128, 4),
|
217 |
-
Block(128, 4),
|
218 |
-
Block(128, 4),
|
219 |
-
)
|
220 |
-
self.mlp = nn.Sequential(
|
221 |
-
nn.Linear(128, 64),
|
222 |
-
nn.GELU(),
|
223 |
-
nn.Linear(64, 128),
|
224 |
-
)
|
225 |
-
|
226 |
-
self.score_block = nn.Sequential(
|
227 |
-
Block(128, 4),
|
228 |
-
Block(128, 4),
|
229 |
-
# Block(128, 4),
|
230 |
-
nn.Linear(128, 128 // 2),
|
231 |
-
nn.ReLU(),
|
232 |
-
nn.Dropout(drop),
|
233 |
-
nn.Linear(128 // 2, 1),
|
234 |
-
nn.Sigmoid()
|
235 |
-
)
|
236 |
-
|
237 |
-
# self.diff_block = nn.Sequential(
|
238 |
-
# Block(128, 8),
|
239 |
-
# Block(128, 8),
|
240 |
-
# Block(128, 8),
|
241 |
-
# nn.Linear(128, 64),
|
242 |
-
# nn.GELU(),
|
243 |
-
# nn.Linear(64, 1),
|
244 |
-
# )
|
245 |
-
self.prompt_feature = None
|
246 |
-
|
247 |
-
@torch.no_grad()
|
248 |
-
def clear(self):
|
249 |
-
self.prompt_feature = None
|
250 |
-
|
251 |
-
@torch.no_grad()
|
252 |
-
def inference(self, x):
|
253 |
-
prompt_feature = self.prompt_feature # 1, n, 128
|
254 |
-
|
255 |
-
_x = self.vit(x)
|
256 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
257 |
-
self.save_output.outputs.clear()
|
258 |
-
|
259 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
260 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
261 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
|
262 |
-
|
263 |
-
# Different Opinion Features (DOF)
|
264 |
-
DOF = torch.tensor([]).cuda()
|
265 |
-
for index, _ in enumerate(self.MALs):
|
266 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
267 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
268 |
-
|
269 |
-
# Image Quality Score Regression
|
270 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
271 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
272 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
273 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
274 |
-
|
275 |
-
prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
|
276 |
-
|
277 |
-
fusion = self.score_block(torch.cat((img_feature, prompt_feature), dim=1)) # bs, n, 1
|
278 |
-
|
279 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
280 |
-
iq_res = fusion[:, 0].view(-1)
|
281 |
-
|
282 |
-
return iq_res
|
283 |
-
|
284 |
-
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
285 |
-
x1 = save_output.outputs[block_index[0]][:, 1:]
|
286 |
-
x2 = save_output.outputs[block_index[1]][:, 1:]
|
287 |
-
x3 = save_output.outputs[block_index[2]][:, 1:]
|
288 |
-
x4 = save_output.outputs[block_index[3]][:, 1:]
|
289 |
-
x = torch.cat((x1, x2, x3, x4), dim=2)
|
290 |
-
return x
|
291 |
-
|
292 |
-
@torch.no_grad()
|
293 |
-
def forward_prompt(self, x, score):
|
294 |
-
_x = self.vit(x)
|
295 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
296 |
-
self.save_output.outputs.clear()
|
297 |
-
|
298 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
299 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
300 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
301 |
-
|
302 |
-
# Different Opinion Features (DOF)
|
303 |
-
DOF = torch.tensor([]).cuda()
|
304 |
-
for index, _ in enumerate(self.MALs):
|
305 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
306 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
307 |
-
|
308 |
-
# Image Quality Score Regression
|
309 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
310 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
311 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
312 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
313 |
-
|
314 |
-
# 分数线性变换为128维
|
315 |
-
# score_feature = self.score_projection(score) # bs, 128
|
316 |
-
score_feature = score.expand(-1, 128)
|
317 |
-
|
318 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
319 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
320 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
|
321 |
-
|
322 |
-
print('Load Prompt For Testing.', funsion_feature.shape)
|
323 |
-
self.prompt_feature = funsion_feature.clone()
|
324 |
-
|
325 |
-
def expand(self, A):
|
326 |
-
A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
|
327 |
-
|
328 |
-
B = None
|
329 |
-
for index, i in enumerate(A_expanded):
|
330 |
-
rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
|
331 |
-
if B is None:
|
332 |
-
B = rmv
|
333 |
-
else:
|
334 |
-
B = torch.cat((B, rmv), dim=0)
|
335 |
-
|
336 |
-
return B
|
337 |
-
|
338 |
-
def forward(self, x, score):
|
339 |
-
_x = self.vit(x)
|
340 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
341 |
-
self.save_output.outputs.clear()
|
342 |
-
|
343 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
344 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
345 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
346 |
-
|
347 |
-
# Different Opinion Features (DOF)
|
348 |
-
DOF = torch.tensor([]).cuda()
|
349 |
-
for index, _ in enumerate(self.MALs):
|
350 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
351 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
352 |
-
|
353 |
-
# Image Quality Score Regression
|
354 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
355 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
356 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
357 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
358 |
-
|
359 |
-
# 分数线性变换为128维
|
360 |
-
# score_feature = self.score_projection(score) # bs, 128
|
361 |
-
score_feature = score.expand(-1, 128) # bs, 128
|
362 |
-
|
363 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
364 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature.detach(), score_feature.unsqueeze(1).detach()), dim=1)) # bs, 2, 128
|
365 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
|
366 |
-
funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
|
367 |
-
|
368 |
-
fusion = self.score_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
|
369 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
370 |
-
iq_res = fusion[:, 0].view(-1)
|
371 |
-
|
372 |
-
# differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
|
373 |
-
# differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
|
374 |
-
|
375 |
-
gt_res = score.view(-1)
|
376 |
-
# diff_gt_res = 1 - score.view(-1)
|
377 |
-
|
378 |
-
return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
|
379 |
-
|
380 |
-
|
381 |
-
if __name__ == '__main__':
|
382 |
-
in_feature = torch.zeros((10, 3, 224, 224)).cuda()
|
383 |
-
gt_feature = torch.tensor([[0, 100, 1], [0, 100, 2], [0, 100, 3], [0, 100, 4], [0, 100, 5], [0, 100, 6], [0, 100, 7], [0, 100, 8], [0, 100, 9], [0, 100, 10]], dtype=torch.float).cuda()
|
384 |
-
model = MoNet().cuda()
|
385 |
-
|
386 |
-
iq_res, gt_res = model(in_feature, gt_feature)
|
387 |
-
|
388 |
-
print(iq_res.shape)
|
389 |
-
print(gt_res.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/models/monet_wo_prompt.py
DELETED
@@ -1,392 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
The completion for Mean-opinion Network(MoNet)
|
3 |
-
"""
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import timm
|
7 |
-
|
8 |
-
from timm.models.vision_transformer import Block
|
9 |
-
from einops import rearrange
|
10 |
-
from itertools import combinations
|
11 |
-
|
12 |
-
from tqdm import tqdm
|
13 |
-
import os
|
14 |
-
# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
|
15 |
-
|
16 |
-
class Attention_Block(nn.Module):
|
17 |
-
def __init__(self, dim, drop=0.1):
|
18 |
-
super().__init__()
|
19 |
-
self.c_q = nn.Linear(dim, dim)
|
20 |
-
self.c_k = nn.Linear(dim, dim)
|
21 |
-
self.c_v = nn.Linear(dim, dim)
|
22 |
-
self.norm_fact = dim ** -0.5
|
23 |
-
self.softmax = nn.Softmax(dim=-1)
|
24 |
-
self.proj_drop = nn.Dropout(drop)
|
25 |
-
|
26 |
-
def forward(self, x):
|
27 |
-
_x = x
|
28 |
-
B, C, N = x.shape
|
29 |
-
q = self.c_q(x)
|
30 |
-
k = self.c_k(x)
|
31 |
-
v = self.c_v(x)
|
32 |
-
|
33 |
-
attn = q @ k.transpose(-2, -1) * self.norm_fact
|
34 |
-
attn = self.softmax(attn)
|
35 |
-
x = (attn @ v).transpose(1, 2).reshape(B, C, N)
|
36 |
-
x = self.proj_drop(x)
|
37 |
-
x = x + _x
|
38 |
-
return x
|
39 |
-
|
40 |
-
|
41 |
-
class Self_Attention(nn.Module):
|
42 |
-
""" Self attention Layer"""
|
43 |
-
|
44 |
-
def __init__(self, in_dim):
|
45 |
-
super(Self_Attention, self).__init__()
|
46 |
-
|
47 |
-
self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
48 |
-
self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
49 |
-
self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
50 |
-
self.gamma = nn.Parameter(torch.zeros(1))
|
51 |
-
|
52 |
-
self.softmax = nn.Softmax(dim=-1)
|
53 |
-
|
54 |
-
def forward(self, inFeature):
|
55 |
-
bs, C, w, h = inFeature.size()
|
56 |
-
|
57 |
-
proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
|
58 |
-
proj_key = self.kConv(inFeature).view(bs, -1, w * h)
|
59 |
-
energy = torch.bmm(proj_query, proj_key)
|
60 |
-
attention = self.softmax(energy)
|
61 |
-
proj_value = self.vConv(inFeature).view(bs, -1, w * h)
|
62 |
-
|
63 |
-
out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
|
64 |
-
out = out.view(bs, C, w, h)
|
65 |
-
|
66 |
-
out = self.gamma * out + inFeature
|
67 |
-
|
68 |
-
return out
|
69 |
-
|
70 |
-
|
71 |
-
class MAL(nn.Module):
|
72 |
-
"""
|
73 |
-
Multi-view Attention Learning (MAL) module
|
74 |
-
"""
|
75 |
-
|
76 |
-
def __init__(self, in_dim=768, feature_num=4, feature_size=28):
|
77 |
-
super().__init__()
|
78 |
-
|
79 |
-
self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
|
80 |
-
self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention
|
81 |
-
|
82 |
-
# Self attention module for each input feature
|
83 |
-
self.attention_module = nn.ModuleList()
|
84 |
-
for _ in range(feature_num):
|
85 |
-
self.attention_module.append(Self_Attention(in_dim))
|
86 |
-
|
87 |
-
self.feature_num = feature_num
|
88 |
-
self.in_dim = in_dim
|
89 |
-
|
90 |
-
def forward(self, features):
|
91 |
-
feature = torch.tensor([]).cuda()
|
92 |
-
for index, _ in enumerate(features):
|
93 |
-
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
|
94 |
-
features = feature
|
95 |
-
|
96 |
-
input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28
|
97 |
-
bs, _, _ = input_tensor.shape # [2, 3072, 784]
|
98 |
-
|
99 |
-
in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim,
|
100 |
-
c=self.feature_num) # bs, 768, 28 * 28 * feature_num
|
101 |
-
feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768
|
102 |
-
|
103 |
-
in_channel = input_tensor.permute(0, 2, 1).contiguous() # bs, 28 * 28, 768 * feature_num
|
104 |
-
channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28
|
105 |
-
|
106 |
-
weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim,
|
107 |
-
c=self.feature_num) + channel_weight_sum.permute(0, 2, 1).contiguous()) / 2 # [2, 3072, 784]
|
108 |
-
|
109 |
-
weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1)
|
110 |
-
|
111 |
-
return weight_sum_res # bs, 768, 28 * 28
|
112 |
-
|
113 |
-
|
114 |
-
class SaveOutput:
|
115 |
-
def __init__(self):
|
116 |
-
self.outputs = []
|
117 |
-
|
118 |
-
def __call__(self, module, module_in, module_out):
|
119 |
-
self.outputs.append(module_out)
|
120 |
-
|
121 |
-
def clear(self):
|
122 |
-
self.outputs = []
|
123 |
-
|
124 |
-
# utils
|
125 |
-
@torch.no_grad()
|
126 |
-
def concat_all_gather(tensor):
|
127 |
-
"""
|
128 |
-
Performs all_gather operation on the provided tensors.
|
129 |
-
*** Warning ***: torch.distributed.all_gather has no gradient.
|
130 |
-
"""
|
131 |
-
tensors_gather = [
|
132 |
-
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
133 |
-
]
|
134 |
-
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
135 |
-
|
136 |
-
output = torch.cat(tensors_gather, dim=0)
|
137 |
-
return output
|
138 |
-
class Attention(nn.Module):
|
139 |
-
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
140 |
-
super().__init__()
|
141 |
-
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
142 |
-
self.num_heads = num_heads
|
143 |
-
head_dim = dim // num_heads
|
144 |
-
self.scale = head_dim ** -0.5
|
145 |
-
|
146 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
147 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
148 |
-
self.proj = nn.Linear(dim, dim)
|
149 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
150 |
-
|
151 |
-
def forward(self, x):
|
152 |
-
B, N, C = x.shape
|
153 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
154 |
-
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
155 |
-
|
156 |
-
attn = (q @ k.transpose(-2, -1)) * self.scale
|
157 |
-
attn = attn.softmax(dim=-1)
|
158 |
-
attn = self.attn_drop(attn)
|
159 |
-
|
160 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
161 |
-
x = self.proj(x)
|
162 |
-
x = self.proj_drop(x)
|
163 |
-
return x
|
164 |
-
|
165 |
-
from functools import partial
|
166 |
-
class MoNet(nn.Module):
|
167 |
-
def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
168 |
-
super().__init__()
|
169 |
-
self.img_size = img_size
|
170 |
-
self.input_size = img_size // patch_size
|
171 |
-
self.dim_mlp = dim_mlp
|
172 |
-
|
173 |
-
self.vit = timm.create_model('vit_base_patch8_224', pretrained=True)
|
174 |
-
self.vit.norm = nn.Identity()
|
175 |
-
self.vit.head = nn.Identity()
|
176 |
-
|
177 |
-
self.save_output = SaveOutput()
|
178 |
-
|
179 |
-
# Register Hooks
|
180 |
-
hook_handles = []
|
181 |
-
for layer in self.vit.modules():
|
182 |
-
if isinstance(layer, Block):
|
183 |
-
handle = layer.register_forward_hook(self.save_output)
|
184 |
-
hook_handles.append(handle)
|
185 |
-
|
186 |
-
self.MALs = nn.ModuleList()
|
187 |
-
for _ in range(3):
|
188 |
-
self.MALs.append(MAL())
|
189 |
-
|
190 |
-
# Image Quality Score Regression
|
191 |
-
self.fusion_mal = MAL(feature_num=3)
|
192 |
-
self.block = Block(dim_mlp, 12)
|
193 |
-
self.cnn = nn.Sequential(
|
194 |
-
nn.Conv2d(dim_mlp, 256, 5),
|
195 |
-
nn.BatchNorm2d(256),
|
196 |
-
nn.ReLU(inplace=True),
|
197 |
-
nn.AvgPool2d((2, 2)),
|
198 |
-
nn.Conv2d(256, 128, 3),
|
199 |
-
nn.BatchNorm2d(128),
|
200 |
-
nn.ReLU(inplace=True),
|
201 |
-
nn.AvgPool2d((2, 2)),
|
202 |
-
nn.Conv2d(128, 128, 3),
|
203 |
-
nn.BatchNorm2d(128),
|
204 |
-
nn.ReLU(inplace=True),
|
205 |
-
nn.AvgPool2d((3, 3)),
|
206 |
-
)
|
207 |
-
|
208 |
-
# self.i_p_fusion = nn.Sequential(
|
209 |
-
# Block(128, 4),
|
210 |
-
# Block(128, 4),
|
211 |
-
# Block(128, 4),
|
212 |
-
# )
|
213 |
-
# self.mlp = nn.Sequential(
|
214 |
-
# nn.Linear(128, 64),
|
215 |
-
# nn.GELU(),
|
216 |
-
# nn.Linear(64, 128),
|
217 |
-
# )
|
218 |
-
|
219 |
-
dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule
|
220 |
-
self.blocks = nn.Sequential(*[
|
221 |
-
Block(
|
222 |
-
dim=128, num_heads=4, mlp_ratio=4, qkv_bias=True, drop=0,
|
223 |
-
attn_drop=0, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
|
224 |
-
for i in range(8)])
|
225 |
-
self.norm = nn.LayerNorm(128)
|
226 |
-
|
227 |
-
self.score_block = nn.Sequential(
|
228 |
-
nn.Linear(128, 128 // 2),
|
229 |
-
nn.ReLU(),
|
230 |
-
nn.Dropout(drop),
|
231 |
-
nn.Linear(128 // 2, 1),
|
232 |
-
nn.Sigmoid()
|
233 |
-
)
|
234 |
-
|
235 |
-
self.prompt_feature = {}
|
236 |
-
|
237 |
-
@torch.no_grad()
|
238 |
-
def clear(self):
|
239 |
-
self.prompt_feature = {}
|
240 |
-
|
241 |
-
@torch.no_grad()
|
242 |
-
def inference(self, x, data_type):
|
243 |
-
# prompt_feature = self.prompt_feature[data_type] # 1, n, 128
|
244 |
-
|
245 |
-
_x = self.vit(x)
|
246 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
247 |
-
self.save_output.outputs.clear()
|
248 |
-
|
249 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
250 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
251 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
|
252 |
-
|
253 |
-
# Different Opinion Features (DOF)
|
254 |
-
DOF = torch.tensor([]).cuda()
|
255 |
-
for index, _ in enumerate(self.MALs):
|
256 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
257 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
258 |
-
|
259 |
-
# Image Quality Score Regression
|
260 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
261 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
262 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
263 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
264 |
-
|
265 |
-
# prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
|
266 |
-
# prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128
|
267 |
-
|
268 |
-
fusion = self.blocks(img_feature) # bs, 2, 1
|
269 |
-
fusion = self.norm(fusion)
|
270 |
-
fusion = self.score_block(fusion)
|
271 |
-
|
272 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
273 |
-
iq_res = fusion[:, 0].view(-1)
|
274 |
-
|
275 |
-
return iq_res
|
276 |
-
|
277 |
-
@torch.no_grad()
|
278 |
-
def check_prompt(self, data_type):
|
279 |
-
return data_type in self.prompt_feature
|
280 |
-
|
281 |
-
@torch.no_grad()
|
282 |
-
def forward_prompt(self, x, score, data_type):
|
283 |
-
pass
|
284 |
-
# if data_type in self.prompt_feature:
|
285 |
-
# return
|
286 |
-
# _x = self.vit(x)
|
287 |
-
# x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
288 |
-
# self.save_output.outputs.clear()
|
289 |
-
|
290 |
-
# x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
291 |
-
# x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
292 |
-
# x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
293 |
-
|
294 |
-
# # Different Opinion Features (DOF)
|
295 |
-
# DOF = torch.tensor([]).cuda()
|
296 |
-
# for index, _ in enumerate(self.MALs):
|
297 |
-
# DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
298 |
-
# DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
299 |
-
|
300 |
-
# # Image Quality Score Regression
|
301 |
-
# fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
302 |
-
# IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
303 |
-
# IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
304 |
-
# img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
305 |
-
|
306 |
-
# # 分数线性变换为128维
|
307 |
-
# # score_feature = self.score_projection(score) # bs, 128
|
308 |
-
# score_feature = score.expand(-1, 128)
|
309 |
-
|
310 |
-
# # img_feature 和 score_feature融合得到 funsion_feature
|
311 |
-
# funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
312 |
-
# funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
|
313 |
-
|
314 |
-
# # print('Load Prompt For Testing.', funsion_feature.shape)
|
315 |
-
# # self.prompt_feature = funsion_feature.clone()
|
316 |
-
# self.prompt_feature[data_type] = funsion_feature.clone()
|
317 |
-
|
318 |
-
def forward(self, x, score):
|
319 |
-
_x = self.vit(x)
|
320 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
321 |
-
self.save_output.outputs.clear()
|
322 |
-
|
323 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
324 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
325 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
326 |
-
|
327 |
-
# Different Opinion Features (DOF)
|
328 |
-
DOF = torch.tensor([]).cuda()
|
329 |
-
for index, _ in enumerate(self.MALs):
|
330 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
331 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
332 |
-
|
333 |
-
# Image Quality Score Regression
|
334 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
335 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
336 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
337 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
338 |
-
|
339 |
-
# 分数线性变换为128维
|
340 |
-
# score_feature = self.score_projection(score) # bs, 128
|
341 |
-
# score_feature = score.expand(-1, 128) # bs, 128
|
342 |
-
|
343 |
-
# # img_feature 和 score_feature融合得到 funsion_feature
|
344 |
-
# funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
345 |
-
# funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
|
346 |
-
# funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
|
347 |
-
# funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
|
348 |
-
|
349 |
-
fusion = self.blocks(img_feature) # bs, 2, 1
|
350 |
-
fusion = self.norm(fusion)
|
351 |
-
fusion = self.score_block(fusion)
|
352 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
353 |
-
iq_res = fusion[:, 0].view(-1)
|
354 |
-
|
355 |
-
# differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
|
356 |
-
# differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
|
357 |
-
|
358 |
-
gt_res = score.view(-1)
|
359 |
-
# diff_gt_res = 1 - score.view(-1)
|
360 |
-
|
361 |
-
return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
|
362 |
-
|
363 |
-
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
364 |
-
x1 = save_output.outputs[block_index[0]][:, 1:]
|
365 |
-
x2 = save_output.outputs[block_index[1]][:, 1:]
|
366 |
-
x3 = save_output.outputs[block_index[2]][:, 1:]
|
367 |
-
x4 = save_output.outputs[block_index[3]][:, 1:]
|
368 |
-
x = torch.cat((x1, x2, x3, x4), dim=2)
|
369 |
-
return x
|
370 |
-
|
371 |
-
def expand(self, A):
|
372 |
-
A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
|
373 |
-
|
374 |
-
B = None
|
375 |
-
for index, i in enumerate(A_expanded):
|
376 |
-
rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
|
377 |
-
if B is None:
|
378 |
-
B = rmv
|
379 |
-
else:
|
380 |
-
B = torch.cat((B, rmv), dim=0)
|
381 |
-
|
382 |
-
return B
|
383 |
-
|
384 |
-
if __name__ == '__main__':
|
385 |
-
in_feature = torch.zeros((2, 3, 224, 224)).cuda()
|
386 |
-
gt_feature = torch.tensor([[0, 100, 1], [0, 100, 2]], dtype=torch.float).cuda()
|
387 |
-
model = MoNet().cuda()
|
388 |
-
|
389 |
-
iq_res, gt_res = model(in_feature, gt_feature)
|
390 |
-
|
391 |
-
print(iq_res)
|
392 |
-
print(gt_res.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/models/{monet.py → promptiqa.py}
RENAMED
@@ -6,10 +6,8 @@ import torch.nn as nn
|
|
6 |
import timm
|
7 |
|
8 |
from timm.models.vision_transformer import Block
|
|
|
9 |
from einops import rearrange
|
10 |
-
from itertools import combinations
|
11 |
-
|
12 |
-
from tqdm import tqdm
|
13 |
|
14 |
class Attention_Block(nn.Module):
|
15 |
def __init__(self, dim, drop=0.1):
|
@@ -119,20 +117,6 @@ class SaveOutput:
|
|
119 |
def clear(self):
|
120 |
self.outputs = []
|
121 |
|
122 |
-
# utils
|
123 |
-
@torch.no_grad()
|
124 |
-
def concat_all_gather(tensor):
|
125 |
-
"""
|
126 |
-
Performs all_gather operation on the provided tensors.
|
127 |
-
*** Warning ***: torch.distributed.all_gather has no gradient.
|
128 |
-
"""
|
129 |
-
tensors_gather = [
|
130 |
-
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
131 |
-
]
|
132 |
-
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
133 |
-
|
134 |
-
output = torch.cat(tensors_gather, dim=0)
|
135 |
-
return output
|
136 |
class Attention(nn.Module):
|
137 |
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
138 |
super().__init__()
|
@@ -160,8 +144,7 @@ class Attention(nn.Module):
|
|
160 |
x = self.proj_drop(x)
|
161 |
return x
|
162 |
|
163 |
-
|
164 |
-
class MoNet(nn.Module):
|
165 |
def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
166 |
super().__init__()
|
167 |
self.img_size = img_size
|
@@ -273,15 +256,10 @@ class MoNet(nn.Module):
|
|
273 |
fusion = self.norm(fusion)
|
274 |
fusion = self.score_block(fusion)
|
275 |
|
276 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
277 |
iq_res = fusion[:, 0].view(-1)
|
278 |
|
279 |
return iq_res
|
280 |
|
281 |
-
@torch.no_grad()
|
282 |
-
def check_prompt(self, data_type):
|
283 |
-
return data_type in self.prompt_feature
|
284 |
-
|
285 |
@torch.no_grad()
|
286 |
def forward_prompt(self, x, score, data_type):
|
287 |
_x = self.vit(x)
|
@@ -304,63 +282,13 @@ class MoNet(nn.Module):
|
|
304 |
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
305 |
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
306 |
|
307 |
-
# 分数线性变换为128维
|
308 |
-
# score_feature = self.score_projection(score) # bs, 128
|
309 |
score_feature = score.expand(-1, 128)
|
310 |
|
311 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
312 |
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
313 |
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
|
314 |
|
315 |
-
# print('Load Prompt For Testing.', funsion_feature.shape)
|
316 |
-
# self.prompt_feature = funsion_feature.clone()
|
317 |
self.prompt_feature[data_type] = funsion_feature.clone()
|
318 |
|
319 |
-
def forward(self, x, score):
|
320 |
-
_x = self.vit(x)
|
321 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
322 |
-
self.save_output.outputs.clear()
|
323 |
-
|
324 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
325 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
326 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
327 |
-
|
328 |
-
# Different Opinion Features (DOF)
|
329 |
-
DOF = torch.tensor([]).cuda()
|
330 |
-
for index, _ in enumerate(self.MALs):
|
331 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
332 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
333 |
-
|
334 |
-
# Image Quality Score Regression
|
335 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
336 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
337 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
338 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
339 |
-
|
340 |
-
# 分数线性变换为128维
|
341 |
-
# score_feature = self.score_projection(score) # bs, 128
|
342 |
-
score_feature = score.expand(-1, 128) # bs, 128
|
343 |
-
|
344 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
345 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
346 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
|
347 |
-
funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
|
348 |
-
funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
|
349 |
-
|
350 |
-
fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1
|
351 |
-
fusion = self.norm(fusion)
|
352 |
-
fusion = self.score_block(fusion)
|
353 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
354 |
-
iq_res = fusion[:, 0].view(-1)
|
355 |
-
|
356 |
-
# differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
|
357 |
-
# differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
|
358 |
-
|
359 |
-
gt_res = score.view(-1)
|
360 |
-
# diff_gt_res = 1 - score.view(-1)
|
361 |
-
|
362 |
-
return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
|
363 |
-
|
364 |
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
365 |
x1 = save_output.outputs[block_index[0]][:, 1:]
|
366 |
x2 = save_output.outputs[block_index[1]][:, 1:]
|
@@ -381,13 +309,3 @@ class MoNet(nn.Module):
|
|
381 |
B = torch.cat((B, rmv), dim=0)
|
382 |
|
383 |
return B
|
384 |
-
|
385 |
-
if __name__ == '__main__':
|
386 |
-
in_feature = torch.zeros((10, 3, 224, 224)).cuda()
|
387 |
-
gt_feature = torch.tensor([[1], [2], [3], [4], [5], [6], [7], [8], [9], [10]], dtype=torch.float).cuda()
|
388 |
-
model = MoNet().cuda()
|
389 |
-
|
390 |
-
iq_res, gt_res = model(in_feature, gt_feature)
|
391 |
-
|
392 |
-
print(iq_res.shape)
|
393 |
-
print(gt_res.shape)
|
|
|
6 |
import timm
|
7 |
|
8 |
from timm.models.vision_transformer import Block
|
9 |
+
from functools import partial
|
10 |
from einops import rearrange
|
|
|
|
|
|
|
11 |
|
12 |
class Attention_Block(nn.Module):
|
13 |
def __init__(self, dim, drop=0.1):
|
|
|
117 |
def clear(self):
|
118 |
self.outputs = []
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
class Attention(nn.Module):
|
121 |
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
122 |
super().__init__()
|
|
|
144 |
x = self.proj_drop(x)
|
145 |
return x
|
146 |
|
147 |
+
class PromptIQA(nn.Module):
|
|
|
148 |
def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
149 |
super().__init__()
|
150 |
self.img_size = img_size
|
|
|
256 |
fusion = self.norm(fusion)
|
257 |
fusion = self.score_block(fusion)
|
258 |
|
|
|
259 |
iq_res = fusion[:, 0].view(-1)
|
260 |
|
261 |
return iq_res
|
262 |
|
|
|
|
|
|
|
|
|
263 |
@torch.no_grad()
|
264 |
def forward_prompt(self, x, score, data_type):
|
265 |
_x = self.vit(x)
|
|
|
282 |
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
283 |
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
284 |
|
|
|
|
|
285 |
score_feature = score.expand(-1, 128)
|
286 |
|
|
|
287 |
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
288 |
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
|
289 |
|
|
|
|
|
290 |
self.prompt_feature[data_type] = funsion_feature.clone()
|
291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
293 |
x1 = save_output.outputs[block_index[0]][:, 1:]
|
294 |
x2 = save_output.outputs[block_index[1]][:, 1:]
|
|
|
309 |
B = torch.cat((B, rmv), dim=0)
|
310 |
|
311 |
return B
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/models/vit_base.py
DELETED
@@ -1,402 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
The completion for Mean-opinion Network(MoNet)
|
3 |
-
"""
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import timm
|
7 |
-
|
8 |
-
from timm.models.vision_transformer import Block
|
9 |
-
from einops import rearrange
|
10 |
-
from itertools import combinations
|
11 |
-
|
12 |
-
from tqdm import tqdm
|
13 |
-
|
14 |
-
|
15 |
-
class Attention_Block(nn.Module):
|
16 |
-
def __init__(self, dim, drop=0.1):
|
17 |
-
super().__init__()
|
18 |
-
self.c_q = nn.Linear(dim, dim)
|
19 |
-
self.c_k = nn.Linear(dim, dim)
|
20 |
-
self.c_v = nn.Linear(dim, dim)
|
21 |
-
self.norm_fact = dim ** -0.5
|
22 |
-
self.softmax = nn.Softmax(dim=-1)
|
23 |
-
self.proj_drop = nn.Dropout(drop)
|
24 |
-
|
25 |
-
def forward(self, x):
|
26 |
-
_x = x
|
27 |
-
B, C, N = x.shape
|
28 |
-
q = self.c_q(x)
|
29 |
-
k = self.c_k(x)
|
30 |
-
v = self.c_v(x)
|
31 |
-
|
32 |
-
attn = q @ k.transpose(-2, -1) * self.norm_fact
|
33 |
-
attn = self.softmax(attn)
|
34 |
-
x = (attn @ v).transpose(1, 2).reshape(B, C, N)
|
35 |
-
x = self.proj_drop(x)
|
36 |
-
x = x + _x
|
37 |
-
return x
|
38 |
-
|
39 |
-
|
40 |
-
class Self_Attention(nn.Module):
|
41 |
-
""" Self attention Layer"""
|
42 |
-
|
43 |
-
def __init__(self, in_dim):
|
44 |
-
super(Self_Attention, self).__init__()
|
45 |
-
|
46 |
-
self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
47 |
-
self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
48 |
-
self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
49 |
-
self.gamma = nn.Parameter(torch.zeros(1))
|
50 |
-
|
51 |
-
self.softmax = nn.Softmax(dim=-1)
|
52 |
-
|
53 |
-
def forward(self, inFeature):
|
54 |
-
bs, C, w, h = inFeature.size()
|
55 |
-
|
56 |
-
proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
|
57 |
-
proj_key = self.kConv(inFeature).view(bs, -1, w * h)
|
58 |
-
energy = torch.bmm(proj_query, proj_key)
|
59 |
-
attention = self.softmax(energy)
|
60 |
-
proj_value = self.vConv(inFeature).view(bs, -1, w * h)
|
61 |
-
|
62 |
-
out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
|
63 |
-
out = out.view(bs, C, w, h)
|
64 |
-
|
65 |
-
out = self.gamma * out + inFeature
|
66 |
-
|
67 |
-
return out
|
68 |
-
|
69 |
-
|
70 |
-
class three_cnn(nn.Module):
|
71 |
-
def __init__(self, in_dim) -> None:
|
72 |
-
super().__init__()
|
73 |
-
|
74 |
-
self.three_cnn = nn.Sequential(
|
75 |
-
nn.Conv2d(in_dim, in_dim // 2, kernel_size=3, padding=1),
|
76 |
-
nn.ReLU(inplace=True),
|
77 |
-
nn.Conv2d(in_dim // 2, in_dim // 2, kernel_size=3, padding=1),
|
78 |
-
nn.ReLU(inplace=True),
|
79 |
-
nn.Conv2d(in_dim // 2, in_dim, kernel_size=3, padding=1),
|
80 |
-
nn.ReLU(inplace=True),
|
81 |
-
)
|
82 |
-
|
83 |
-
def forward(self, input):
|
84 |
-
return self.three_cnn(input)
|
85 |
-
|
86 |
-
|
87 |
-
class MAL(nn.Module):
|
88 |
-
def __init__(self, in_dim=768, feature_num=4, feature_size=28):
|
89 |
-
super().__init__()
|
90 |
-
self.attention_module = nn.ModuleList()
|
91 |
-
for i in range(feature_num):
|
92 |
-
self.attention_module.append(three_cnn(in_dim))
|
93 |
-
|
94 |
-
self.feature_num = feature_num
|
95 |
-
self.in_dim = in_dim
|
96 |
-
self.feature_size = feature_size
|
97 |
-
|
98 |
-
def forward(self, features):
|
99 |
-
feature = torch.tensor([]).cuda()
|
100 |
-
for index, _ in enumerate(features):
|
101 |
-
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(1)), dim=1)
|
102 |
-
feature = torch.mean(feature, dim=1)
|
103 |
-
features = feature.view(-1, self.in_dim, self.feature_size * self.feature_size)
|
104 |
-
|
105 |
-
return features # bs, 768, 28 * 28
|
106 |
-
|
107 |
-
|
108 |
-
class SaveOutput:
|
109 |
-
def __init__(self):
|
110 |
-
self.outputs = []
|
111 |
-
|
112 |
-
def __call__(self, module, module_in, module_out):
|
113 |
-
self.outputs.append(module_out)
|
114 |
-
|
115 |
-
def clear(self):
|
116 |
-
self.outputs = []
|
117 |
-
|
118 |
-
|
119 |
-
# utils
|
120 |
-
@torch.no_grad()
|
121 |
-
def concat_all_gather(tensor):
|
122 |
-
"""
|
123 |
-
Performs all_gather operation on the provided tensors.
|
124 |
-
*** Warning ***: torch.distributed.all_gather has no gradient.
|
125 |
-
"""
|
126 |
-
tensors_gather = [
|
127 |
-
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
128 |
-
]
|
129 |
-
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
130 |
-
|
131 |
-
output = torch.cat(tensors_gather, dim=0)
|
132 |
-
return output
|
133 |
-
|
134 |
-
|
135 |
-
class Attention(nn.Module):
|
136 |
-
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
137 |
-
super().__init__()
|
138 |
-
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
139 |
-
self.num_heads = num_heads
|
140 |
-
head_dim = dim // num_heads
|
141 |
-
self.scale = head_dim ** -0.5
|
142 |
-
|
143 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
144 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
145 |
-
self.proj = nn.Linear(dim, dim)
|
146 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
147 |
-
|
148 |
-
def forward(self, x):
|
149 |
-
B, N, C = x.shape
|
150 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
151 |
-
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
152 |
-
|
153 |
-
attn = (q @ k.transpose(-2, -1)) * self.scale
|
154 |
-
attn = attn.softmax(dim=-1)
|
155 |
-
attn = self.attn_drop(attn)
|
156 |
-
|
157 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
158 |
-
x = self.proj(x)
|
159 |
-
x = self.proj_drop(x)
|
160 |
-
return x
|
161 |
-
|
162 |
-
|
163 |
-
from functools import partial
|
164 |
-
|
165 |
-
|
166 |
-
class MoNet(nn.Module):
|
167 |
-
def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
168 |
-
super().__init__()
|
169 |
-
self.img_size = img_size
|
170 |
-
self.input_size = img_size // patch_size
|
171 |
-
self.dim_mlp = dim_mlp
|
172 |
-
|
173 |
-
self.vit = timm.create_model('vit_base_patch8_224', pretrained=True)
|
174 |
-
self.vit.norm = nn.Identity()
|
175 |
-
self.vit.head = nn.Identity()
|
176 |
-
|
177 |
-
self.save_output = SaveOutput()
|
178 |
-
|
179 |
-
# Register Hooks
|
180 |
-
hook_handles = []
|
181 |
-
for layer in self.vit.modules():
|
182 |
-
if isinstance(layer, Block):
|
183 |
-
handle = layer.register_forward_hook(self.save_output)
|
184 |
-
hook_handles.append(handle)
|
185 |
-
|
186 |
-
self.MALs = nn.ModuleList()
|
187 |
-
for _ in range(1):
|
188 |
-
self.MALs.append(MAL())
|
189 |
-
|
190 |
-
# Image Quality Score Regression
|
191 |
-
self.fusion_mal = MAL(feature_num=1)
|
192 |
-
self.block = Block(dim_mlp, 12)
|
193 |
-
self.cnn = nn.Sequential(
|
194 |
-
nn.Conv2d(dim_mlp, 256, 5),
|
195 |
-
nn.BatchNorm2d(256),
|
196 |
-
nn.ReLU(inplace=True),
|
197 |
-
nn.AvgPool2d((2, 2)),
|
198 |
-
nn.Conv2d(256, 128, 3),
|
199 |
-
nn.BatchNorm2d(128),
|
200 |
-
nn.ReLU(inplace=True),
|
201 |
-
nn.AvgPool2d((2, 2)),
|
202 |
-
nn.Conv2d(128, 128, 3),
|
203 |
-
nn.BatchNorm2d(128),
|
204 |
-
nn.ReLU(inplace=True),
|
205 |
-
nn.AvgPool2d((3, 3)),
|
206 |
-
)
|
207 |
-
|
208 |
-
self.i_p_fusion = nn.Sequential(
|
209 |
-
Block(128, 4),
|
210 |
-
Block(128, 4),
|
211 |
-
Block(128, 4),
|
212 |
-
)
|
213 |
-
self.mlp = nn.Sequential(
|
214 |
-
nn.Linear(128, 64),
|
215 |
-
nn.GELU(),
|
216 |
-
nn.Linear(64, 128),
|
217 |
-
)
|
218 |
-
|
219 |
-
self.prompt_fusion = nn.Sequential(
|
220 |
-
Block(128, 4),
|
221 |
-
Block(128, 4),
|
222 |
-
Block(128, 4),
|
223 |
-
)
|
224 |
-
|
225 |
-
dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule
|
226 |
-
self.blocks = nn.Sequential(*[
|
227 |
-
Block(
|
228 |
-
dim=128, num_heads=4, mlp_ratio=4, qkv_bias=True, drop=0,
|
229 |
-
attn_drop=0, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
|
230 |
-
for i in range(8)])
|
231 |
-
self.norm = nn.LayerNorm(128)
|
232 |
-
|
233 |
-
self.score_block = nn.Sequential(
|
234 |
-
nn.Linear(128, 128 // 2),
|
235 |
-
nn.ReLU(),
|
236 |
-
nn.Dropout(drop),
|
237 |
-
nn.Linear(128 // 2, 1),
|
238 |
-
nn.Sigmoid()
|
239 |
-
)
|
240 |
-
|
241 |
-
self.prompt_feature = {}
|
242 |
-
|
243 |
-
@torch.no_grad()
|
244 |
-
def clear(self):
|
245 |
-
self.prompt_feature = {}
|
246 |
-
|
247 |
-
@torch.no_grad()
|
248 |
-
def inference(self, x, data_type):
|
249 |
-
prompt_feature = self.prompt_feature[data_type] # 1, n, 128
|
250 |
-
|
251 |
-
_x = self.vit(x)
|
252 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
253 |
-
self.save_output.outputs.clear()
|
254 |
-
|
255 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
256 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
257 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
|
258 |
-
|
259 |
-
# Different Opinion Features (DOF)
|
260 |
-
DOF = torch.tensor([]).cuda()
|
261 |
-
for index, _ in enumerate(self.MALs):
|
262 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
263 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
264 |
-
|
265 |
-
# Image Quality Score Regression
|
266 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
267 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
268 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
269 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
270 |
-
|
271 |
-
prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
|
272 |
-
prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128
|
273 |
-
|
274 |
-
fusion = self.blocks(torch.cat((img_feature, prompt_feature), dim=1)) # bs, 2, 1
|
275 |
-
fusion = self.norm(fusion)
|
276 |
-
fusion = self.score_block(fusion)
|
277 |
-
|
278 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
279 |
-
iq_res = fusion[:, 0].view(-1)
|
280 |
-
|
281 |
-
return iq_res
|
282 |
-
|
283 |
-
@torch.no_grad()
|
284 |
-
def check_prompt(self, data_type):
|
285 |
-
return data_type in self.prompt_feature
|
286 |
-
|
287 |
-
@torch.no_grad()
|
288 |
-
def forward_prompt(self, x, score, data_type):
|
289 |
-
if data_type in self.prompt_feature:
|
290 |
-
return
|
291 |
-
_x = self.vit(x)
|
292 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
293 |
-
self.save_output.outputs.clear()
|
294 |
-
|
295 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
296 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
297 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
298 |
-
|
299 |
-
# Different Opinion Features (DOF)
|
300 |
-
DOF = torch.tensor([]).cuda()
|
301 |
-
for index, _ in enumerate(self.MALs):
|
302 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
303 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
304 |
-
|
305 |
-
# Image Quality Score Regression
|
306 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
307 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
308 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
309 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
310 |
-
|
311 |
-
# 分数线性变换为128维
|
312 |
-
# score_feature = self.score_projection(score) # bs, 128
|
313 |
-
score_feature = score.expand(-1, 128)
|
314 |
-
|
315 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
316 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
317 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
|
318 |
-
|
319 |
-
# print('Load Prompt For Testing.', funsion_feature.shape)
|
320 |
-
# self.prompt_feature = funsion_feature.clone()
|
321 |
-
self.prompt_feature[data_type] = funsion_feature.clone()
|
322 |
-
|
323 |
-
def forward(self, x, score):
|
324 |
-
_x = self.vit(x)
|
325 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
326 |
-
self.save_output.outputs.clear()
|
327 |
-
|
328 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
329 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
330 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
331 |
-
|
332 |
-
# Different Opinion Features (DOF)
|
333 |
-
DOF = torch.tensor([]).cuda()
|
334 |
-
for index, _ in enumerate(self.MALs):
|
335 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
336 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
337 |
-
|
338 |
-
# Image Quality Score Regression
|
339 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
340 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
341 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
342 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
343 |
-
|
344 |
-
# 分数线性变换为128维
|
345 |
-
# score_feature = self.score_projection(score) # bs, 128
|
346 |
-
score_feature = score.expand(-1, 128) # bs, 128
|
347 |
-
|
348 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
349 |
-
# funsion_feature = self.i_p_fusion(torch.cat((img_feature.detach(), score_feature.unsqueeze(1).detach()), dim=1)) # bs, 2, 128
|
350 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
351 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
|
352 |
-
funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
|
353 |
-
funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
|
354 |
-
|
355 |
-
fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1
|
356 |
-
fusion = self.norm(fusion)
|
357 |
-
fusion = self.score_block(fusion)
|
358 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
359 |
-
iq_res = fusion[:, 0].view(-1)
|
360 |
-
|
361 |
-
# differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
|
362 |
-
# differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
|
363 |
-
|
364 |
-
gt_res = score.view(-1)
|
365 |
-
# diff_gt_res = 1 - score.view(-1)
|
366 |
-
|
367 |
-
return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
|
368 |
-
|
369 |
-
def extract_feature(self, save_output, block_index=None):
|
370 |
-
block_index = [2, 5, 8, 11]
|
371 |
-
x1 = save_output.outputs[block_index[0]][:, 1:]
|
372 |
-
x2 = save_output.outputs[block_index[1]][:, 1:]
|
373 |
-
x3 = save_output.outputs[block_index[2]][:, 1:]
|
374 |
-
x4 = save_output.outputs[block_index[3]][:, 1:]
|
375 |
-
x = torch.cat((x1, x2, x3, x4), dim=2)
|
376 |
-
return x
|
377 |
-
|
378 |
-
def expand(self, A):
|
379 |
-
A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
|
380 |
-
|
381 |
-
B = None
|
382 |
-
for index, i in enumerate(A_expanded):
|
383 |
-
rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
|
384 |
-
if B is None:
|
385 |
-
B = rmv
|
386 |
-
else:
|
387 |
-
B = torch.cat((B, rmv), dim=0)
|
388 |
-
|
389 |
-
return B
|
390 |
-
|
391 |
-
|
392 |
-
if __name__ == '__main__':
|
393 |
-
in_feature = torch.zeros((11, 3, 384, 384)).cuda()
|
394 |
-
gt_feature = torch.tensor(
|
395 |
-
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype=torch.float).cuda()
|
396 |
-
gt_feature = gt_feature.reshape(-1, 1)
|
397 |
-
model = MoNet().cuda()
|
398 |
-
|
399 |
-
(iq_res, _), (_, _) = model(in_feature, gt_feature)
|
400 |
-
|
401 |
-
print(iq_res.shape)
|
402 |
-
# print(gt_res.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/models/vit_large.py
DELETED
@@ -1,405 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
The completion for Mean-opinion Network(MoNet)
|
3 |
-
"""
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import timm
|
7 |
-
|
8 |
-
from timm.models.vision_transformer import Block
|
9 |
-
from einops import rearrange
|
10 |
-
from itertools import combinations
|
11 |
-
|
12 |
-
from tqdm import tqdm
|
13 |
-
|
14 |
-
|
15 |
-
class Attention_Block(nn.Module):
|
16 |
-
def __init__(self, dim, drop=0.1):
|
17 |
-
super().__init__()
|
18 |
-
self.c_q = nn.Linear(dim, dim)
|
19 |
-
self.c_k = nn.Linear(dim, dim)
|
20 |
-
self.c_v = nn.Linear(dim, dim)
|
21 |
-
self.norm_fact = dim ** -0.5
|
22 |
-
self.softmax = nn.Softmax(dim=-1)
|
23 |
-
self.proj_drop = nn.Dropout(drop)
|
24 |
-
|
25 |
-
def forward(self, x):
|
26 |
-
_x = x
|
27 |
-
B, C, N = x.shape
|
28 |
-
q = self.c_q(x)
|
29 |
-
k = self.c_k(x)
|
30 |
-
v = self.c_v(x)
|
31 |
-
|
32 |
-
attn = q @ k.transpose(-2, -1) * self.norm_fact
|
33 |
-
attn = self.softmax(attn)
|
34 |
-
x = (attn @ v).transpose(1, 2).reshape(B, C, N)
|
35 |
-
x = self.proj_drop(x)
|
36 |
-
x = x + _x
|
37 |
-
return x
|
38 |
-
|
39 |
-
|
40 |
-
class Self_Attention(nn.Module):
|
41 |
-
""" Self attention Layer"""
|
42 |
-
|
43 |
-
def __init__(self, in_dim):
|
44 |
-
super(Self_Attention, self).__init__()
|
45 |
-
|
46 |
-
self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
47 |
-
self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
48 |
-
self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
49 |
-
self.gamma = nn.Parameter(torch.zeros(1))
|
50 |
-
|
51 |
-
self.softmax = nn.Softmax(dim=-1)
|
52 |
-
|
53 |
-
def forward(self, inFeature):
|
54 |
-
bs, C, w, h = inFeature.size()
|
55 |
-
|
56 |
-
proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
|
57 |
-
proj_key = self.kConv(inFeature).view(bs, -1, w * h)
|
58 |
-
energy = torch.bmm(proj_query, proj_key)
|
59 |
-
attention = self.softmax(energy)
|
60 |
-
proj_value = self.vConv(inFeature).view(bs, -1, w * h)
|
61 |
-
|
62 |
-
out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
|
63 |
-
out = out.view(bs, C, w, h)
|
64 |
-
|
65 |
-
out = self.gamma * out + inFeature
|
66 |
-
|
67 |
-
return out
|
68 |
-
|
69 |
-
|
70 |
-
class three_cnn(nn.Module):
|
71 |
-
def __init__(self, in_dim) -> None:
|
72 |
-
super().__init__()
|
73 |
-
|
74 |
-
self.three_cnn = nn.Sequential(
|
75 |
-
nn.Conv2d(in_dim, in_dim // 2, kernel_size=3, padding=1),
|
76 |
-
nn.ReLU(inplace=True),
|
77 |
-
nn.Conv2d(in_dim // 2, in_dim // 2, kernel_size=3, padding=1),
|
78 |
-
nn.ReLU(inplace=True),
|
79 |
-
nn.Conv2d(in_dim // 2, in_dim, kernel_size=3, padding=1),
|
80 |
-
nn.ReLU(inplace=True),
|
81 |
-
)
|
82 |
-
|
83 |
-
def forward(self, input):
|
84 |
-
return self.three_cnn(input)
|
85 |
-
|
86 |
-
|
87 |
-
class MAL(nn.Module):
|
88 |
-
def __init__(self, in_dim=768, feature_num=4, feature_size=28):
|
89 |
-
super().__init__()
|
90 |
-
self.attention_module = nn.ModuleList()
|
91 |
-
for i in range(feature_num):
|
92 |
-
self.attention_module.append(three_cnn(in_dim))
|
93 |
-
|
94 |
-
self.feature_num = feature_num
|
95 |
-
self.in_dim = in_dim
|
96 |
-
self.feature_size = feature_size
|
97 |
-
|
98 |
-
def forward(self, features):
|
99 |
-
feature = torch.tensor([]).cuda()
|
100 |
-
for index, _ in enumerate(features):
|
101 |
-
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(1)), dim=1)
|
102 |
-
feature = torch.mean(feature, dim=1)
|
103 |
-
features = feature.view(-1, self.in_dim, self.feature_size * self.feature_size)
|
104 |
-
|
105 |
-
return features # bs, 768, 28 * 28
|
106 |
-
|
107 |
-
|
108 |
-
class SaveOutput:
|
109 |
-
def __init__(self):
|
110 |
-
self.outputs = []
|
111 |
-
|
112 |
-
def __call__(self, module, module_in, module_out):
|
113 |
-
self.outputs.append(module_out)
|
114 |
-
|
115 |
-
def clear(self):
|
116 |
-
self.outputs = []
|
117 |
-
|
118 |
-
|
119 |
-
# utils
|
120 |
-
@torch.no_grad()
|
121 |
-
def concat_all_gather(tensor):
|
122 |
-
"""
|
123 |
-
Performs all_gather operation on the provided tensors.
|
124 |
-
*** Warning ***: torch.distributed.all_gather has no gradient.
|
125 |
-
"""
|
126 |
-
tensors_gather = [
|
127 |
-
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
128 |
-
]
|
129 |
-
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
130 |
-
|
131 |
-
output = torch.cat(tensors_gather, dim=0)
|
132 |
-
return output
|
133 |
-
|
134 |
-
|
135 |
-
class Attention(nn.Module):
|
136 |
-
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
137 |
-
super().__init__()
|
138 |
-
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
139 |
-
self.num_heads = num_heads
|
140 |
-
head_dim = dim // num_heads
|
141 |
-
self.scale = head_dim ** -0.5
|
142 |
-
|
143 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
144 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
145 |
-
self.proj = nn.Linear(dim, dim)
|
146 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
147 |
-
|
148 |
-
def forward(self, x):
|
149 |
-
B, N, C = x.shape
|
150 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
151 |
-
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
152 |
-
|
153 |
-
attn = (q @ k.transpose(-2, -1)) * self.scale
|
154 |
-
attn = attn.softmax(dim=-1)
|
155 |
-
attn = self.attn_drop(attn)
|
156 |
-
|
157 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
158 |
-
x = self.proj(x)
|
159 |
-
x = self.proj_drop(x)
|
160 |
-
return x
|
161 |
-
|
162 |
-
|
163 |
-
from functools import partial
|
164 |
-
|
165 |
-
|
166 |
-
class MoNet(nn.Module):
|
167 |
-
def __init__(self, patch_size=32, drop=0.1, dim_mlp=1024, img_size=384):
|
168 |
-
super().__init__()
|
169 |
-
self.img_size = img_size
|
170 |
-
self.input_size = img_size // patch_size
|
171 |
-
self.dim_mlp = dim_mlp
|
172 |
-
|
173 |
-
self.vit = timm.create_model('vit_large_patch32_384', pretrained=True)
|
174 |
-
self.vit.norm = nn.Identity()
|
175 |
-
self.vit.head = nn.Identity()
|
176 |
-
self.vit.head_drop = nn.Identity()
|
177 |
-
|
178 |
-
self.save_output = SaveOutput()
|
179 |
-
|
180 |
-
# Register Hooks
|
181 |
-
hook_handles = []
|
182 |
-
for layer in self.vit.modules():
|
183 |
-
if isinstance(layer, Block):
|
184 |
-
handle = layer.register_forward_hook(self.save_output)
|
185 |
-
hook_handles.append(handle)
|
186 |
-
|
187 |
-
self.MALs = nn.ModuleList()
|
188 |
-
for _ in range(3):
|
189 |
-
self.MALs.append(MAL(in_dim=dim_mlp, feature_size=self.input_size))
|
190 |
-
|
191 |
-
# Image Quality Score Regression
|
192 |
-
self.fusion_mal = MAL(in_dim=dim_mlp, feature_num=3, feature_size=self.input_size)
|
193 |
-
self.block = Block(dim_mlp, 16)
|
194 |
-
self.cnn = nn.Sequential(
|
195 |
-
nn.Conv2d(dim_mlp, 512, 5),
|
196 |
-
nn.BatchNorm2d(512),
|
197 |
-
nn.ReLU(inplace=True),
|
198 |
-
nn.AvgPool2d((2, 2)), # 4
|
199 |
-
nn.Conv2d(512, 256, 3, 1), # 2
|
200 |
-
nn.BatchNorm2d(256),
|
201 |
-
nn.ReLU(inplace=True),
|
202 |
-
nn.Conv2d(256, 256, 1),
|
203 |
-
nn.BatchNorm2d(256),
|
204 |
-
nn.ReLU(inplace=True),
|
205 |
-
nn.AvgPool2d((2, 2)),
|
206 |
-
)
|
207 |
-
|
208 |
-
self.i_p_fusion = nn.Sequential(
|
209 |
-
Block(256, 8),
|
210 |
-
Block(256, 8),
|
211 |
-
Block(256, 8),
|
212 |
-
)
|
213 |
-
self.mlp = nn.Sequential(
|
214 |
-
nn.Linear(256, 128),
|
215 |
-
nn.GELU(),
|
216 |
-
nn.Linear(128, 256),
|
217 |
-
)
|
218 |
-
|
219 |
-
self.prompt_fusion = nn.Sequential(
|
220 |
-
Block(256, 8),
|
221 |
-
Block(256, 8),
|
222 |
-
Block(256, 8),
|
223 |
-
)
|
224 |
-
|
225 |
-
dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule
|
226 |
-
self.blocks = nn.Sequential(*[
|
227 |
-
Block(dim=256, num_heads=8, mlp_ratio=4, qkv_bias=True, attn_drop=0, drop_path=dpr[i])
|
228 |
-
for i in range(8)])
|
229 |
-
self.norm = nn.LayerNorm(256)
|
230 |
-
|
231 |
-
self.score_block = nn.Sequential(
|
232 |
-
nn.Linear(256, 256 // 2),
|
233 |
-
nn.ReLU(),
|
234 |
-
nn.Dropout(drop),
|
235 |
-
nn.Linear(256 // 2, 1),
|
236 |
-
nn.Sigmoid()
|
237 |
-
)
|
238 |
-
self.prompt_feature = {}
|
239 |
-
|
240 |
-
@torch.no_grad()
|
241 |
-
def clear(self):
|
242 |
-
self.prompt_feature = {}
|
243 |
-
|
244 |
-
@torch.no_grad()
|
245 |
-
def inference(self, x, data_type):
|
246 |
-
prompt_feature = self.prompt_feature[data_type] # 1, n, 128
|
247 |
-
|
248 |
-
_x = self.vit(x)
|
249 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
250 |
-
self.save_output.outputs.clear()
|
251 |
-
|
252 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
253 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size,
|
254 |
-
h=self.input_size) # bs, 4, 768, 28, 28
|
255 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
|
256 |
-
|
257 |
-
# Different Opinion Features (DOF)
|
258 |
-
DOF = torch.tensor([]).cuda()
|
259 |
-
for index, _ in enumerate(self.MALs):
|
260 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
261 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
262 |
-
|
263 |
-
# Image Quality Score Regression
|
264 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
265 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
266 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size,
|
267 |
-
h=self.input_size) # bs, 768, 28, 28
|
268 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
269 |
-
|
270 |
-
prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
|
271 |
-
prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128
|
272 |
-
|
273 |
-
fusion = self.blocks(torch.cat((img_feature, prompt_feature), dim=1)) # bs, 2, 1
|
274 |
-
fusion = self.norm(fusion)
|
275 |
-
fusion = self.score_block(fusion)
|
276 |
-
|
277 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
278 |
-
iq_res = fusion[:, 0].view(-1)
|
279 |
-
|
280 |
-
return iq_res
|
281 |
-
|
282 |
-
@torch.no_grad()
|
283 |
-
def check_prompt(self, data_type):
|
284 |
-
return data_type in self.prompt_feature
|
285 |
-
|
286 |
-
@torch.no_grad()
|
287 |
-
def forward_prompt(self, x, score, data_type):
|
288 |
-
if data_type in self.prompt_feature:
|
289 |
-
return
|
290 |
-
_x = self.vit(x)
|
291 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
292 |
-
self.save_output.outputs.clear()
|
293 |
-
|
294 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
295 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size,
|
296 |
-
h=self.input_size) # bs, 4, 768, 28, 28
|
297 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
298 |
-
|
299 |
-
# Different Opinion Features (DOF)
|
300 |
-
DOF = torch.tensor([]).cuda()
|
301 |
-
for index, _ in enumerate(self.MALs):
|
302 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
303 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
304 |
-
|
305 |
-
# Image Quality Score Regression
|
306 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
307 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
308 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size,
|
309 |
-
h=self.input_size) # bs, 768, 28, 28
|
310 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
311 |
-
|
312 |
-
# 分数线性变换为128维
|
313 |
-
# score_feature = self.score_projection(score) # bs, 128
|
314 |
-
score_feature = score.expand(-1, 256)
|
315 |
-
|
316 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
317 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
318 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
|
319 |
-
|
320 |
-
# print('Load Prompt For Testing.', funsion_feature.shape)
|
321 |
-
# self.prompt_feature = funsion_feature.clone()
|
322 |
-
self.prompt_feature[data_type] = funsion_feature.clone()
|
323 |
-
|
324 |
-
def forward(self, x, score):
|
325 |
-
_x = self.vit(x)
|
326 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
327 |
-
self.save_output.outputs.clear()
|
328 |
-
|
329 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
330 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size,
|
331 |
-
h=self.input_size) # bs, 4, 768, 28, 28
|
332 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
333 |
-
|
334 |
-
# Different Opinion Features (DOF)
|
335 |
-
DOF = torch.tensor([]).cuda()
|
336 |
-
for index, _ in enumerate(self.MALs):
|
337 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
338 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
339 |
-
# Image Quality Score Regression
|
340 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
341 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
342 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size,
|
343 |
-
h=self.input_size) # bs, 768, 28, 28
|
344 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
345 |
-
|
346 |
-
# 分数线性变换为128维
|
347 |
-
# score_feature = self.score_projection(score) # bs, 128
|
348 |
-
score_feature = score.expand(-1, 256) # bs, 128
|
349 |
-
|
350 |
-
# img_feature 和 score_feature融合得到 funsion_feature funsion_feature = self.i_p_fusion(torch.cat((
|
351 |
-
# img_feature.detach(), score_feature.unsqueeze(1).detach()), dim=1)) # bs, 2, 128
|
352 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
353 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) # bs, 128
|
354 |
-
funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
|
355 |
-
funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
|
356 |
-
|
357 |
-
fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1
|
358 |
-
fusion = self.norm(fusion)
|
359 |
-
fusion = self.score_block(fusion)
|
360 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
361 |
-
iq_res = fusion[:, 0].view(-1)
|
362 |
-
|
363 |
-
# differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
|
364 |
-
# differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
|
365 |
-
|
366 |
-
gt_res = score.view(-1)
|
367 |
-
# diff_gt_res = 1 - score.view(-1)
|
368 |
-
|
369 |
-
return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
|
370 |
-
|
371 |
-
def extract_feature(self, save_output, block_index=None):
|
372 |
-
if block_index is None:
|
373 |
-
block_index = [5, 11, 17, 23]
|
374 |
-
x1 = save_output.outputs[block_index[0]][:, 1:]
|
375 |
-
x2 = save_output.outputs[block_index[1]][:, 1:]
|
376 |
-
x3 = save_output.outputs[block_index[2]][:, 1:]
|
377 |
-
x4 = save_output.outputs[block_index[3]][:, 1:]
|
378 |
-
x = torch.cat((x1, x2, x3, x4), dim=2)
|
379 |
-
return x
|
380 |
-
|
381 |
-
def expand(self, A):
|
382 |
-
A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
|
383 |
-
|
384 |
-
B = None
|
385 |
-
for index, i in enumerate(A_expanded):
|
386 |
-
rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
|
387 |
-
if B is None:
|
388 |
-
B = rmv
|
389 |
-
else:
|
390 |
-
B = torch.cat((B, rmv), dim=0)
|
391 |
-
|
392 |
-
return B
|
393 |
-
|
394 |
-
|
395 |
-
if __name__ == '__main__':
|
396 |
-
in_feature = torch.zeros((11, 3, 384, 384)).cuda()
|
397 |
-
gt_feature = torch.tensor(
|
398 |
-
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype=torch.float).cuda()
|
399 |
-
gt_feature = gt_feature.reshape(-1, 1)
|
400 |
-
model = MoNet().cuda()
|
401 |
-
|
402 |
-
(iq_res, _), (_, _) = model(in_feature, gt_feature)
|
403 |
-
|
404 |
-
print(iq_res.shape)
|
405 |
-
# print(gt_res.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/run_promptIQA copy.py
DELETED
@@ -1,109 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import random
|
3 |
-
import torchvision
|
4 |
-
import cv2
|
5 |
-
import torch
|
6 |
-
from models import monet as MoNet
|
7 |
-
import numpy as np
|
8 |
-
from utils.dataset.process import ToTensor, Normalize
|
9 |
-
from utils.toolkit import *
|
10 |
-
import warnings
|
11 |
-
warnings.filterwarnings('ignore')
|
12 |
-
|
13 |
-
import sys
|
14 |
-
sys.path.append(os.path.dirname(__file__))
|
15 |
-
|
16 |
-
class PromptIQA():
|
17 |
-
def __init__(self) -> None:
|
18 |
-
pass
|
19 |
-
|
20 |
-
def load_image(img_path, size=224):
|
21 |
-
try:
|
22 |
-
d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
23 |
-
d_img = cv2.resize(d_img, (size, size), interpolation=cv2.INTER_CUBIC)
|
24 |
-
d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
|
25 |
-
d_img = np.array(d_img).astype('float32') / 255
|
26 |
-
d_img = np.transpose(d_img, (2, 0, 1))
|
27 |
-
except:
|
28 |
-
print(img_path)
|
29 |
-
|
30 |
-
return d_img
|
31 |
-
|
32 |
-
def load_model(pkl_path):
|
33 |
-
|
34 |
-
model = MoNet.MoNet()
|
35 |
-
dict_pkl = {}
|
36 |
-
# prompt_num = torch.load(pkl_path, map_location='cpu').get('prompt_num')
|
37 |
-
for key, value in torch.load(pkl_path, map_location='cpu')['state_dict'].items():
|
38 |
-
dict_pkl[key[7:]] = value
|
39 |
-
model.load_state_dict(dict_pkl)
|
40 |
-
print('Load Model From ', pkl_path)
|
41 |
-
|
42 |
-
return model
|
43 |
-
|
44 |
-
def get_an_img_score(img_path, target):
|
45 |
-
transform=torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
|
46 |
-
values_to_insert = np.array([0.0, 1.0])
|
47 |
-
position_to_insert = 0
|
48 |
-
target = np.insert(target, position_to_insert, values_to_insert)
|
49 |
-
|
50 |
-
sample = load_image(img_path)
|
51 |
-
samples = {'img': sample, 'gt': target}
|
52 |
-
samples = transform(samples)
|
53 |
-
|
54 |
-
return samples
|
55 |
-
|
56 |
-
import random
|
57 |
-
if __name__ == '__main__':
|
58 |
-
pkl_path = "./checkpoints/best_model_five_22.pth.tar"
|
59 |
-
model = load_model(pkl_path).cuda()
|
60 |
-
model.eval()
|
61 |
-
|
62 |
-
img_path = '/mnt/storage/PromptIQA_Demo/CSIQ/dst_src'
|
63 |
-
|
64 |
-
img_tensor, gt_tensor = None, None
|
65 |
-
img_list = os.listdir(img_path)
|
66 |
-
random.shuffle(img_list)
|
67 |
-
for idx, img_name in enumerate(img_list):
|
68 |
-
if idx == 10:
|
69 |
-
break
|
70 |
-
|
71 |
-
img_name = os.path.join(img_path, img_name)
|
72 |
-
score = np.array(idx / 10)
|
73 |
-
samples = get_an_img_score(img_name, score)
|
74 |
-
|
75 |
-
if img_tensor is None:
|
76 |
-
img_tensor = samples['img'].unsqueeze(0)
|
77 |
-
gt_tensor = samples['gt'].type(torch.FloatTensor).unsqueeze(0)
|
78 |
-
else:
|
79 |
-
img_tensor = torch.cat((img_tensor, samples['img'].unsqueeze(0)), dim=0)
|
80 |
-
gt_tensor = torch.cat((gt_tensor, samples['gt'].type(torch.FloatTensor).unsqueeze(0)), dim=0)
|
81 |
-
|
82 |
-
print(img_tensor.shape)
|
83 |
-
print(gt_tensor.shape)
|
84 |
-
print(gt_tensor)
|
85 |
-
|
86 |
-
img = img_tensor.squeeze(0).cuda()
|
87 |
-
label = gt_tensor.squeeze(0).cuda()
|
88 |
-
reverse = False
|
89 |
-
if reverse == 2:
|
90 |
-
label = torch.rand_like(label[:, -1]).cuda()
|
91 |
-
print(label)
|
92 |
-
elif reverse == 3:
|
93 |
-
print('Total Random')
|
94 |
-
label = torch.rand_like(label[:, -1]).cuda()
|
95 |
-
img = torch.rand_like(img).cuda()
|
96 |
-
else:
|
97 |
-
label = label[:, -1].cuda() if not reverse else (1 - label[:, -1].cuda())
|
98 |
-
print('input label: ', label)
|
99 |
-
model.forward_prompt(img, label.reshape(-1, 1), 'livec')
|
100 |
-
|
101 |
-
img_name = '/mnt/storage/PromptIQA_Demo/CSIQ/src_imgs/1600.png'
|
102 |
-
score = np.array(random.random())
|
103 |
-
samples = get_an_img_score(img_name, score)
|
104 |
-
|
105 |
-
img = samples['img'].unsqueeze(0).cuda()
|
106 |
-
print(img.shape)
|
107 |
-
pred = model.inference(img, 'livec')
|
108 |
-
|
109 |
-
print(pred)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/run_promptIQA.py
CHANGED
@@ -3,7 +3,7 @@ import random
|
|
3 |
import torchvision
|
4 |
import cv2
|
5 |
import torch
|
6 |
-
from PromptIQA.models import
|
7 |
import numpy as np
|
8 |
from PromptIQA.utils.dataset.process import ToTensor, Normalize
|
9 |
from PromptIQA.utils.toolkit import *
|
@@ -14,7 +14,7 @@ import sys
|
|
14 |
sys.path.append(os.path.dirname(__file__))
|
15 |
|
16 |
def load_model(pkl_path):
|
17 |
-
model =
|
18 |
dict_pkl = {}
|
19 |
for key, value in torch.load(pkl_path, map_location='cpu')['state_dict'].items():
|
20 |
dict_pkl[key[7:]] = value
|
|
|
3 |
import torchvision
|
4 |
import cv2
|
5 |
import torch
|
6 |
+
from PromptIQA.models import promptiqa
|
7 |
import numpy as np
|
8 |
from PromptIQA.utils.dataset.process import ToTensor, Normalize
|
9 |
from PromptIQA.utils.toolkit import *
|
|
|
14 |
sys.path.append(os.path.dirname(__file__))
|
15 |
|
16 |
def load_model(pkl_path):
|
17 |
+
model = promptiqa.PromptIQA()
|
18 |
dict_pkl = {}
|
19 |
for key, value in torch.load(pkl_path, map_location='cpu')['state_dict'].items():
|
20 |
dict_pkl[key[7:]] = value
|
PromptIQA/t.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
a = "(1+1)**(2**2)"
|
2 |
-
print(eval(a))
|
|
|
|
|
|
PromptIQA/test.py
DELETED
@@ -1,429 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
|
3 |
-
from utils import log_writer
|
4 |
-
|
5 |
-
import argparse
|
6 |
-
import builtins
|
7 |
-
import os
|
8 |
-
import random
|
9 |
-
import shutil
|
10 |
-
import time
|
11 |
-
|
12 |
-
import torch
|
13 |
-
import torch.distributed as dist
|
14 |
-
import torch.multiprocessing as mp
|
15 |
-
import torch.nn as nn
|
16 |
-
import torch.nn.parallel
|
17 |
-
import torch.optim
|
18 |
-
import torch.utils.data
|
19 |
-
import torch.utils.data.distributed
|
20 |
-
# from models import monet as MoNet
|
21 |
-
from torch.utils.data import ConcatDataset
|
22 |
-
from utils.dataset import data_loader
|
23 |
-
|
24 |
-
from utils.toolkit import *
|
25 |
-
|
26 |
-
loger_path = None
|
27 |
-
|
28 |
-
|
29 |
-
def init(config):
|
30 |
-
global loger_path
|
31 |
-
if config.dist_url == "env://" and config.world_size == -1:
|
32 |
-
config.world_size = int(os.environ["WORLD_SIZE"])
|
33 |
-
|
34 |
-
config.distributed = config.world_size > 1 or config.multiprocessing_distributed
|
35 |
-
|
36 |
-
print("config.distributed", config.distributed)
|
37 |
-
|
38 |
-
loger_path = os.path.join(config.save_path, "inference_log")
|
39 |
-
if not os.path.isdir(loger_path):
|
40 |
-
os.makedirs(loger_path)
|
41 |
-
|
42 |
-
print("----------------------------------")
|
43 |
-
print(
|
44 |
-
"Begin Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
|
45 |
-
)
|
46 |
-
printArgs(config, loger_path)
|
47 |
-
# os.environ["CUDA_VISIBLE_DEVICES"] = '2,3,4,5,6,7'
|
48 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
49 |
-
# os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3,4,5'
|
50 |
-
# os.environ["CUDA_VISIBLE_DEVICES"] = '6,7'
|
51 |
-
# os.environ["CUDA_VISIBLE_DEVICES"] = '6'
|
52 |
-
# setup_seed(config.seed)
|
53 |
-
|
54 |
-
|
55 |
-
def main(config):
|
56 |
-
init(config)
|
57 |
-
ngpus_per_node = torch.cuda.device_count()
|
58 |
-
if config.multiprocessing_distributed:
|
59 |
-
config.world_size = ngpus_per_node * config.world_size
|
60 |
-
|
61 |
-
print(config.world_size, ngpus_per_node, ngpus_per_node)
|
62 |
-
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config))
|
63 |
-
else:
|
64 |
-
# Simply call main_worker function
|
65 |
-
main_worker(config.gpu, ngpus_per_node, config)
|
66 |
-
|
67 |
-
print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())))
|
68 |
-
|
69 |
-
|
70 |
-
@torch.no_grad()
|
71 |
-
def gather_together(data): # 封装成一个函数,,用于收集各个gpu上的data数据,并返回一个list
|
72 |
-
dist.barrier()
|
73 |
-
world_size = dist.get_world_size()
|
74 |
-
gather_data = [None for _ in range(world_size)]
|
75 |
-
dist.all_gather_object(gather_data, data)
|
76 |
-
return gather_data
|
77 |
-
|
78 |
-
import importlib.util
|
79 |
-
def main_worker(gpu, ngpus_per_node, args):
|
80 |
-
models_path = os.path.join(args.save_path, "training_files", 'models', 'monet.py')
|
81 |
-
spec = importlib.util.spec_from_file_location("monet_module", models_path)
|
82 |
-
monet_module = importlib.util.module_from_spec(spec)
|
83 |
-
spec.loader.exec_module(monet_module)
|
84 |
-
MoNet = monet_module
|
85 |
-
|
86 |
-
loger_path = os.path.join(args.save_path, "inference_log")
|
87 |
-
if gpu == 0:
|
88 |
-
sys.stdout = log_writer.Logger(os.path.join(loger_path, f"inference_log_{args.prompt_type}_{args.reverse}.log"))
|
89 |
-
args.gpu = gpu
|
90 |
-
|
91 |
-
# suppress printing if not master
|
92 |
-
if args.multiprocessing_distributed and args.gpu != 0:
|
93 |
-
def print_pass(*args):
|
94 |
-
pass
|
95 |
-
|
96 |
-
builtins.print = print_pass
|
97 |
-
|
98 |
-
if args.gpu is not None:
|
99 |
-
print("Use GPU: {} for testing".format(args.gpu))
|
100 |
-
|
101 |
-
if args.distributed:
|
102 |
-
if args.dist_url == "env://" and args.rank == -1:
|
103 |
-
args.rank = int(os.environ["RANK"])
|
104 |
-
if args.multiprocessing_distributed:
|
105 |
-
args.rank = args.rank * ngpus_per_node + gpu
|
106 |
-
dist.init_process_group(
|
107 |
-
backend=args.dist_backend,
|
108 |
-
init_method=args.dist_url,
|
109 |
-
world_size=args.world_size,
|
110 |
-
rank=args.rank,
|
111 |
-
)
|
112 |
-
|
113 |
-
# create model
|
114 |
-
model = MoNet.MoNet()
|
115 |
-
dict_pkl = {}
|
116 |
-
prompt_num = torch.load(args.pkl_path, map_location='cpu').get('prompt_num')
|
117 |
-
for key, value in torch.load(args.pkl_path, map_location='cpu')['state_dict'].items():
|
118 |
-
dict_pkl[key[7:]] = value
|
119 |
-
model.load_state_dict(dict_pkl)
|
120 |
-
print('Load Model From ', args.pkl_path)
|
121 |
-
|
122 |
-
if args.distributed:
|
123 |
-
if args.gpu is not None:
|
124 |
-
torch.cuda.set_device(args.gpu)
|
125 |
-
model.cuda(args.gpu)
|
126 |
-
args.batch_size = int(args.batch_size / ngpus_per_node)
|
127 |
-
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
128 |
-
model = torch.nn.parallel.DistributedDataParallel(
|
129 |
-
model, device_ids=[args.gpu]
|
130 |
-
)
|
131 |
-
print("Model Distribute.")
|
132 |
-
else:
|
133 |
-
model.cuda()
|
134 |
-
model = torch.nn.parallel.DistributedDataParallel(model)
|
135 |
-
|
136 |
-
if prompt_num is None:
|
137 |
-
prompt_num = args.batch_size - 1
|
138 |
-
prompt_num = 10
|
139 |
-
print('prompt_num', prompt_num)
|
140 |
-
|
141 |
-
test_prompt_list, test_data_list = {}, []
|
142 |
-
# fix_prompt = None
|
143 |
-
for dataset in args.dataset:
|
144 |
-
print('---Load ', dataset)
|
145 |
-
path, train_index, test_index = get_data(dataset=dataset, split_seed=args.seed)
|
146 |
-
# if dataset == 'spaq' and False:
|
147 |
-
if dataset == 'spaq':
|
148 |
-
for column in range(2, 8):
|
149 |
-
print('sapq column train', column)
|
150 |
-
test_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, test_index, istrain=False, column=column)
|
151 |
-
test_data_list.append(test_dataset.get_samples())
|
152 |
-
|
153 |
-
train_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, train_index, istrain=False, column=column)
|
154 |
-
test_prompt_list[dataset+f'_{column}'] = train_dataset.get_prompt(prompt_num, args.prompt_type)
|
155 |
-
else:
|
156 |
-
test_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, test_index, istrain=False, types=args.types)
|
157 |
-
test_data_list.append(test_dataset.get_samples())
|
158 |
-
|
159 |
-
train_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, train_index, istrain=False, types=args.types)
|
160 |
-
test_prompt_list[dataset] = train_dataset.get_prompt(prompt_num, args.prompt_type)
|
161 |
-
print('args.prompt_type', args.prompt_type)
|
162 |
-
|
163 |
-
combined_test_samples = ConcatDataset(test_data_list)
|
164 |
-
print("test_dataset", len(combined_test_samples))
|
165 |
-
test_sampler = torch.utils.data.distributed.DistributedSampler(combined_test_samples)
|
166 |
-
|
167 |
-
test_loader = torch.utils.data.DataLoader(
|
168 |
-
combined_test_samples,
|
169 |
-
batch_size=1,
|
170 |
-
shuffle=(test_sampler is None),
|
171 |
-
num_workers=args.workers,
|
172 |
-
sampler=test_sampler,
|
173 |
-
drop_last=False,
|
174 |
-
pin_memory=True,
|
175 |
-
)
|
176 |
-
|
177 |
-
if args.distributed:
|
178 |
-
test_sampler.set_epoch(0)
|
179 |
-
|
180 |
-
for idxsa in range(1):
|
181 |
-
test_srocc, test_plcc, pred_scores, gt_scores, path = test(
|
182 |
-
test_loader, model, test_prompt_list, reverse=args.reverse
|
183 |
-
)
|
184 |
-
print('gt_scores', len(pred_scores), len(gt_scores))
|
185 |
-
print('Summary---')
|
186 |
-
|
187 |
-
gt_scores = gather_together(gt_scores) # 进行汇总,得到一个list
|
188 |
-
pred_scores = gather_together(pred_scores) # 进行汇总,得到一个list
|
189 |
-
|
190 |
-
gt_score_dict, pred_score_dict = {}, {}
|
191 |
-
for sublist in gt_scores:
|
192 |
-
for k, v in sublist.items():
|
193 |
-
if k not in gt_score_dict:
|
194 |
-
gt_score_dict[k] = v
|
195 |
-
else:
|
196 |
-
gt_score_dict[k] = gt_score_dict[k] + v
|
197 |
-
|
198 |
-
for sublist in pred_scores:
|
199 |
-
for k, v in sublist.items():
|
200 |
-
if k not in pred_score_dict:
|
201 |
-
pred_score_dict[k] = v
|
202 |
-
else:
|
203 |
-
pred_score_dict[k] = pred_score_dict[k] + v
|
204 |
-
|
205 |
-
gt_score_dict = dict(sorted(gt_score_dict.items()))
|
206 |
-
test_srocc, test_plcc = 0, 0
|
207 |
-
for k, v in gt_score_dict.items():
|
208 |
-
test_srocc_, test_plcc_ = cal_srocc_plcc(gt_score_dict[k], pred_score_dict[k])
|
209 |
-
print('\t{} Test SROCC: {}, PLCC: {}'.format(k, round(test_srocc_, 4), round(test_plcc_, 4)))
|
210 |
-
# print('Pred: ', pred_score_dict[k][:10])
|
211 |
-
# print('GT: ', gt_score_dict[k][:10])
|
212 |
-
# print('-----')
|
213 |
-
|
214 |
-
with open('{}_{}.csv'.format(idxsa, k), 'w') as f:
|
215 |
-
for i, j in zip(gt_score_dict[k], pred_score_dict[k]):
|
216 |
-
f.write('{},{}\n'.format(i, j))
|
217 |
-
test_srocc += test_srocc_
|
218 |
-
test_plcc += test_plcc_
|
219 |
-
|
220 |
-
|
221 |
-
def test(test_loader, MoNet, promt_data_loader, reverse=False):
|
222 |
-
"""Training"""
|
223 |
-
pred_scores = {}
|
224 |
-
gt_scores = {}
|
225 |
-
path = []
|
226 |
-
|
227 |
-
batch_time = AverageMeter("Time", ":6.3f")
|
228 |
-
srocc = AverageMeter("SROCC", ":6.2f")
|
229 |
-
plcc = AverageMeter("PLCC", ":6.2f")
|
230 |
-
progress = ProgressMeter(
|
231 |
-
len(test_loader),
|
232 |
-
[batch_time, srocc, plcc],
|
233 |
-
prefix="Testing ",
|
234 |
-
)
|
235 |
-
|
236 |
-
print('reverse ----', reverse)
|
237 |
-
MoNet.train(False)
|
238 |
-
with torch.no_grad():
|
239 |
-
for index, (img_or, label_or, paths, dataset_type) in enumerate(test_loader):
|
240 |
-
# print(dataset_type)
|
241 |
-
t = time.time()
|
242 |
-
dataset_type = dataset_type[0]
|
243 |
-
|
244 |
-
has_prompt = False
|
245 |
-
if hasattr(MoNet.module, 'check_prompt'):
|
246 |
-
has_prompt = MoNet.module.check_prompt(dataset_type)
|
247 |
-
|
248 |
-
if not has_prompt:
|
249 |
-
print('Load Prompt For ', dataset_type)
|
250 |
-
prompt_dataset = promt_data_loader[dataset_type]
|
251 |
-
for img, label in prompt_dataset:
|
252 |
-
img = img.squeeze(0).cuda()
|
253 |
-
label = label.squeeze(0).cuda()
|
254 |
-
if reverse == 2:
|
255 |
-
# label = torch.tensor([random.random() for i in range(len(label[:, -1]))]).cuda()
|
256 |
-
#
|
257 |
-
label = torch.rand_like(label[:, -1]).cuda()
|
258 |
-
print(label)
|
259 |
-
elif reverse == 3:
|
260 |
-
print('Total Random')
|
261 |
-
label = torch.rand_like(label[:, -1]).cuda()
|
262 |
-
img = torch.rand_like(img).cuda()
|
263 |
-
else:
|
264 |
-
label = label[:, -1].cuda() if not reverse else (1 - label[:, -1].cuda())
|
265 |
-
MoNet.module.forward_prompt(img, label.reshape(-1, 1), dataset_type)
|
266 |
-
|
267 |
-
img = img_or.squeeze(0).cuda()
|
268 |
-
label = label_or.squeeze(0).cuda()[:, 2]
|
269 |
-
|
270 |
-
# print(img.shape)
|
271 |
-
|
272 |
-
pred = MoNet.module.inference(img, dataset_type)
|
273 |
-
|
274 |
-
if dataset_type not in pred_scores:
|
275 |
-
pred_scores[dataset_type] = []
|
276 |
-
|
277 |
-
if dataset_type not in gt_scores:
|
278 |
-
gt_scores[dataset_type] = []
|
279 |
-
|
280 |
-
pred_scores[dataset_type] = pred_scores[dataset_type] + pred.cpu().tolist()
|
281 |
-
gt_scores[dataset_type] = gt_scores[dataset_type] + label.cpu().tolist()
|
282 |
-
path = path + list(paths)
|
283 |
-
|
284 |
-
batch_time.update(time.time() - t)
|
285 |
-
|
286 |
-
if index % 100 == 0:
|
287 |
-
for k, v in pred_scores.items():
|
288 |
-
test_srocc, test_plcc = cal_srocc_plcc(pred_scores[k], gt_scores[k])
|
289 |
-
# print('\t{}, SROCC: {}, PLCC: {}'.format(k, round(test_srocc, 4), round(test_plcc, 4)))
|
290 |
-
srocc.update(test_srocc)
|
291 |
-
plcc.update(test_plcc)
|
292 |
-
|
293 |
-
progress.display(index)
|
294 |
-
|
295 |
-
MoNet.module.clear()
|
296 |
-
# MoNet.train(True)
|
297 |
-
return 'test_srocc', 'test_plcc', pred_scores, gt_scores, path
|
298 |
-
|
299 |
-
if __name__ == "__main__":
|
300 |
-
parser = argparse.ArgumentParser()
|
301 |
-
parser.add_argument(
|
302 |
-
"--seed",
|
303 |
-
dest="seed",
|
304 |
-
type=int,
|
305 |
-
default=570908,
|
306 |
-
help="Random seeds for result reproduction.",
|
307 |
-
)
|
308 |
-
|
309 |
-
parser.add_argument(
|
310 |
-
"--mal_num",
|
311 |
-
dest="mal_num",
|
312 |
-
type=int,
|
313 |
-
default=2,
|
314 |
-
help="The number of the MAL modules.",
|
315 |
-
)
|
316 |
-
|
317 |
-
# data related
|
318 |
-
parser.add_argument(
|
319 |
-
"--dataset",
|
320 |
-
dest="dataset",
|
321 |
-
nargs='+', default=None,
|
322 |
-
help="Support datasets: livec|koniq10k|bid|spaq",
|
323 |
-
)
|
324 |
-
|
325 |
-
# training related
|
326 |
-
parser.add_argument(
|
327 |
-
"--queue_ratio",
|
328 |
-
dest="queue_ratio",
|
329 |
-
type=float,
|
330 |
-
default=0.6,
|
331 |
-
help="Ratio of queue length used in GC loss to training set length.",
|
332 |
-
)
|
333 |
-
|
334 |
-
parser.add_argument(
|
335 |
-
"--loss",
|
336 |
-
dest="loss",
|
337 |
-
type=str,
|
338 |
-
default="MSE",
|
339 |
-
help="Loss function to use. Support losses: GC|MAE|MSE.",
|
340 |
-
)
|
341 |
-
|
342 |
-
parser.add_argument(
|
343 |
-
"--lr", dest="lr", type=float, default=1e-5, help="Learning rate"
|
344 |
-
)
|
345 |
-
|
346 |
-
parser.add_argument(
|
347 |
-
"--weight_decay",
|
348 |
-
dest="weight_decay",
|
349 |
-
type=float,
|
350 |
-
default=1e-5,
|
351 |
-
help="Weight decay",
|
352 |
-
)
|
353 |
-
parser.add_argument(
|
354 |
-
"--batch_size", dest="batch_size", type=int, default=11, help="Batch size"
|
355 |
-
)
|
356 |
-
parser.add_argument(
|
357 |
-
"--epochs", dest="epochs", type=int, default=50, help="Epochs for training"
|
358 |
-
)
|
359 |
-
parser.add_argument(
|
360 |
-
"--T_max",
|
361 |
-
dest="T_max",
|
362 |
-
type=int,
|
363 |
-
default=50,
|
364 |
-
help="Hyper-parameter for CosineAnnealingLR",
|
365 |
-
)
|
366 |
-
parser.add_argument(
|
367 |
-
"--eta_min",
|
368 |
-
dest="eta_min",
|
369 |
-
type=int,
|
370 |
-
default=0,
|
371 |
-
help="Hyper-parameter for CosineAnnealingLR",
|
372 |
-
)
|
373 |
-
|
374 |
-
parser.add_argument(
|
375 |
-
"-j",
|
376 |
-
"--workers",
|
377 |
-
default=32,
|
378 |
-
type=int,
|
379 |
-
metavar="N",
|
380 |
-
help="number of data loading workers (default: 32)",
|
381 |
-
)
|
382 |
-
|
383 |
-
# result related
|
384 |
-
parser.add_argument(
|
385 |
-
"--save_path",
|
386 |
-
dest="save_path",
|
387 |
-
type=str,
|
388 |
-
default="./save_logs/Matrix_Comparation_Koniq_bs_25",
|
389 |
-
help="The path where the model and logs will be saved.",
|
390 |
-
)
|
391 |
-
|
392 |
-
parser.add_argument(
|
393 |
-
"--world-size",
|
394 |
-
default=-1,
|
395 |
-
type=int,
|
396 |
-
help="number of nodes for distributed training",
|
397 |
-
)
|
398 |
-
parser.add_argument(
|
399 |
-
"--rank", default=-1, type=int, help="node rank for distributed training"
|
400 |
-
)
|
401 |
-
parser.add_argument(
|
402 |
-
"--dist-url",
|
403 |
-
default="tcp://224.66.41.62:23456",
|
404 |
-
type=str,
|
405 |
-
help="url used to set up distributed training",
|
406 |
-
)
|
407 |
-
parser.add_argument(
|
408 |
-
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
409 |
-
)
|
410 |
-
parser.add_argument(
|
411 |
-
"--multiprocessing-distributed",
|
412 |
-
action="store_true",
|
413 |
-
help="Use multi-processing distributed training to launch "
|
414 |
-
"N processes per node, which has N GPUs. This is the "
|
415 |
-
"fastest way to use PyTorch for either single node or "
|
416 |
-
"multi node data parallel training",
|
417 |
-
)
|
418 |
-
|
419 |
-
parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.")
|
420 |
-
parser.add_argument("--pkl_path", required=True, type=str)
|
421 |
-
parser.add_argument("--prompt_type", required=True, type=str)
|
422 |
-
parser.add_argument("--reverse", required=True, type=int)
|
423 |
-
parser.add_argument("--types", default='SSIM', type=str)
|
424 |
-
|
425 |
-
config = parser.parse_args()
|
426 |
-
|
427 |
-
config.save_path = os.path.dirname(config.pkl_path)
|
428 |
-
|
429 |
-
main(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/test.sh
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
# python test.py --dist-url 'tcp://localhost:10055' --dataset spaq tid2013 livec bid spaq flive --batch_size 50 --prompt_type fix --multiprocessing-distributed --world-size 1 --rank 0 --reverse 0 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/FourTask/N_F_A_U_RandomScale_MAE_loaderDebug_Rate95/best_model_five_52.pth.tar
|
2 |
-
# python test.py --dist-url 'tcp://localhost:12755' --dataset csiq --batch_size 50 --prompt_type fix --multiprocessing-distributed --world-size 1 --rank 0 --reverse 3 --seed 2024 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Training_log/FourTask/N_F_A_U_RandomScale_MAE_loaderDebug_Rate95/best_model_five_52.pth.tar
|
3 |
-
python test.py --dist-url 'tcp://localhost:12755' --dataset livec bid csiq --batch_size 50 --prompt_type random --multiprocessing-distributed --world-size 1 --rank 0 --reverse 2 --seed 2026 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Formal/PromptIQA_2026/best_model_five_92.pth.tar
|
4 |
-
# reverse 0 no, 1 yes, 2 random
|
5 |
-
|
6 |
-
python test.py --dist-url 'tcp://localhost:12755' --dataset tid2013_other --batch_size 50 --prompt_type random --multiprocessing-distributed --world-size 1 --rank 0 --reverse 2 --seed 2026 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Formal/PromptIQA_2026/best_model_five_92.pth.tar --types 'SSIM'
|
7 |
-
|
8 |
-
|
9 |
-
CUDA_VISIBLE_DEVICES="0" python test.py --dist-url 'tcp://localhost:12755' --dataset tid2013_other --batch_size 50 --prompt_type random --multiprocessing-distributed --world-size 1 --rank 0 --reverse 2 --seed 2024 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Publication/PromptIQA_2024_WO_Norm_Score/best_model_five_22.pth.tar --types 'SSIM'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
best_model.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:993555b9efaeae660d2dd6f4056f13c6957628ca592a2ce74ff2e8eb5a4a2280
|
3 |
+
size 1272842308
|
get_examplt.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from copy import deepcopy
|
3 |
-
|
4 |
-
|
5 |
-
isp_json = []
|
6 |
-
path = './Examples'
|
7 |
-
for img_dir in sorted(os.listdir(path)):
|
8 |
-
if os.path.isdir(os.path.join(path, img_dir)):
|
9 |
-
ISPP = os.path.join(path, img_dir, 'ISPP')
|
10 |
-
|
11 |
-
ispp = {}
|
12 |
-
ispp['Example_id'] = img_dir
|
13 |
-
ispp['ISPP'] = []
|
14 |
-
img_list = []
|
15 |
-
for idx, img in enumerate(sorted(os.listdir(ISPP))):
|
16 |
-
ispp['ISPP'].append([os.path.join(ISPP, img), idx / 10 if '1' in img_dir else 1 - idx / 10])
|
17 |
-
|
18 |
-
for file in os.listdir(os.path.join(path, img_dir)):
|
19 |
-
if os.path.isfile(os.path.join(path, img_dir, file)):
|
20 |
-
img_list.append(file)
|
21 |
-
ispp['Image'] = [os.path.join(path, img_dir, file), 7]
|
22 |
-
ispp['Remark'] = []
|
23 |
-
isp_json.append(deepcopy(ispp))
|
24 |
-
|
25 |
-
with open('example2.json', 'w') as f:
|
26 |
-
import json
|
27 |
-
json.dump(isp_json, f, indent=4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|