Zevin2023 commited on
Commit
ccbd00a
·
1 Parent(s): ed2472a

refine promptiqa.py

Browse files
PromptIQA/models/gc_loss.py DELETED
@@ -1,99 +0,0 @@
1
- import torch.nn as nn
2
- import torch
3
- import numpy as np
4
-
5
-
6
- class GC_Loss(nn.Module):
7
- def __init__(self, queue_len=800, alpha=0.5, beta=0.5, gamma=1):
8
- super(GC_Loss, self).__init__()
9
- self.pred_queue = list()
10
- self.gt_queue = list()
11
- self.queue_len = 0
12
-
13
- self.queue_max_len = queue_len
14
- print('The queue length is: ', self.queue_max_len)
15
- self.mse = torch.nn.MSELoss().cuda()
16
-
17
- self.alpha, self.beta, self.gamma = alpha, beta, gamma
18
-
19
- def consistency(self, pred_data, gt_data):
20
- pred_one_batch, pred_queue = pred_data
21
- gt_one_batch, gt_queue = gt_data
22
-
23
- pred_mean = torch.mean(pred_queue)
24
- gt_mean = torch.mean(gt_queue)
25
-
26
- diff_pred = pred_one_batch - pred_mean
27
- diff_gt = gt_one_batch - gt_mean
28
-
29
- x1 = torch.sum(torch.mul(diff_pred, diff_gt))
30
- x2_1 = torch.sqrt(torch.sum(torch.mul(diff_pred, diff_pred)))
31
- x2_2 = torch.sqrt(torch.sum(torch.mul(diff_gt, diff_gt)))
32
-
33
- return x1 / (x2_1 * x2_2)
34
-
35
- def ppra(self, x):
36
- """
37
- Pairwise Preference-based Rank Approximation
38
- """
39
-
40
- x_bar, x_std = torch.mean(x), torch.std(x)
41
- x_n = (x - x_bar) / x_std
42
- x_n_T = x_n.reshape(-1, 1)
43
-
44
- rank_x = x_n_T - x_n_T.transpose(1, 0)
45
- rank_x = torch.sum(1 / 2 * (1 + torch.erf(rank_x / torch.sqrt(torch.tensor(2, dtype=torch.float)))), dim=1)
46
-
47
- return rank_x
48
-
49
- @torch.no_grad()
50
- def enqueue(self, pred, gt):
51
- bs = pred.shape[0]
52
- self.queue_len = self.queue_len + bs
53
-
54
- self.pred_queue = self.pred_queue + pred.tolist()
55
- self.gt_queue = self.gt_queue + gt.cpu().detach().numpy().tolist()
56
-
57
- if self.queue_len > self.queue_max_len:
58
- self.dequeue(self.queue_len - self.queue_max_len)
59
- self.queue_len = self.queue_max_len
60
-
61
- @torch.no_grad()
62
- def dequeue(self, n):
63
- for _ in range(n):
64
- self.pred_queue.pop(0)
65
- self.gt_queue.pop(0)
66
-
67
- def clear(self):
68
- self.pred_queue.clear()
69
- self.gt_queue.clear()
70
-
71
- def forward(self, x, y):
72
- x_queue = self.pred_queue.copy()
73
- y_queue = self.gt_queue.copy()
74
-
75
- x_all = torch.cat((x, torch.tensor(x_queue).cuda()), dim=0)
76
- y_all = torch.cat((y, torch.tensor(y_queue).cuda()), dim=0)
77
-
78
- PLCC = self.consistency((x, x_all), (y, y_all))
79
- PGC = 1 - PLCC
80
-
81
- rank_x = self.ppra(x_all)
82
- rank_y = self.ppra(y_all)
83
- SROCC = self.consistency((rank_x[:x.shape[0]], rank_x), (rank_y[:y.shape[0]], rank_y))
84
- SGC = 1 - SROCC
85
-
86
- GC = (self.alpha * PGC + self.beta * SGC + self.gamma) * self.mse(x, y)
87
- self.enqueue(x, y)
88
-
89
- return GC
90
-
91
-
92
- if __name__ == '__main__':
93
- gc = GC_Loss().cuda()
94
- x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float).cuda()
95
- y = torch.tensor([6, 7, 8, 9, 15], dtype=torch.float).cuda()
96
-
97
- res = gc(x, y)
98
-
99
- print(res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PromptIQA/models/monet_IPF.py DELETED
@@ -1,397 +0,0 @@
1
- """
2
- The completion for Mean-opinion Network(MoNet)
3
- """
4
- import torch
5
- import torch.nn as nn
6
- import timm
7
-
8
- from timm.models.vision_transformer import Block
9
- from einops import rearrange
10
- from itertools import combinations
11
-
12
- from tqdm import tqdm
13
-
14
- class Attention_Block(nn.Module):
15
- def __init__(self, dim, drop=0.1):
16
- super().__init__()
17
- self.c_q = nn.Linear(dim, dim)
18
- self.c_k = nn.Linear(dim, dim)
19
- self.c_v = nn.Linear(dim, dim)
20
- self.norm_fact = dim ** -0.5
21
- self.softmax = nn.Softmax(dim=-1)
22
- self.proj_drop = nn.Dropout(drop)
23
-
24
- def forward(self, x):
25
- _x = x
26
- B, C, N = x.shape
27
- q = self.c_q(x)
28
- k = self.c_k(x)
29
- v = self.c_v(x)
30
-
31
- attn = q @ k.transpose(-2, -1) * self.norm_fact
32
- attn = self.softmax(attn)
33
- x = (attn @ v).transpose(1, 2).reshape(B, C, N)
34
- x = self.proj_drop(x)
35
- x = x + _x
36
- return x
37
-
38
-
39
- class Self_Attention(nn.Module):
40
- """ Self attention Layer"""
41
-
42
- def __init__(self, in_dim):
43
- super(Self_Attention, self).__init__()
44
-
45
- self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
46
- self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
47
- self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
48
- self.gamma = nn.Parameter(torch.zeros(1))
49
-
50
- self.softmax = nn.Softmax(dim=-1)
51
-
52
- def forward(self, inFeature):
53
- bs, C, w, h = inFeature.size()
54
-
55
- proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
56
- proj_key = self.kConv(inFeature).view(bs, -1, w * h)
57
- energy = torch.bmm(proj_query, proj_key)
58
- attention = self.softmax(energy)
59
- proj_value = self.vConv(inFeature).view(bs, -1, w * h)
60
-
61
- out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
62
- out = out.view(bs, C, w, h)
63
-
64
- out = self.gamma * out + inFeature
65
-
66
- return out
67
-
68
-
69
- class MAL(nn.Module):
70
- """
71
- Multi-view Attention Learning (MAL) module
72
- """
73
-
74
- def __init__(self, in_dim=768, feature_num=4, feature_size=28):
75
- super().__init__()
76
-
77
- self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
78
- self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention
79
-
80
- # Self attention module for each input feature
81
- self.attention_module = nn.ModuleList()
82
- for _ in range(feature_num):
83
- self.attention_module.append(Self_Attention(in_dim))
84
-
85
- self.feature_num = feature_num
86
- self.in_dim = in_dim
87
-
88
- def forward(self, features):
89
- feature = torch.tensor([]).cuda()
90
- for index, _ in enumerate(features):
91
- feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
92
- features = feature
93
-
94
- input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28
95
- bs, _, _ = input_tensor.shape # [2, 3072, 784]
96
-
97
- in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim,
98
- c=self.feature_num) # bs, 768, 28 * 28 * feature_num
99
- feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768
100
-
101
- in_channel = input_tensor.permute(0, 2, 1).contiguous() # bs, 28 * 28, 768 * feature_num
102
- channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28
103
-
104
- weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim,
105
- c=self.feature_num) + channel_weight_sum.permute(0, 2, 1).contiguous()) / 2 # [2, 3072, 784]
106
-
107
- weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1)
108
-
109
- return weight_sum_res # bs, 768, 28 * 28
110
-
111
-
112
- class SaveOutput:
113
- def __init__(self):
114
- self.outputs = []
115
-
116
- def __call__(self, module, module_in, module_out):
117
- self.outputs.append(module_out)
118
-
119
- def clear(self):
120
- self.outputs = []
121
-
122
- # utils
123
- @torch.no_grad()
124
- def concat_all_gather(tensor):
125
- """
126
- Performs all_gather operation on the provided tensors.
127
- *** Warning ***: torch.distributed.all_gather has no gradient.
128
- """
129
- tensors_gather = [
130
- torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
131
- ]
132
- torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
133
-
134
- output = torch.cat(tensors_gather, dim=0)
135
- return output
136
- class Attention(nn.Module):
137
- def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
138
- super().__init__()
139
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
140
- self.num_heads = num_heads
141
- head_dim = dim // num_heads
142
- self.scale = head_dim ** -0.5
143
-
144
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
145
- self.attn_drop = nn.Dropout(attn_drop)
146
- self.proj = nn.Linear(dim, dim)
147
- self.proj_drop = nn.Dropout(proj_drop)
148
-
149
- def forward(self, x):
150
- B, N, C = x.shape
151
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
152
- q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
153
-
154
- attn = (q @ k.transpose(-2, -1)) * self.scale
155
- attn = attn.softmax(dim=-1)
156
- attn = self.attn_drop(attn)
157
-
158
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
159
- x = self.proj(x)
160
- x = self.proj_drop(x)
161
- return x
162
- import torch
163
- from functools import partial
164
- class MoNet(nn.Module):
165
- def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
166
- super().__init__()
167
- self.img_size = img_size
168
- self.input_size = img_size // patch_size
169
- self.dim_mlp = dim_mlp
170
-
171
- self.vit = timm.create_model('vit_base_patch8_224', pretrained=True)
172
- self.vit.norm = nn.Identity()
173
- self.vit.head = nn.Identity()
174
-
175
- self.save_output = SaveOutput()
176
-
177
- # Register Hooks
178
- hook_handles = []
179
- for layer in self.vit.modules():
180
- if isinstance(layer, Block):
181
- handle = layer.register_forward_hook(self.save_output)
182
- hook_handles.append(handle)
183
-
184
- self.MALs = nn.ModuleList()
185
- for _ in range(3):
186
- self.MALs.append(MAL())
187
-
188
- # Image Quality Score Regression
189
- self.fusion_mal = MAL(feature_num=3)
190
- self.block = Block(dim_mlp, 12)
191
- self.cnn = nn.Sequential(
192
- nn.Conv2d(dim_mlp, 256, 5),
193
- nn.BatchNorm2d(256),
194
- nn.ReLU(inplace=True),
195
- nn.AvgPool2d((2, 2)),
196
- nn.Conv2d(256, 128, 3),
197
- nn.BatchNorm2d(128),
198
- nn.ReLU(inplace=True),
199
- nn.AvgPool2d((2, 2)),
200
- nn.Conv2d(128, 128, 3),
201
- nn.BatchNorm2d(128),
202
- nn.ReLU(inplace=True),
203
- nn.AvgPool2d((3, 3)),
204
- )
205
-
206
- self.i_p_fusion = nn.Sequential(
207
- Block(128, 4),
208
- Block(128, 4),
209
- Block(128, 4),
210
- )
211
- self.mlp = nn.Sequential(
212
- nn.Linear(128, 64),
213
- nn.GELU(),
214
- nn.Linear(64, 128),
215
- )
216
-
217
- self.prompt_fusion = nn.Sequential(
218
- Block(128, 4),
219
- Block(128, 4),
220
- Block(128, 4),
221
- )
222
-
223
- dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule
224
- self.blocks = nn.Sequential(*[
225
- Block(
226
- dim=128, num_heads=4, mlp_ratio=4, qkv_bias=True, drop=0,
227
- attn_drop=0, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
228
- for i in range(8)])
229
- self.norm = nn.LayerNorm(128)
230
-
231
- self.score_block = nn.Sequential(
232
- nn.Linear(128, 128 // 2),
233
- nn.ReLU(),
234
- nn.Dropout(drop),
235
- nn.Linear(128 // 2, 1),
236
- nn.Sigmoid()
237
- )
238
-
239
- self.prompt_feature = {}
240
-
241
-
242
-
243
- @torch.no_grad()
244
- def clear(self):
245
- self.prompt_feature = {}
246
-
247
- @torch.no_grad()
248
- def inference(self, x, data_type):
249
- prompt_feature = self.prompt_feature[data_type] # 1, n, 128
250
-
251
- _x = self.vit(x)
252
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
253
- self.save_output.outputs.clear()
254
-
255
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
256
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
257
- x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
258
-
259
- # Different Opinion Features (DOF)
260
- DOF = torch.tensor([]).cuda()
261
- for index, _ in enumerate(self.MALs):
262
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
263
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
264
-
265
- # Image Quality Score Regression
266
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
267
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
268
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
269
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
270
-
271
- prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
272
- prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128
273
-
274
- fusion = self.blocks(torch.cat((img_feature, prompt_feature), dim=1))[:, 0, :] # bs, 2, 1
275
- # fusion = self.norm(fusion)[:, 0, :]
276
- # fusion = self.score_block(fusion)
277
-
278
- # # iq_res = torch.mean(fusion, dim=1).view(-1)
279
- # iq_res = fusion[:, 0].view(-1)
280
-
281
- return fusion
282
-
283
- @torch.no_grad()
284
- def check_prompt(self, data_type):
285
- return data_type in self.prompt_feature
286
-
287
- @torch.no_grad()
288
- def forward_prompt(self, x, score, data_type):
289
- if data_type in self.prompt_feature:
290
- return
291
- _x = self.vit(x)
292
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
293
- self.save_output.outputs.clear()
294
-
295
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
296
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
297
- x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
298
-
299
- # Different Opinion Features (DOF)
300
- DOF = torch.tensor([]).cuda()
301
- for index, _ in enumerate(self.MALs):
302
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
303
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
304
-
305
- # Image Quality Score Regression
306
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
307
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
308
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
309
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
310
-
311
- # 分数线性变换为128维
312
- # score_feature = self.score_projection(score) # bs, 128
313
- score_feature = score.expand(-1, 128)
314
-
315
- # img_feature 和 score_feature融合得到 funsion_feature
316
- funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
317
- funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
318
-
319
- # print('Load Prompt For Testing.', funsion_feature.shape)
320
- # self.prompt_feature = funsion_feature.clone()
321
- self.prompt_feature[data_type] = funsion_feature.clone()
322
-
323
- def forward(self, x, score):
324
- _x = self.vit(x)
325
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
326
- self.save_output.outputs.clear()
327
-
328
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
329
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
330
- x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
331
-
332
- # Different Opinion Features (DOF)
333
- DOF = torch.tensor([]).cuda()
334
- for index, _ in enumerate(self.MALs):
335
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
336
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
337
-
338
- # Image Quality Score Regression
339
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
340
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
341
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
342
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
343
-
344
- # 分数线性变换为128维
345
- # score_feature = self.score_projection(score) # bs, 128
346
- score_feature = score.expand(-1, 128) # bs, 128
347
-
348
- # img_feature 和 score_feature融合得到 funsion_feature
349
- funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
350
- funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
351
- funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
352
- funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
353
-
354
- fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1
355
- fusion = self.norm(fusion)
356
- fusion = self.score_block(fusion)
357
- # iq_res = torch.mean(fusion, dim=1).view(-1)
358
- iq_res = fusion[:, 0].view(-1)
359
-
360
- # differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
361
- # differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
362
-
363
- gt_res = score.view(-1)
364
- # diff_gt_res = 1 - score.view(-1)
365
-
366
- return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
367
-
368
- def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
369
- x1 = save_output.outputs[block_index[0]][:, 1:]
370
- x2 = save_output.outputs[block_index[1]][:, 1:]
371
- x3 = save_output.outputs[block_index[2]][:, 1:]
372
- x4 = save_output.outputs[block_index[3]][:, 1:]
373
- x = torch.cat((x1, x2, x3, x4), dim=2)
374
- return x
375
-
376
- def expand(self, A):
377
- A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
378
-
379
- B = None
380
- for index, i in enumerate(A_expanded):
381
- rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
382
- if B is None:
383
- B = rmv
384
- else:
385
- B = torch.cat((B, rmv), dim=0)
386
-
387
- return B
388
-
389
- if __name__ == '__main__':
390
- in_feature = torch.zeros((10, 3, 224, 224)).cuda()
391
- gt_feature = torch.tensor([[0, 100, 1], [0, 100, 2], [0, 100, 3], [0, 100, 4], [0, 100, 5], [0, 100, 6], [0, 100, 7], [0, 100, 8], [0, 100, 9], [0, 100, 10]], dtype=torch.float).cuda()
392
- model = MoNet().cuda()
393
-
394
- iq_res, gt_res = model(in_feature, gt_feature)
395
-
396
- print(iq_res.shape)
397
- print(gt_res.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PromptIQA/models/monet_test.py DELETED
@@ -1,389 +0,0 @@
1
- """
2
- The completion for Mean-opinion Network(MoNet)
3
- """
4
- import torch
5
- import torch.nn as nn
6
- import timm
7
-
8
- from timm.models.vision_transformer import Block
9
- from einops import rearrange
10
- from itertools import combinations
11
-
12
- from tqdm import tqdm
13
-
14
- class Attention_Block(nn.Module):
15
- def __init__(self, dim, drop=0.1):
16
- super().__init__()
17
- self.c_q = nn.Linear(dim, dim)
18
- self.c_k = nn.Linear(dim, dim)
19
- self.c_v = nn.Linear(dim, dim)
20
- self.norm_fact = dim ** -0.5
21
- self.softmax = nn.Softmax(dim=-1)
22
- self.proj_drop = nn.Dropout(drop)
23
-
24
- def forward(self, x):
25
- _x = x
26
- B, C, N = x.shape
27
- q = self.c_q(x)
28
- k = self.c_k(x)
29
- v = self.c_v(x)
30
-
31
- attn = q @ k.transpose(-2, -1) * self.norm_fact
32
- attn = self.softmax(attn)
33
- x = (attn @ v).transpose(1, 2).reshape(B, C, N)
34
- x = self.proj_drop(x)
35
- x = x + _x
36
- return x
37
-
38
-
39
- class Self_Attention(nn.Module):
40
- """ Self attention Layer"""
41
-
42
- def __init__(self, in_dim):
43
- super(Self_Attention, self).__init__()
44
-
45
- self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
46
- self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
47
- self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
48
- self.gamma = nn.Parameter(torch.zeros(1))
49
-
50
- self.softmax = nn.Softmax(dim=-1)
51
-
52
- def forward(self, inFeature):
53
- bs, C, w, h = inFeature.size()
54
-
55
- proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
56
- proj_key = self.kConv(inFeature).view(bs, -1, w * h)
57
- energy = torch.bmm(proj_query, proj_key)
58
- attention = self.softmax(energy)
59
- proj_value = self.vConv(inFeature).view(bs, -1, w * h)
60
-
61
- out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
62
- out = out.view(bs, C, w, h)
63
-
64
- out = self.gamma * out + inFeature
65
-
66
- return out
67
-
68
-
69
- class MAL(nn.Module):
70
- """
71
- Multi-view Attention Learning (MAL) module
72
- """
73
-
74
- def __init__(self, in_dim=768, feature_num=4, feature_size=28):
75
- super().__init__()
76
-
77
- self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
78
- self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention
79
-
80
- # Self attention module for each input feature
81
- self.attention_module = nn.ModuleList()
82
- for _ in range(feature_num):
83
- self.attention_module.append(Self_Attention(in_dim))
84
-
85
- self.feature_num = feature_num
86
- self.in_dim = in_dim
87
-
88
- def forward(self, features):
89
- feature = torch.tensor([]).cuda()
90
- for index, _ in enumerate(features):
91
- feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
92
- features = feature
93
-
94
- input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28
95
- bs, _, _ = input_tensor.shape # [2, 3072, 784]
96
-
97
- in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim,
98
- c=self.feature_num) # bs, 768, 28 * 28 * feature_num
99
- feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768
100
-
101
- in_channel = input_tensor.permute(0, 2, 1).contiguous() # bs, 28 * 28, 768 * feature_num
102
- channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28
103
-
104
- weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim,
105
- c=self.feature_num) + channel_weight_sum.permute(0, 2, 1).contiguous()) / 2 # [2, 3072, 784]
106
-
107
- weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1)
108
-
109
- return weight_sum_res # bs, 768, 28 * 28
110
-
111
-
112
- class SaveOutput:
113
- def __init__(self):
114
- self.outputs = []
115
-
116
- def __call__(self, module, module_in, module_out):
117
- self.outputs.append(module_out)
118
-
119
- def clear(self):
120
- self.outputs = []
121
-
122
- # utils
123
- @torch.no_grad()
124
- def concat_all_gather(tensor):
125
- """
126
- Performs all_gather operation on the provided tensors.
127
- *** Warning ***: torch.distributed.all_gather has no gradient.
128
- """
129
- tensors_gather = [
130
- torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
131
- ]
132
- torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
133
-
134
- output = torch.cat(tensors_gather, dim=0)
135
- return output
136
- class Attention(nn.Module):
137
- def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
138
- super().__init__()
139
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
140
- self.num_heads = num_heads
141
- head_dim = dim // num_heads
142
- self.scale = head_dim ** -0.5
143
-
144
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
145
- self.attn_drop = nn.Dropout(attn_drop)
146
- self.proj = nn.Linear(dim, dim)
147
- self.proj_drop = nn.Dropout(proj_drop)
148
-
149
- def forward(self, x):
150
- B, N, C = x.shape
151
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
152
- q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
153
-
154
- attn = (q @ k.transpose(-2, -1)) * self.scale
155
- attn = attn.softmax(dim=-1)
156
- attn = self.attn_drop(attn)
157
-
158
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
159
- x = self.proj(x)
160
- x = self.proj_drop(x)
161
- return x
162
- class MoNet(nn.Module):
163
- def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
164
- super().__init__()
165
- self.img_size = img_size
166
- self.input_size = img_size // patch_size
167
- self.dim_mlp = dim_mlp
168
-
169
- self.vit = timm.create_model('vit_base_patch8_224', pretrained=True)
170
- self.vit.norm = nn.Identity()
171
- self.vit.head = nn.Identity()
172
-
173
- self.save_output = SaveOutput()
174
-
175
- # Register Hooks
176
- hook_handles = []
177
- for layer in self.vit.modules():
178
- if isinstance(layer, Block):
179
- handle = layer.register_forward_hook(self.save_output)
180
- hook_handles.append(handle)
181
-
182
- self.MALs = nn.ModuleList()
183
- for _ in range(3):
184
- self.MALs.append(MAL())
185
-
186
- # Image Quality Score Regression
187
- self.fusion_mal = MAL(feature_num=3)
188
- self.block = Block(dim_mlp, 12)
189
- self.cnn = nn.Sequential(
190
- nn.Conv2d(dim_mlp, 256, 5),
191
- nn.BatchNorm2d(256),
192
- nn.ReLU(inplace=True),
193
- nn.AvgPool2d((2, 2)),
194
- nn.Conv2d(256, 128, 3),
195
- nn.BatchNorm2d(128),
196
- nn.ReLU(inplace=True),
197
- nn.AvgPool2d((2, 2)),
198
- nn.Conv2d(128, 128, 3),
199
- nn.BatchNorm2d(128),
200
- nn.ReLU(inplace=True),
201
- nn.AvgPool2d((3, 3)),
202
- )
203
-
204
- # self.score_projection = nn.Sequential(
205
- # nn.Linear(1, 64),
206
- # nn.GELU(),
207
- # nn.Linear(64, 128),
208
- # )
209
-
210
- # self.i_p_fusion = nn.Sequential(
211
- # Block(128, 8),
212
- # Block(128, 8),
213
- # Block(128, 8),
214
- # )
215
- self.i_p_fusion = nn.Sequential(
216
- Block(128, 4),
217
- Block(128, 4),
218
- Block(128, 4),
219
- )
220
- self.mlp = nn.Sequential(
221
- nn.Linear(128, 64),
222
- nn.GELU(),
223
- nn.Linear(64, 128),
224
- )
225
-
226
- self.score_block = nn.Sequential(
227
- Block(128, 4),
228
- Block(128, 4),
229
- # Block(128, 4),
230
- nn.Linear(128, 128 // 2),
231
- nn.ReLU(),
232
- nn.Dropout(drop),
233
- nn.Linear(128 // 2, 1),
234
- nn.Sigmoid()
235
- )
236
-
237
- # self.diff_block = nn.Sequential(
238
- # Block(128, 8),
239
- # Block(128, 8),
240
- # Block(128, 8),
241
- # nn.Linear(128, 64),
242
- # nn.GELU(),
243
- # nn.Linear(64, 1),
244
- # )
245
- self.prompt_feature = None
246
-
247
- @torch.no_grad()
248
- def clear(self):
249
- self.prompt_feature = None
250
-
251
- @torch.no_grad()
252
- def inference(self, x):
253
- prompt_feature = self.prompt_feature # 1, n, 128
254
-
255
- _x = self.vit(x)
256
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
257
- self.save_output.outputs.clear()
258
-
259
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
260
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
261
- x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
262
-
263
- # Different Opinion Features (DOF)
264
- DOF = torch.tensor([]).cuda()
265
- for index, _ in enumerate(self.MALs):
266
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
267
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
268
-
269
- # Image Quality Score Regression
270
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
271
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
272
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
273
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
274
-
275
- prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
276
-
277
- fusion = self.score_block(torch.cat((img_feature, prompt_feature), dim=1)) # bs, n, 1
278
-
279
- # iq_res = torch.mean(fusion, dim=1).view(-1)
280
- iq_res = fusion[:, 0].view(-1)
281
-
282
- return iq_res
283
-
284
- def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
285
- x1 = save_output.outputs[block_index[0]][:, 1:]
286
- x2 = save_output.outputs[block_index[1]][:, 1:]
287
- x3 = save_output.outputs[block_index[2]][:, 1:]
288
- x4 = save_output.outputs[block_index[3]][:, 1:]
289
- x = torch.cat((x1, x2, x3, x4), dim=2)
290
- return x
291
-
292
- @torch.no_grad()
293
- def forward_prompt(self, x, score):
294
- _x = self.vit(x)
295
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
296
- self.save_output.outputs.clear()
297
-
298
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
299
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
300
- x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
301
-
302
- # Different Opinion Features (DOF)
303
- DOF = torch.tensor([]).cuda()
304
- for index, _ in enumerate(self.MALs):
305
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
306
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
307
-
308
- # Image Quality Score Regression
309
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
310
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
311
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
312
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
313
-
314
- # 分数线性变换为128维
315
- # score_feature = self.score_projection(score) # bs, 128
316
- score_feature = score.expand(-1, 128)
317
-
318
- # img_feature 和 score_feature融合得到 funsion_feature
319
- funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
320
- funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
321
-
322
- print('Load Prompt For Testing.', funsion_feature.shape)
323
- self.prompt_feature = funsion_feature.clone()
324
-
325
- def expand(self, A):
326
- A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
327
-
328
- B = None
329
- for index, i in enumerate(A_expanded):
330
- rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
331
- if B is None:
332
- B = rmv
333
- else:
334
- B = torch.cat((B, rmv), dim=0)
335
-
336
- return B
337
-
338
- def forward(self, x, score):
339
- _x = self.vit(x)
340
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
341
- self.save_output.outputs.clear()
342
-
343
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
344
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
345
- x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
346
-
347
- # Different Opinion Features (DOF)
348
- DOF = torch.tensor([]).cuda()
349
- for index, _ in enumerate(self.MALs):
350
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
351
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
352
-
353
- # Image Quality Score Regression
354
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
355
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
356
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
357
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
358
-
359
- # 分数线性变换为128维
360
- # score_feature = self.score_projection(score) # bs, 128
361
- score_feature = score.expand(-1, 128) # bs, 128
362
-
363
- # img_feature 和 score_feature融合得到 funsion_feature
364
- funsion_feature = self.i_p_fusion(torch.cat((img_feature.detach(), score_feature.unsqueeze(1).detach()), dim=1)) # bs, 2, 128
365
- funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
366
- funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
367
-
368
- fusion = self.score_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
369
- # iq_res = torch.mean(fusion, dim=1).view(-1)
370
- iq_res = fusion[:, 0].view(-1)
371
-
372
- # differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
373
- # differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
374
-
375
- gt_res = score.view(-1)
376
- # diff_gt_res = 1 - score.view(-1)
377
-
378
- return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
379
-
380
-
381
- if __name__ == '__main__':
382
- in_feature = torch.zeros((10, 3, 224, 224)).cuda()
383
- gt_feature = torch.tensor([[0, 100, 1], [0, 100, 2], [0, 100, 3], [0, 100, 4], [0, 100, 5], [0, 100, 6], [0, 100, 7], [0, 100, 8], [0, 100, 9], [0, 100, 10]], dtype=torch.float).cuda()
384
- model = MoNet().cuda()
385
-
386
- iq_res, gt_res = model(in_feature, gt_feature)
387
-
388
- print(iq_res.shape)
389
- print(gt_res.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PromptIQA/models/monet_wo_prompt.py DELETED
@@ -1,392 +0,0 @@
1
- """
2
- The completion for Mean-opinion Network(MoNet)
3
- """
4
- import torch
5
- import torch.nn as nn
6
- import timm
7
-
8
- from timm.models.vision_transformer import Block
9
- from einops import rearrange
10
- from itertools import combinations
11
-
12
- from tqdm import tqdm
13
- import os
14
- # os.environ['CUDA_VISIBLE_DEVICES'] = '7'
15
-
16
- class Attention_Block(nn.Module):
17
- def __init__(self, dim, drop=0.1):
18
- super().__init__()
19
- self.c_q = nn.Linear(dim, dim)
20
- self.c_k = nn.Linear(dim, dim)
21
- self.c_v = nn.Linear(dim, dim)
22
- self.norm_fact = dim ** -0.5
23
- self.softmax = nn.Softmax(dim=-1)
24
- self.proj_drop = nn.Dropout(drop)
25
-
26
- def forward(self, x):
27
- _x = x
28
- B, C, N = x.shape
29
- q = self.c_q(x)
30
- k = self.c_k(x)
31
- v = self.c_v(x)
32
-
33
- attn = q @ k.transpose(-2, -1) * self.norm_fact
34
- attn = self.softmax(attn)
35
- x = (attn @ v).transpose(1, 2).reshape(B, C, N)
36
- x = self.proj_drop(x)
37
- x = x + _x
38
- return x
39
-
40
-
41
- class Self_Attention(nn.Module):
42
- """ Self attention Layer"""
43
-
44
- def __init__(self, in_dim):
45
- super(Self_Attention, self).__init__()
46
-
47
- self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
48
- self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
49
- self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
50
- self.gamma = nn.Parameter(torch.zeros(1))
51
-
52
- self.softmax = nn.Softmax(dim=-1)
53
-
54
- def forward(self, inFeature):
55
- bs, C, w, h = inFeature.size()
56
-
57
- proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
58
- proj_key = self.kConv(inFeature).view(bs, -1, w * h)
59
- energy = torch.bmm(proj_query, proj_key)
60
- attention = self.softmax(energy)
61
- proj_value = self.vConv(inFeature).view(bs, -1, w * h)
62
-
63
- out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
64
- out = out.view(bs, C, w, h)
65
-
66
- out = self.gamma * out + inFeature
67
-
68
- return out
69
-
70
-
71
- class MAL(nn.Module):
72
- """
73
- Multi-view Attention Learning (MAL) module
74
- """
75
-
76
- def __init__(self, in_dim=768, feature_num=4, feature_size=28):
77
- super().__init__()
78
-
79
- self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
80
- self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention
81
-
82
- # Self attention module for each input feature
83
- self.attention_module = nn.ModuleList()
84
- for _ in range(feature_num):
85
- self.attention_module.append(Self_Attention(in_dim))
86
-
87
- self.feature_num = feature_num
88
- self.in_dim = in_dim
89
-
90
- def forward(self, features):
91
- feature = torch.tensor([]).cuda()
92
- for index, _ in enumerate(features):
93
- feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
94
- features = feature
95
-
96
- input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28
97
- bs, _, _ = input_tensor.shape # [2, 3072, 784]
98
-
99
- in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim,
100
- c=self.feature_num) # bs, 768, 28 * 28 * feature_num
101
- feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768
102
-
103
- in_channel = input_tensor.permute(0, 2, 1).contiguous() # bs, 28 * 28, 768 * feature_num
104
- channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28
105
-
106
- weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim,
107
- c=self.feature_num) + channel_weight_sum.permute(0, 2, 1).contiguous()) / 2 # [2, 3072, 784]
108
-
109
- weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1)
110
-
111
- return weight_sum_res # bs, 768, 28 * 28
112
-
113
-
114
- class SaveOutput:
115
- def __init__(self):
116
- self.outputs = []
117
-
118
- def __call__(self, module, module_in, module_out):
119
- self.outputs.append(module_out)
120
-
121
- def clear(self):
122
- self.outputs = []
123
-
124
- # utils
125
- @torch.no_grad()
126
- def concat_all_gather(tensor):
127
- """
128
- Performs all_gather operation on the provided tensors.
129
- *** Warning ***: torch.distributed.all_gather has no gradient.
130
- """
131
- tensors_gather = [
132
- torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
133
- ]
134
- torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
135
-
136
- output = torch.cat(tensors_gather, dim=0)
137
- return output
138
- class Attention(nn.Module):
139
- def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
140
- super().__init__()
141
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
142
- self.num_heads = num_heads
143
- head_dim = dim // num_heads
144
- self.scale = head_dim ** -0.5
145
-
146
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
147
- self.attn_drop = nn.Dropout(attn_drop)
148
- self.proj = nn.Linear(dim, dim)
149
- self.proj_drop = nn.Dropout(proj_drop)
150
-
151
- def forward(self, x):
152
- B, N, C = x.shape
153
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
154
- q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
155
-
156
- attn = (q @ k.transpose(-2, -1)) * self.scale
157
- attn = attn.softmax(dim=-1)
158
- attn = self.attn_drop(attn)
159
-
160
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
161
- x = self.proj(x)
162
- x = self.proj_drop(x)
163
- return x
164
-
165
- from functools import partial
166
- class MoNet(nn.Module):
167
- def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
168
- super().__init__()
169
- self.img_size = img_size
170
- self.input_size = img_size // patch_size
171
- self.dim_mlp = dim_mlp
172
-
173
- self.vit = timm.create_model('vit_base_patch8_224', pretrained=True)
174
- self.vit.norm = nn.Identity()
175
- self.vit.head = nn.Identity()
176
-
177
- self.save_output = SaveOutput()
178
-
179
- # Register Hooks
180
- hook_handles = []
181
- for layer in self.vit.modules():
182
- if isinstance(layer, Block):
183
- handle = layer.register_forward_hook(self.save_output)
184
- hook_handles.append(handle)
185
-
186
- self.MALs = nn.ModuleList()
187
- for _ in range(3):
188
- self.MALs.append(MAL())
189
-
190
- # Image Quality Score Regression
191
- self.fusion_mal = MAL(feature_num=3)
192
- self.block = Block(dim_mlp, 12)
193
- self.cnn = nn.Sequential(
194
- nn.Conv2d(dim_mlp, 256, 5),
195
- nn.BatchNorm2d(256),
196
- nn.ReLU(inplace=True),
197
- nn.AvgPool2d((2, 2)),
198
- nn.Conv2d(256, 128, 3),
199
- nn.BatchNorm2d(128),
200
- nn.ReLU(inplace=True),
201
- nn.AvgPool2d((2, 2)),
202
- nn.Conv2d(128, 128, 3),
203
- nn.BatchNorm2d(128),
204
- nn.ReLU(inplace=True),
205
- nn.AvgPool2d((3, 3)),
206
- )
207
-
208
- # self.i_p_fusion = nn.Sequential(
209
- # Block(128, 4),
210
- # Block(128, 4),
211
- # Block(128, 4),
212
- # )
213
- # self.mlp = nn.Sequential(
214
- # nn.Linear(128, 64),
215
- # nn.GELU(),
216
- # nn.Linear(64, 128),
217
- # )
218
-
219
- dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule
220
- self.blocks = nn.Sequential(*[
221
- Block(
222
- dim=128, num_heads=4, mlp_ratio=4, qkv_bias=True, drop=0,
223
- attn_drop=0, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
224
- for i in range(8)])
225
- self.norm = nn.LayerNorm(128)
226
-
227
- self.score_block = nn.Sequential(
228
- nn.Linear(128, 128 // 2),
229
- nn.ReLU(),
230
- nn.Dropout(drop),
231
- nn.Linear(128 // 2, 1),
232
- nn.Sigmoid()
233
- )
234
-
235
- self.prompt_feature = {}
236
-
237
- @torch.no_grad()
238
- def clear(self):
239
- self.prompt_feature = {}
240
-
241
- @torch.no_grad()
242
- def inference(self, x, data_type):
243
- # prompt_feature = self.prompt_feature[data_type] # 1, n, 128
244
-
245
- _x = self.vit(x)
246
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
247
- self.save_output.outputs.clear()
248
-
249
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
250
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
251
- x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
252
-
253
- # Different Opinion Features (DOF)
254
- DOF = torch.tensor([]).cuda()
255
- for index, _ in enumerate(self.MALs):
256
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
257
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
258
-
259
- # Image Quality Score Regression
260
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
261
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
262
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
263
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
264
-
265
- # prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
266
- # prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128
267
-
268
- fusion = self.blocks(img_feature) # bs, 2, 1
269
- fusion = self.norm(fusion)
270
- fusion = self.score_block(fusion)
271
-
272
- # iq_res = torch.mean(fusion, dim=1).view(-1)
273
- iq_res = fusion[:, 0].view(-1)
274
-
275
- return iq_res
276
-
277
- @torch.no_grad()
278
- def check_prompt(self, data_type):
279
- return data_type in self.prompt_feature
280
-
281
- @torch.no_grad()
282
- def forward_prompt(self, x, score, data_type):
283
- pass
284
- # if data_type in self.prompt_feature:
285
- # return
286
- # _x = self.vit(x)
287
- # x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
288
- # self.save_output.outputs.clear()
289
-
290
- # x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
291
- # x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
292
- # x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
293
-
294
- # # Different Opinion Features (DOF)
295
- # DOF = torch.tensor([]).cuda()
296
- # for index, _ in enumerate(self.MALs):
297
- # DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
298
- # DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
299
-
300
- # # Image Quality Score Regression
301
- # fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
302
- # IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
303
- # IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
304
- # img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
305
-
306
- # # 分数线性变换为128维
307
- # # score_feature = self.score_projection(score) # bs, 128
308
- # score_feature = score.expand(-1, 128)
309
-
310
- # # img_feature 和 score_feature融合得到 funsion_feature
311
- # funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
312
- # funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
313
-
314
- # # print('Load Prompt For Testing.', funsion_feature.shape)
315
- # # self.prompt_feature = funsion_feature.clone()
316
- # self.prompt_feature[data_type] = funsion_feature.clone()
317
-
318
- def forward(self, x, score):
319
- _x = self.vit(x)
320
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
321
- self.save_output.outputs.clear()
322
-
323
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
324
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
325
- x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
326
-
327
- # Different Opinion Features (DOF)
328
- DOF = torch.tensor([]).cuda()
329
- for index, _ in enumerate(self.MALs):
330
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
331
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
332
-
333
- # Image Quality Score Regression
334
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
335
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
336
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
337
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
338
-
339
- # 分数线性变换为128维
340
- # score_feature = self.score_projection(score) # bs, 128
341
- # score_feature = score.expand(-1, 128) # bs, 128
342
-
343
- # # img_feature 和 score_feature融合得到 funsion_feature
344
- # funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
345
- # funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
346
- # funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
347
- # funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
348
-
349
- fusion = self.blocks(img_feature) # bs, 2, 1
350
- fusion = self.norm(fusion)
351
- fusion = self.score_block(fusion)
352
- # iq_res = torch.mean(fusion, dim=1).view(-1)
353
- iq_res = fusion[:, 0].view(-1)
354
-
355
- # differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
356
- # differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
357
-
358
- gt_res = score.view(-1)
359
- # diff_gt_res = 1 - score.view(-1)
360
-
361
- return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
362
-
363
- def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
364
- x1 = save_output.outputs[block_index[0]][:, 1:]
365
- x2 = save_output.outputs[block_index[1]][:, 1:]
366
- x3 = save_output.outputs[block_index[2]][:, 1:]
367
- x4 = save_output.outputs[block_index[3]][:, 1:]
368
- x = torch.cat((x1, x2, x3, x4), dim=2)
369
- return x
370
-
371
- def expand(self, A):
372
- A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
373
-
374
- B = None
375
- for index, i in enumerate(A_expanded):
376
- rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
377
- if B is None:
378
- B = rmv
379
- else:
380
- B = torch.cat((B, rmv), dim=0)
381
-
382
- return B
383
-
384
- if __name__ == '__main__':
385
- in_feature = torch.zeros((2, 3, 224, 224)).cuda()
386
- gt_feature = torch.tensor([[0, 100, 1], [0, 100, 2]], dtype=torch.float).cuda()
387
- model = MoNet().cuda()
388
-
389
- iq_res, gt_res = model(in_feature, gt_feature)
390
-
391
- print(iq_res)
392
- print(gt_res.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PromptIQA/models/{monet.py → promptiqa.py} RENAMED
@@ -6,10 +6,8 @@ import torch.nn as nn
6
  import timm
7
 
8
  from timm.models.vision_transformer import Block
 
9
  from einops import rearrange
10
- from itertools import combinations
11
-
12
- from tqdm import tqdm
13
 
14
  class Attention_Block(nn.Module):
15
  def __init__(self, dim, drop=0.1):
@@ -119,20 +117,6 @@ class SaveOutput:
119
  def clear(self):
120
  self.outputs = []
121
 
122
- # utils
123
- @torch.no_grad()
124
- def concat_all_gather(tensor):
125
- """
126
- Performs all_gather operation on the provided tensors.
127
- *** Warning ***: torch.distributed.all_gather has no gradient.
128
- """
129
- tensors_gather = [
130
- torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
131
- ]
132
- torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
133
-
134
- output = torch.cat(tensors_gather, dim=0)
135
- return output
136
  class Attention(nn.Module):
137
  def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
138
  super().__init__()
@@ -160,8 +144,7 @@ class Attention(nn.Module):
160
  x = self.proj_drop(x)
161
  return x
162
 
163
- from functools import partial
164
- class MoNet(nn.Module):
165
  def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
166
  super().__init__()
167
  self.img_size = img_size
@@ -273,15 +256,10 @@ class MoNet(nn.Module):
273
  fusion = self.norm(fusion)
274
  fusion = self.score_block(fusion)
275
 
276
- # iq_res = torch.mean(fusion, dim=1).view(-1)
277
  iq_res = fusion[:, 0].view(-1)
278
 
279
  return iq_res
280
 
281
- @torch.no_grad()
282
- def check_prompt(self, data_type):
283
- return data_type in self.prompt_feature
284
-
285
  @torch.no_grad()
286
  def forward_prompt(self, x, score, data_type):
287
  _x = self.vit(x)
@@ -304,63 +282,13 @@ class MoNet(nn.Module):
304
  IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
305
  img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
306
 
307
- # 分数线性变换为128维
308
- # score_feature = self.score_projection(score) # bs, 128
309
  score_feature = score.expand(-1, 128)
310
 
311
- # img_feature 和 score_feature融合得到 funsion_feature
312
  funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
313
  funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
314
 
315
- # print('Load Prompt For Testing.', funsion_feature.shape)
316
- # self.prompt_feature = funsion_feature.clone()
317
  self.prompt_feature[data_type] = funsion_feature.clone()
318
 
319
- def forward(self, x, score):
320
- _x = self.vit(x)
321
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
322
- self.save_output.outputs.clear()
323
-
324
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
325
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
326
- x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
327
-
328
- # Different Opinion Features (DOF)
329
- DOF = torch.tensor([]).cuda()
330
- for index, _ in enumerate(self.MALs):
331
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
332
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
333
-
334
- # Image Quality Score Regression
335
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
336
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
337
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
338
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
339
-
340
- # 分数线性变换为128维
341
- # score_feature = self.score_projection(score) # bs, 128
342
- score_feature = score.expand(-1, 128) # bs, 128
343
-
344
- # img_feature 和 score_feature融合得到 funsion_feature
345
- funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
346
- funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
347
- funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
348
- funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
349
-
350
- fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1
351
- fusion = self.norm(fusion)
352
- fusion = self.score_block(fusion)
353
- # iq_res = torch.mean(fusion, dim=1).view(-1)
354
- iq_res = fusion[:, 0].view(-1)
355
-
356
- # differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
357
- # differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
358
-
359
- gt_res = score.view(-1)
360
- # diff_gt_res = 1 - score.view(-1)
361
-
362
- return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
363
-
364
  def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
365
  x1 = save_output.outputs[block_index[0]][:, 1:]
366
  x2 = save_output.outputs[block_index[1]][:, 1:]
@@ -381,13 +309,3 @@ class MoNet(nn.Module):
381
  B = torch.cat((B, rmv), dim=0)
382
 
383
  return B
384
-
385
- if __name__ == '__main__':
386
- in_feature = torch.zeros((10, 3, 224, 224)).cuda()
387
- gt_feature = torch.tensor([[1], [2], [3], [4], [5], [6], [7], [8], [9], [10]], dtype=torch.float).cuda()
388
- model = MoNet().cuda()
389
-
390
- iq_res, gt_res = model(in_feature, gt_feature)
391
-
392
- print(iq_res.shape)
393
- print(gt_res.shape)
 
6
  import timm
7
 
8
  from timm.models.vision_transformer import Block
9
+ from functools import partial
10
  from einops import rearrange
 
 
 
11
 
12
  class Attention_Block(nn.Module):
13
  def __init__(self, dim, drop=0.1):
 
117
  def clear(self):
118
  self.outputs = []
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  class Attention(nn.Module):
121
  def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
122
  super().__init__()
 
144
  x = self.proj_drop(x)
145
  return x
146
 
147
+ class PromptIQA(nn.Module):
 
148
  def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
149
  super().__init__()
150
  self.img_size = img_size
 
256
  fusion = self.norm(fusion)
257
  fusion = self.score_block(fusion)
258
 
 
259
  iq_res = fusion[:, 0].view(-1)
260
 
261
  return iq_res
262
 
 
 
 
 
263
  @torch.no_grad()
264
  def forward_prompt(self, x, score, data_type):
265
  _x = self.vit(x)
 
282
  IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
283
  img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
284
 
 
 
285
  score_feature = score.expand(-1, 128)
286
 
 
287
  funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
288
  funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
289
 
 
 
290
  self.prompt_feature[data_type] = funsion_feature.clone()
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
293
  x1 = save_output.outputs[block_index[0]][:, 1:]
294
  x2 = save_output.outputs[block_index[1]][:, 1:]
 
309
  B = torch.cat((B, rmv), dim=0)
310
 
311
  return B
 
 
 
 
 
 
 
 
 
 
PromptIQA/models/vit_base.py DELETED
@@ -1,402 +0,0 @@
1
- """
2
- The completion for Mean-opinion Network(MoNet)
3
- """
4
- import torch
5
- import torch.nn as nn
6
- import timm
7
-
8
- from timm.models.vision_transformer import Block
9
- from einops import rearrange
10
- from itertools import combinations
11
-
12
- from tqdm import tqdm
13
-
14
-
15
- class Attention_Block(nn.Module):
16
- def __init__(self, dim, drop=0.1):
17
- super().__init__()
18
- self.c_q = nn.Linear(dim, dim)
19
- self.c_k = nn.Linear(dim, dim)
20
- self.c_v = nn.Linear(dim, dim)
21
- self.norm_fact = dim ** -0.5
22
- self.softmax = nn.Softmax(dim=-1)
23
- self.proj_drop = nn.Dropout(drop)
24
-
25
- def forward(self, x):
26
- _x = x
27
- B, C, N = x.shape
28
- q = self.c_q(x)
29
- k = self.c_k(x)
30
- v = self.c_v(x)
31
-
32
- attn = q @ k.transpose(-2, -1) * self.norm_fact
33
- attn = self.softmax(attn)
34
- x = (attn @ v).transpose(1, 2).reshape(B, C, N)
35
- x = self.proj_drop(x)
36
- x = x + _x
37
- return x
38
-
39
-
40
- class Self_Attention(nn.Module):
41
- """ Self attention Layer"""
42
-
43
- def __init__(self, in_dim):
44
- super(Self_Attention, self).__init__()
45
-
46
- self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
47
- self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
48
- self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
49
- self.gamma = nn.Parameter(torch.zeros(1))
50
-
51
- self.softmax = nn.Softmax(dim=-1)
52
-
53
- def forward(self, inFeature):
54
- bs, C, w, h = inFeature.size()
55
-
56
- proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
57
- proj_key = self.kConv(inFeature).view(bs, -1, w * h)
58
- energy = torch.bmm(proj_query, proj_key)
59
- attention = self.softmax(energy)
60
- proj_value = self.vConv(inFeature).view(bs, -1, w * h)
61
-
62
- out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
63
- out = out.view(bs, C, w, h)
64
-
65
- out = self.gamma * out + inFeature
66
-
67
- return out
68
-
69
-
70
- class three_cnn(nn.Module):
71
- def __init__(self, in_dim) -> None:
72
- super().__init__()
73
-
74
- self.three_cnn = nn.Sequential(
75
- nn.Conv2d(in_dim, in_dim // 2, kernel_size=3, padding=1),
76
- nn.ReLU(inplace=True),
77
- nn.Conv2d(in_dim // 2, in_dim // 2, kernel_size=3, padding=1),
78
- nn.ReLU(inplace=True),
79
- nn.Conv2d(in_dim // 2, in_dim, kernel_size=3, padding=1),
80
- nn.ReLU(inplace=True),
81
- )
82
-
83
- def forward(self, input):
84
- return self.three_cnn(input)
85
-
86
-
87
- class MAL(nn.Module):
88
- def __init__(self, in_dim=768, feature_num=4, feature_size=28):
89
- super().__init__()
90
- self.attention_module = nn.ModuleList()
91
- for i in range(feature_num):
92
- self.attention_module.append(three_cnn(in_dim))
93
-
94
- self.feature_num = feature_num
95
- self.in_dim = in_dim
96
- self.feature_size = feature_size
97
-
98
- def forward(self, features):
99
- feature = torch.tensor([]).cuda()
100
- for index, _ in enumerate(features):
101
- feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(1)), dim=1)
102
- feature = torch.mean(feature, dim=1)
103
- features = feature.view(-1, self.in_dim, self.feature_size * self.feature_size)
104
-
105
- return features # bs, 768, 28 * 28
106
-
107
-
108
- class SaveOutput:
109
- def __init__(self):
110
- self.outputs = []
111
-
112
- def __call__(self, module, module_in, module_out):
113
- self.outputs.append(module_out)
114
-
115
- def clear(self):
116
- self.outputs = []
117
-
118
-
119
- # utils
120
- @torch.no_grad()
121
- def concat_all_gather(tensor):
122
- """
123
- Performs all_gather operation on the provided tensors.
124
- *** Warning ***: torch.distributed.all_gather has no gradient.
125
- """
126
- tensors_gather = [
127
- torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
128
- ]
129
- torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
130
-
131
- output = torch.cat(tensors_gather, dim=0)
132
- return output
133
-
134
-
135
- class Attention(nn.Module):
136
- def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
137
- super().__init__()
138
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
139
- self.num_heads = num_heads
140
- head_dim = dim // num_heads
141
- self.scale = head_dim ** -0.5
142
-
143
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
144
- self.attn_drop = nn.Dropout(attn_drop)
145
- self.proj = nn.Linear(dim, dim)
146
- self.proj_drop = nn.Dropout(proj_drop)
147
-
148
- def forward(self, x):
149
- B, N, C = x.shape
150
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
151
- q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
152
-
153
- attn = (q @ k.transpose(-2, -1)) * self.scale
154
- attn = attn.softmax(dim=-1)
155
- attn = self.attn_drop(attn)
156
-
157
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
158
- x = self.proj(x)
159
- x = self.proj_drop(x)
160
- return x
161
-
162
-
163
- from functools import partial
164
-
165
-
166
- class MoNet(nn.Module):
167
- def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
168
- super().__init__()
169
- self.img_size = img_size
170
- self.input_size = img_size // patch_size
171
- self.dim_mlp = dim_mlp
172
-
173
- self.vit = timm.create_model('vit_base_patch8_224', pretrained=True)
174
- self.vit.norm = nn.Identity()
175
- self.vit.head = nn.Identity()
176
-
177
- self.save_output = SaveOutput()
178
-
179
- # Register Hooks
180
- hook_handles = []
181
- for layer in self.vit.modules():
182
- if isinstance(layer, Block):
183
- handle = layer.register_forward_hook(self.save_output)
184
- hook_handles.append(handle)
185
-
186
- self.MALs = nn.ModuleList()
187
- for _ in range(1):
188
- self.MALs.append(MAL())
189
-
190
- # Image Quality Score Regression
191
- self.fusion_mal = MAL(feature_num=1)
192
- self.block = Block(dim_mlp, 12)
193
- self.cnn = nn.Sequential(
194
- nn.Conv2d(dim_mlp, 256, 5),
195
- nn.BatchNorm2d(256),
196
- nn.ReLU(inplace=True),
197
- nn.AvgPool2d((2, 2)),
198
- nn.Conv2d(256, 128, 3),
199
- nn.BatchNorm2d(128),
200
- nn.ReLU(inplace=True),
201
- nn.AvgPool2d((2, 2)),
202
- nn.Conv2d(128, 128, 3),
203
- nn.BatchNorm2d(128),
204
- nn.ReLU(inplace=True),
205
- nn.AvgPool2d((3, 3)),
206
- )
207
-
208
- self.i_p_fusion = nn.Sequential(
209
- Block(128, 4),
210
- Block(128, 4),
211
- Block(128, 4),
212
- )
213
- self.mlp = nn.Sequential(
214
- nn.Linear(128, 64),
215
- nn.GELU(),
216
- nn.Linear(64, 128),
217
- )
218
-
219
- self.prompt_fusion = nn.Sequential(
220
- Block(128, 4),
221
- Block(128, 4),
222
- Block(128, 4),
223
- )
224
-
225
- dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule
226
- self.blocks = nn.Sequential(*[
227
- Block(
228
- dim=128, num_heads=4, mlp_ratio=4, qkv_bias=True, drop=0,
229
- attn_drop=0, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
230
- for i in range(8)])
231
- self.norm = nn.LayerNorm(128)
232
-
233
- self.score_block = nn.Sequential(
234
- nn.Linear(128, 128 // 2),
235
- nn.ReLU(),
236
- nn.Dropout(drop),
237
- nn.Linear(128 // 2, 1),
238
- nn.Sigmoid()
239
- )
240
-
241
- self.prompt_feature = {}
242
-
243
- @torch.no_grad()
244
- def clear(self):
245
- self.prompt_feature = {}
246
-
247
- @torch.no_grad()
248
- def inference(self, x, data_type):
249
- prompt_feature = self.prompt_feature[data_type] # 1, n, 128
250
-
251
- _x = self.vit(x)
252
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
253
- self.save_output.outputs.clear()
254
-
255
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
256
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
257
- x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
258
-
259
- # Different Opinion Features (DOF)
260
- DOF = torch.tensor([]).cuda()
261
- for index, _ in enumerate(self.MALs):
262
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
263
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
264
-
265
- # Image Quality Score Regression
266
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
267
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
268
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
269
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
270
-
271
- prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
272
- prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128
273
-
274
- fusion = self.blocks(torch.cat((img_feature, prompt_feature), dim=1)) # bs, 2, 1
275
- fusion = self.norm(fusion)
276
- fusion = self.score_block(fusion)
277
-
278
- # iq_res = torch.mean(fusion, dim=1).view(-1)
279
- iq_res = fusion[:, 0].view(-1)
280
-
281
- return iq_res
282
-
283
- @torch.no_grad()
284
- def check_prompt(self, data_type):
285
- return data_type in self.prompt_feature
286
-
287
- @torch.no_grad()
288
- def forward_prompt(self, x, score, data_type):
289
- if data_type in self.prompt_feature:
290
- return
291
- _x = self.vit(x)
292
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
293
- self.save_output.outputs.clear()
294
-
295
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
296
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
297
- x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
298
-
299
- # Different Opinion Features (DOF)
300
- DOF = torch.tensor([]).cuda()
301
- for index, _ in enumerate(self.MALs):
302
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
303
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
304
-
305
- # Image Quality Score Regression
306
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
307
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
308
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
309
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
310
-
311
- # 分数线性变换为128维
312
- # score_feature = self.score_projection(score) # bs, 128
313
- score_feature = score.expand(-1, 128)
314
-
315
- # img_feature 和 score_feature融合得到 funsion_feature
316
- funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
317
- funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
318
-
319
- # print('Load Prompt For Testing.', funsion_feature.shape)
320
- # self.prompt_feature = funsion_feature.clone()
321
- self.prompt_feature[data_type] = funsion_feature.clone()
322
-
323
- def forward(self, x, score):
324
- _x = self.vit(x)
325
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
326
- self.save_output.outputs.clear()
327
-
328
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
329
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
330
- x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
331
-
332
- # Different Opinion Features (DOF)
333
- DOF = torch.tensor([]).cuda()
334
- for index, _ in enumerate(self.MALs):
335
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
336
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
337
-
338
- # Image Quality Score Regression
339
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
340
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
341
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
342
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
343
-
344
- # 分数线性变换为128维
345
- # score_feature = self.score_projection(score) # bs, 128
346
- score_feature = score.expand(-1, 128) # bs, 128
347
-
348
- # img_feature 和 score_feature融合得到 funsion_feature
349
- # funsion_feature = self.i_p_fusion(torch.cat((img_feature.detach(), score_feature.unsqueeze(1).detach()), dim=1)) # bs, 2, 128
350
- funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
351
- funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
352
- funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
353
- funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
354
-
355
- fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1
356
- fusion = self.norm(fusion)
357
- fusion = self.score_block(fusion)
358
- # iq_res = torch.mean(fusion, dim=1).view(-1)
359
- iq_res = fusion[:, 0].view(-1)
360
-
361
- # differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
362
- # differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
363
-
364
- gt_res = score.view(-1)
365
- # diff_gt_res = 1 - score.view(-1)
366
-
367
- return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
368
-
369
- def extract_feature(self, save_output, block_index=None):
370
- block_index = [2, 5, 8, 11]
371
- x1 = save_output.outputs[block_index[0]][:, 1:]
372
- x2 = save_output.outputs[block_index[1]][:, 1:]
373
- x3 = save_output.outputs[block_index[2]][:, 1:]
374
- x4 = save_output.outputs[block_index[3]][:, 1:]
375
- x = torch.cat((x1, x2, x3, x4), dim=2)
376
- return x
377
-
378
- def expand(self, A):
379
- A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
380
-
381
- B = None
382
- for index, i in enumerate(A_expanded):
383
- rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
384
- if B is None:
385
- B = rmv
386
- else:
387
- B = torch.cat((B, rmv), dim=0)
388
-
389
- return B
390
-
391
-
392
- if __name__ == '__main__':
393
- in_feature = torch.zeros((11, 3, 384, 384)).cuda()
394
- gt_feature = torch.tensor(
395
- [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype=torch.float).cuda()
396
- gt_feature = gt_feature.reshape(-1, 1)
397
- model = MoNet().cuda()
398
-
399
- (iq_res, _), (_, _) = model(in_feature, gt_feature)
400
-
401
- print(iq_res.shape)
402
- # print(gt_res.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PromptIQA/models/vit_large.py DELETED
@@ -1,405 +0,0 @@
1
- """
2
- The completion for Mean-opinion Network(MoNet)
3
- """
4
- import torch
5
- import torch.nn as nn
6
- import timm
7
-
8
- from timm.models.vision_transformer import Block
9
- from einops import rearrange
10
- from itertools import combinations
11
-
12
- from tqdm import tqdm
13
-
14
-
15
- class Attention_Block(nn.Module):
16
- def __init__(self, dim, drop=0.1):
17
- super().__init__()
18
- self.c_q = nn.Linear(dim, dim)
19
- self.c_k = nn.Linear(dim, dim)
20
- self.c_v = nn.Linear(dim, dim)
21
- self.norm_fact = dim ** -0.5
22
- self.softmax = nn.Softmax(dim=-1)
23
- self.proj_drop = nn.Dropout(drop)
24
-
25
- def forward(self, x):
26
- _x = x
27
- B, C, N = x.shape
28
- q = self.c_q(x)
29
- k = self.c_k(x)
30
- v = self.c_v(x)
31
-
32
- attn = q @ k.transpose(-2, -1) * self.norm_fact
33
- attn = self.softmax(attn)
34
- x = (attn @ v).transpose(1, 2).reshape(B, C, N)
35
- x = self.proj_drop(x)
36
- x = x + _x
37
- return x
38
-
39
-
40
- class Self_Attention(nn.Module):
41
- """ Self attention Layer"""
42
-
43
- def __init__(self, in_dim):
44
- super(Self_Attention, self).__init__()
45
-
46
- self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
47
- self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
48
- self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
49
- self.gamma = nn.Parameter(torch.zeros(1))
50
-
51
- self.softmax = nn.Softmax(dim=-1)
52
-
53
- def forward(self, inFeature):
54
- bs, C, w, h = inFeature.size()
55
-
56
- proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
57
- proj_key = self.kConv(inFeature).view(bs, -1, w * h)
58
- energy = torch.bmm(proj_query, proj_key)
59
- attention = self.softmax(energy)
60
- proj_value = self.vConv(inFeature).view(bs, -1, w * h)
61
-
62
- out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
63
- out = out.view(bs, C, w, h)
64
-
65
- out = self.gamma * out + inFeature
66
-
67
- return out
68
-
69
-
70
- class three_cnn(nn.Module):
71
- def __init__(self, in_dim) -> None:
72
- super().__init__()
73
-
74
- self.three_cnn = nn.Sequential(
75
- nn.Conv2d(in_dim, in_dim // 2, kernel_size=3, padding=1),
76
- nn.ReLU(inplace=True),
77
- nn.Conv2d(in_dim // 2, in_dim // 2, kernel_size=3, padding=1),
78
- nn.ReLU(inplace=True),
79
- nn.Conv2d(in_dim // 2, in_dim, kernel_size=3, padding=1),
80
- nn.ReLU(inplace=True),
81
- )
82
-
83
- def forward(self, input):
84
- return self.three_cnn(input)
85
-
86
-
87
- class MAL(nn.Module):
88
- def __init__(self, in_dim=768, feature_num=4, feature_size=28):
89
- super().__init__()
90
- self.attention_module = nn.ModuleList()
91
- for i in range(feature_num):
92
- self.attention_module.append(three_cnn(in_dim))
93
-
94
- self.feature_num = feature_num
95
- self.in_dim = in_dim
96
- self.feature_size = feature_size
97
-
98
- def forward(self, features):
99
- feature = torch.tensor([]).cuda()
100
- for index, _ in enumerate(features):
101
- feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(1)), dim=1)
102
- feature = torch.mean(feature, dim=1)
103
- features = feature.view(-1, self.in_dim, self.feature_size * self.feature_size)
104
-
105
- return features # bs, 768, 28 * 28
106
-
107
-
108
- class SaveOutput:
109
- def __init__(self):
110
- self.outputs = []
111
-
112
- def __call__(self, module, module_in, module_out):
113
- self.outputs.append(module_out)
114
-
115
- def clear(self):
116
- self.outputs = []
117
-
118
-
119
- # utils
120
- @torch.no_grad()
121
- def concat_all_gather(tensor):
122
- """
123
- Performs all_gather operation on the provided tensors.
124
- *** Warning ***: torch.distributed.all_gather has no gradient.
125
- """
126
- tensors_gather = [
127
- torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
128
- ]
129
- torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
130
-
131
- output = torch.cat(tensors_gather, dim=0)
132
- return output
133
-
134
-
135
- class Attention(nn.Module):
136
- def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
137
- super().__init__()
138
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
139
- self.num_heads = num_heads
140
- head_dim = dim // num_heads
141
- self.scale = head_dim ** -0.5
142
-
143
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
144
- self.attn_drop = nn.Dropout(attn_drop)
145
- self.proj = nn.Linear(dim, dim)
146
- self.proj_drop = nn.Dropout(proj_drop)
147
-
148
- def forward(self, x):
149
- B, N, C = x.shape
150
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
151
- q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
152
-
153
- attn = (q @ k.transpose(-2, -1)) * self.scale
154
- attn = attn.softmax(dim=-1)
155
- attn = self.attn_drop(attn)
156
-
157
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
158
- x = self.proj(x)
159
- x = self.proj_drop(x)
160
- return x
161
-
162
-
163
- from functools import partial
164
-
165
-
166
- class MoNet(nn.Module):
167
- def __init__(self, patch_size=32, drop=0.1, dim_mlp=1024, img_size=384):
168
- super().__init__()
169
- self.img_size = img_size
170
- self.input_size = img_size // patch_size
171
- self.dim_mlp = dim_mlp
172
-
173
- self.vit = timm.create_model('vit_large_patch32_384', pretrained=True)
174
- self.vit.norm = nn.Identity()
175
- self.vit.head = nn.Identity()
176
- self.vit.head_drop = nn.Identity()
177
-
178
- self.save_output = SaveOutput()
179
-
180
- # Register Hooks
181
- hook_handles = []
182
- for layer in self.vit.modules():
183
- if isinstance(layer, Block):
184
- handle = layer.register_forward_hook(self.save_output)
185
- hook_handles.append(handle)
186
-
187
- self.MALs = nn.ModuleList()
188
- for _ in range(3):
189
- self.MALs.append(MAL(in_dim=dim_mlp, feature_size=self.input_size))
190
-
191
- # Image Quality Score Regression
192
- self.fusion_mal = MAL(in_dim=dim_mlp, feature_num=3, feature_size=self.input_size)
193
- self.block = Block(dim_mlp, 16)
194
- self.cnn = nn.Sequential(
195
- nn.Conv2d(dim_mlp, 512, 5),
196
- nn.BatchNorm2d(512),
197
- nn.ReLU(inplace=True),
198
- nn.AvgPool2d((2, 2)), # 4
199
- nn.Conv2d(512, 256, 3, 1), # 2
200
- nn.BatchNorm2d(256),
201
- nn.ReLU(inplace=True),
202
- nn.Conv2d(256, 256, 1),
203
- nn.BatchNorm2d(256),
204
- nn.ReLU(inplace=True),
205
- nn.AvgPool2d((2, 2)),
206
- )
207
-
208
- self.i_p_fusion = nn.Sequential(
209
- Block(256, 8),
210
- Block(256, 8),
211
- Block(256, 8),
212
- )
213
- self.mlp = nn.Sequential(
214
- nn.Linear(256, 128),
215
- nn.GELU(),
216
- nn.Linear(128, 256),
217
- )
218
-
219
- self.prompt_fusion = nn.Sequential(
220
- Block(256, 8),
221
- Block(256, 8),
222
- Block(256, 8),
223
- )
224
-
225
- dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule
226
- self.blocks = nn.Sequential(*[
227
- Block(dim=256, num_heads=8, mlp_ratio=4, qkv_bias=True, attn_drop=0, drop_path=dpr[i])
228
- for i in range(8)])
229
- self.norm = nn.LayerNorm(256)
230
-
231
- self.score_block = nn.Sequential(
232
- nn.Linear(256, 256 // 2),
233
- nn.ReLU(),
234
- nn.Dropout(drop),
235
- nn.Linear(256 // 2, 1),
236
- nn.Sigmoid()
237
- )
238
- self.prompt_feature = {}
239
-
240
- @torch.no_grad()
241
- def clear(self):
242
- self.prompt_feature = {}
243
-
244
- @torch.no_grad()
245
- def inference(self, x, data_type):
246
- prompt_feature = self.prompt_feature[data_type] # 1, n, 128
247
-
248
- _x = self.vit(x)
249
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
250
- self.save_output.outputs.clear()
251
-
252
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
253
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size,
254
- h=self.input_size) # bs, 4, 768, 28, 28
255
- x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
256
-
257
- # Different Opinion Features (DOF)
258
- DOF = torch.tensor([]).cuda()
259
- for index, _ in enumerate(self.MALs):
260
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
261
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
262
-
263
- # Image Quality Score Regression
264
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
265
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
266
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size,
267
- h=self.input_size) # bs, 768, 28, 28
268
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
269
-
270
- prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
271
- prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128
272
-
273
- fusion = self.blocks(torch.cat((img_feature, prompt_feature), dim=1)) # bs, 2, 1
274
- fusion = self.norm(fusion)
275
- fusion = self.score_block(fusion)
276
-
277
- # iq_res = torch.mean(fusion, dim=1).view(-1)
278
- iq_res = fusion[:, 0].view(-1)
279
-
280
- return iq_res
281
-
282
- @torch.no_grad()
283
- def check_prompt(self, data_type):
284
- return data_type in self.prompt_feature
285
-
286
- @torch.no_grad()
287
- def forward_prompt(self, x, score, data_type):
288
- if data_type in self.prompt_feature:
289
- return
290
- _x = self.vit(x)
291
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
292
- self.save_output.outputs.clear()
293
-
294
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
295
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size,
296
- h=self.input_size) # bs, 4, 768, 28, 28
297
- x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
298
-
299
- # Different Opinion Features (DOF)
300
- DOF = torch.tensor([]).cuda()
301
- for index, _ in enumerate(self.MALs):
302
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
303
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
304
-
305
- # Image Quality Score Regression
306
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
307
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
308
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size,
309
- h=self.input_size) # bs, 768, 28, 28
310
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
311
-
312
- # 分数线性变换为128维
313
- # score_feature = self.score_projection(score) # bs, 128
314
- score_feature = score.expand(-1, 256)
315
-
316
- # img_feature 和 score_feature融合得到 funsion_feature
317
- funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
318
- funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
319
-
320
- # print('Load Prompt For Testing.', funsion_feature.shape)
321
- # self.prompt_feature = funsion_feature.clone()
322
- self.prompt_feature[data_type] = funsion_feature.clone()
323
-
324
- def forward(self, x, score):
325
- _x = self.vit(x)
326
- x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
327
- self.save_output.outputs.clear()
328
-
329
- x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
330
- x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size,
331
- h=self.input_size) # bs, 4, 768, 28, 28
332
- x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
333
-
334
- # Different Opinion Features (DOF)
335
- DOF = torch.tensor([]).cuda()
336
- for index, _ in enumerate(self.MALs):
337
- DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
338
- DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
339
- # Image Quality Score Regression
340
- fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
341
- IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
342
- IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size,
343
- h=self.input_size) # bs, 768, 28, 28
344
- img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
345
-
346
- # 分数线性变换为128维
347
- # score_feature = self.score_projection(score) # bs, 128
348
- score_feature = score.expand(-1, 256) # bs, 128
349
-
350
- # img_feature 和 score_feature融合得到 funsion_feature funsion_feature = self.i_p_fusion(torch.cat((
351
- # img_feature.detach(), score_feature.unsqueeze(1).detach()), dim=1)) # bs, 2, 128
352
- funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
353
- funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) # bs, 128
354
- funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
355
- funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
356
-
357
- fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1
358
- fusion = self.norm(fusion)
359
- fusion = self.score_block(fusion)
360
- # iq_res = torch.mean(fusion, dim=1).view(-1)
361
- iq_res = fusion[:, 0].view(-1)
362
-
363
- # differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
364
- # differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
365
-
366
- gt_res = score.view(-1)
367
- # diff_gt_res = 1 - score.view(-1)
368
-
369
- return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
370
-
371
- def extract_feature(self, save_output, block_index=None):
372
- if block_index is None:
373
- block_index = [5, 11, 17, 23]
374
- x1 = save_output.outputs[block_index[0]][:, 1:]
375
- x2 = save_output.outputs[block_index[1]][:, 1:]
376
- x3 = save_output.outputs[block_index[2]][:, 1:]
377
- x4 = save_output.outputs[block_index[3]][:, 1:]
378
- x = torch.cat((x1, x2, x3, x4), dim=2)
379
- return x
380
-
381
- def expand(self, A):
382
- A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
383
-
384
- B = None
385
- for index, i in enumerate(A_expanded):
386
- rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
387
- if B is None:
388
- B = rmv
389
- else:
390
- B = torch.cat((B, rmv), dim=0)
391
-
392
- return B
393
-
394
-
395
- if __name__ == '__main__':
396
- in_feature = torch.zeros((11, 3, 384, 384)).cuda()
397
- gt_feature = torch.tensor(
398
- [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype=torch.float).cuda()
399
- gt_feature = gt_feature.reshape(-1, 1)
400
- model = MoNet().cuda()
401
-
402
- (iq_res, _), (_, _) = model(in_feature, gt_feature)
403
-
404
- print(iq_res.shape)
405
- # print(gt_res.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PromptIQA/run_promptIQA copy.py DELETED
@@ -1,109 +0,0 @@
1
- import os
2
- import random
3
- import torchvision
4
- import cv2
5
- import torch
6
- from models import monet as MoNet
7
- import numpy as np
8
- from utils.dataset.process import ToTensor, Normalize
9
- from utils.toolkit import *
10
- import warnings
11
- warnings.filterwarnings('ignore')
12
-
13
- import sys
14
- sys.path.append(os.path.dirname(__file__))
15
-
16
- class PromptIQA():
17
- def __init__(self) -> None:
18
- pass
19
-
20
- def load_image(img_path, size=224):
21
- try:
22
- d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
23
- d_img = cv2.resize(d_img, (size, size), interpolation=cv2.INTER_CUBIC)
24
- d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
25
- d_img = np.array(d_img).astype('float32') / 255
26
- d_img = np.transpose(d_img, (2, 0, 1))
27
- except:
28
- print(img_path)
29
-
30
- return d_img
31
-
32
- def load_model(pkl_path):
33
-
34
- model = MoNet.MoNet()
35
- dict_pkl = {}
36
- # prompt_num = torch.load(pkl_path, map_location='cpu').get('prompt_num')
37
- for key, value in torch.load(pkl_path, map_location='cpu')['state_dict'].items():
38
- dict_pkl[key[7:]] = value
39
- model.load_state_dict(dict_pkl)
40
- print('Load Model From ', pkl_path)
41
-
42
- return model
43
-
44
- def get_an_img_score(img_path, target):
45
- transform=torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
46
- values_to_insert = np.array([0.0, 1.0])
47
- position_to_insert = 0
48
- target = np.insert(target, position_to_insert, values_to_insert)
49
-
50
- sample = load_image(img_path)
51
- samples = {'img': sample, 'gt': target}
52
- samples = transform(samples)
53
-
54
- return samples
55
-
56
- import random
57
- if __name__ == '__main__':
58
- pkl_path = "./checkpoints/best_model_five_22.pth.tar"
59
- model = load_model(pkl_path).cuda()
60
- model.eval()
61
-
62
- img_path = '/mnt/storage/PromptIQA_Demo/CSIQ/dst_src'
63
-
64
- img_tensor, gt_tensor = None, None
65
- img_list = os.listdir(img_path)
66
- random.shuffle(img_list)
67
- for idx, img_name in enumerate(img_list):
68
- if idx == 10:
69
- break
70
-
71
- img_name = os.path.join(img_path, img_name)
72
- score = np.array(idx / 10)
73
- samples = get_an_img_score(img_name, score)
74
-
75
- if img_tensor is None:
76
- img_tensor = samples['img'].unsqueeze(0)
77
- gt_tensor = samples['gt'].type(torch.FloatTensor).unsqueeze(0)
78
- else:
79
- img_tensor = torch.cat((img_tensor, samples['img'].unsqueeze(0)), dim=0)
80
- gt_tensor = torch.cat((gt_tensor, samples['gt'].type(torch.FloatTensor).unsqueeze(0)), dim=0)
81
-
82
- print(img_tensor.shape)
83
- print(gt_tensor.shape)
84
- print(gt_tensor)
85
-
86
- img = img_tensor.squeeze(0).cuda()
87
- label = gt_tensor.squeeze(0).cuda()
88
- reverse = False
89
- if reverse == 2:
90
- label = torch.rand_like(label[:, -1]).cuda()
91
- print(label)
92
- elif reverse == 3:
93
- print('Total Random')
94
- label = torch.rand_like(label[:, -1]).cuda()
95
- img = torch.rand_like(img).cuda()
96
- else:
97
- label = label[:, -1].cuda() if not reverse else (1 - label[:, -1].cuda())
98
- print('input label: ', label)
99
- model.forward_prompt(img, label.reshape(-1, 1), 'livec')
100
-
101
- img_name = '/mnt/storage/PromptIQA_Demo/CSIQ/src_imgs/1600.png'
102
- score = np.array(random.random())
103
- samples = get_an_img_score(img_name, score)
104
-
105
- img = samples['img'].unsqueeze(0).cuda()
106
- print(img.shape)
107
- pred = model.inference(img, 'livec')
108
-
109
- print(pred)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PromptIQA/run_promptIQA.py CHANGED
@@ -3,7 +3,7 @@ import random
3
  import torchvision
4
  import cv2
5
  import torch
6
- from PromptIQA.models import monet as MoNet
7
  import numpy as np
8
  from PromptIQA.utils.dataset.process import ToTensor, Normalize
9
  from PromptIQA.utils.toolkit import *
@@ -14,7 +14,7 @@ import sys
14
  sys.path.append(os.path.dirname(__file__))
15
 
16
  def load_model(pkl_path):
17
- model = MoNet.MoNet()
18
  dict_pkl = {}
19
  for key, value in torch.load(pkl_path, map_location='cpu')['state_dict'].items():
20
  dict_pkl[key[7:]] = value
 
3
  import torchvision
4
  import cv2
5
  import torch
6
+ from PromptIQA.models import promptiqa
7
  import numpy as np
8
  from PromptIQA.utils.dataset.process import ToTensor, Normalize
9
  from PromptIQA.utils.toolkit import *
 
14
  sys.path.append(os.path.dirname(__file__))
15
 
16
  def load_model(pkl_path):
17
+ model = promptiqa.PromptIQA()
18
  dict_pkl = {}
19
  for key, value in torch.load(pkl_path, map_location='cpu')['state_dict'].items():
20
  dict_pkl[key[7:]] = value
PromptIQA/t.py DELETED
@@ -1,2 +0,0 @@
1
- a = "(1+1)**(2**2)"
2
- print(eval(a))
 
 
 
PromptIQA/test.py DELETED
@@ -1,429 +0,0 @@
1
- import sys
2
-
3
- from utils import log_writer
4
-
5
- import argparse
6
- import builtins
7
- import os
8
- import random
9
- import shutil
10
- import time
11
-
12
- import torch
13
- import torch.distributed as dist
14
- import torch.multiprocessing as mp
15
- import torch.nn as nn
16
- import torch.nn.parallel
17
- import torch.optim
18
- import torch.utils.data
19
- import torch.utils.data.distributed
20
- # from models import monet as MoNet
21
- from torch.utils.data import ConcatDataset
22
- from utils.dataset import data_loader
23
-
24
- from utils.toolkit import *
25
-
26
- loger_path = None
27
-
28
-
29
- def init(config):
30
- global loger_path
31
- if config.dist_url == "env://" and config.world_size == -1:
32
- config.world_size = int(os.environ["WORLD_SIZE"])
33
-
34
- config.distributed = config.world_size > 1 or config.multiprocessing_distributed
35
-
36
- print("config.distributed", config.distributed)
37
-
38
- loger_path = os.path.join(config.save_path, "inference_log")
39
- if not os.path.isdir(loger_path):
40
- os.makedirs(loger_path)
41
-
42
- print("----------------------------------")
43
- print(
44
- "Begin Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
45
- )
46
- printArgs(config, loger_path)
47
- # os.environ["CUDA_VISIBLE_DEVICES"] = '2,3,4,5,6,7'
48
- os.environ["CUDA_VISIBLE_DEVICES"] = '0'
49
- # os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3,4,5'
50
- # os.environ["CUDA_VISIBLE_DEVICES"] = '6,7'
51
- # os.environ["CUDA_VISIBLE_DEVICES"] = '6'
52
- # setup_seed(config.seed)
53
-
54
-
55
- def main(config):
56
- init(config)
57
- ngpus_per_node = torch.cuda.device_count()
58
- if config.multiprocessing_distributed:
59
- config.world_size = ngpus_per_node * config.world_size
60
-
61
- print(config.world_size, ngpus_per_node, ngpus_per_node)
62
- mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config))
63
- else:
64
- # Simply call main_worker function
65
- main_worker(config.gpu, ngpus_per_node, config)
66
-
67
- print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())))
68
-
69
-
70
- @torch.no_grad()
71
- def gather_together(data): # 封装成一个函数,,用于收集各个gpu上的data数据,并返回一个list
72
- dist.barrier()
73
- world_size = dist.get_world_size()
74
- gather_data = [None for _ in range(world_size)]
75
- dist.all_gather_object(gather_data, data)
76
- return gather_data
77
-
78
- import importlib.util
79
- def main_worker(gpu, ngpus_per_node, args):
80
- models_path = os.path.join(args.save_path, "training_files", 'models', 'monet.py')
81
- spec = importlib.util.spec_from_file_location("monet_module", models_path)
82
- monet_module = importlib.util.module_from_spec(spec)
83
- spec.loader.exec_module(monet_module)
84
- MoNet = monet_module
85
-
86
- loger_path = os.path.join(args.save_path, "inference_log")
87
- if gpu == 0:
88
- sys.stdout = log_writer.Logger(os.path.join(loger_path, f"inference_log_{args.prompt_type}_{args.reverse}.log"))
89
- args.gpu = gpu
90
-
91
- # suppress printing if not master
92
- if args.multiprocessing_distributed and args.gpu != 0:
93
- def print_pass(*args):
94
- pass
95
-
96
- builtins.print = print_pass
97
-
98
- if args.gpu is not None:
99
- print("Use GPU: {} for testing".format(args.gpu))
100
-
101
- if args.distributed:
102
- if args.dist_url == "env://" and args.rank == -1:
103
- args.rank = int(os.environ["RANK"])
104
- if args.multiprocessing_distributed:
105
- args.rank = args.rank * ngpus_per_node + gpu
106
- dist.init_process_group(
107
- backend=args.dist_backend,
108
- init_method=args.dist_url,
109
- world_size=args.world_size,
110
- rank=args.rank,
111
- )
112
-
113
- # create model
114
- model = MoNet.MoNet()
115
- dict_pkl = {}
116
- prompt_num = torch.load(args.pkl_path, map_location='cpu').get('prompt_num')
117
- for key, value in torch.load(args.pkl_path, map_location='cpu')['state_dict'].items():
118
- dict_pkl[key[7:]] = value
119
- model.load_state_dict(dict_pkl)
120
- print('Load Model From ', args.pkl_path)
121
-
122
- if args.distributed:
123
- if args.gpu is not None:
124
- torch.cuda.set_device(args.gpu)
125
- model.cuda(args.gpu)
126
- args.batch_size = int(args.batch_size / ngpus_per_node)
127
- args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
128
- model = torch.nn.parallel.DistributedDataParallel(
129
- model, device_ids=[args.gpu]
130
- )
131
- print("Model Distribute.")
132
- else:
133
- model.cuda()
134
- model = torch.nn.parallel.DistributedDataParallel(model)
135
-
136
- if prompt_num is None:
137
- prompt_num = args.batch_size - 1
138
- prompt_num = 10
139
- print('prompt_num', prompt_num)
140
-
141
- test_prompt_list, test_data_list = {}, []
142
- # fix_prompt = None
143
- for dataset in args.dataset:
144
- print('---Load ', dataset)
145
- path, train_index, test_index = get_data(dataset=dataset, split_seed=args.seed)
146
- # if dataset == 'spaq' and False:
147
- if dataset == 'spaq':
148
- for column in range(2, 8):
149
- print('sapq column train', column)
150
- test_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, test_index, istrain=False, column=column)
151
- test_data_list.append(test_dataset.get_samples())
152
-
153
- train_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, train_index, istrain=False, column=column)
154
- test_prompt_list[dataset+f'_{column}'] = train_dataset.get_prompt(prompt_num, args.prompt_type)
155
- else:
156
- test_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, test_index, istrain=False, types=args.types)
157
- test_data_list.append(test_dataset.get_samples())
158
-
159
- train_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, train_index, istrain=False, types=args.types)
160
- test_prompt_list[dataset] = train_dataset.get_prompt(prompt_num, args.prompt_type)
161
- print('args.prompt_type', args.prompt_type)
162
-
163
- combined_test_samples = ConcatDataset(test_data_list)
164
- print("test_dataset", len(combined_test_samples))
165
- test_sampler = torch.utils.data.distributed.DistributedSampler(combined_test_samples)
166
-
167
- test_loader = torch.utils.data.DataLoader(
168
- combined_test_samples,
169
- batch_size=1,
170
- shuffle=(test_sampler is None),
171
- num_workers=args.workers,
172
- sampler=test_sampler,
173
- drop_last=False,
174
- pin_memory=True,
175
- )
176
-
177
- if args.distributed:
178
- test_sampler.set_epoch(0)
179
-
180
- for idxsa in range(1):
181
- test_srocc, test_plcc, pred_scores, gt_scores, path = test(
182
- test_loader, model, test_prompt_list, reverse=args.reverse
183
- )
184
- print('gt_scores', len(pred_scores), len(gt_scores))
185
- print('Summary---')
186
-
187
- gt_scores = gather_together(gt_scores) # 进行汇总,得到一个list
188
- pred_scores = gather_together(pred_scores) # 进行汇总,得到一个list
189
-
190
- gt_score_dict, pred_score_dict = {}, {}
191
- for sublist in gt_scores:
192
- for k, v in sublist.items():
193
- if k not in gt_score_dict:
194
- gt_score_dict[k] = v
195
- else:
196
- gt_score_dict[k] = gt_score_dict[k] + v
197
-
198
- for sublist in pred_scores:
199
- for k, v in sublist.items():
200
- if k not in pred_score_dict:
201
- pred_score_dict[k] = v
202
- else:
203
- pred_score_dict[k] = pred_score_dict[k] + v
204
-
205
- gt_score_dict = dict(sorted(gt_score_dict.items()))
206
- test_srocc, test_plcc = 0, 0
207
- for k, v in gt_score_dict.items():
208
- test_srocc_, test_plcc_ = cal_srocc_plcc(gt_score_dict[k], pred_score_dict[k])
209
- print('\t{} Test SROCC: {}, PLCC: {}'.format(k, round(test_srocc_, 4), round(test_plcc_, 4)))
210
- # print('Pred: ', pred_score_dict[k][:10])
211
- # print('GT: ', gt_score_dict[k][:10])
212
- # print('-----')
213
-
214
- with open('{}_{}.csv'.format(idxsa, k), 'w') as f:
215
- for i, j in zip(gt_score_dict[k], pred_score_dict[k]):
216
- f.write('{},{}\n'.format(i, j))
217
- test_srocc += test_srocc_
218
- test_plcc += test_plcc_
219
-
220
-
221
- def test(test_loader, MoNet, promt_data_loader, reverse=False):
222
- """Training"""
223
- pred_scores = {}
224
- gt_scores = {}
225
- path = []
226
-
227
- batch_time = AverageMeter("Time", ":6.3f")
228
- srocc = AverageMeter("SROCC", ":6.2f")
229
- plcc = AverageMeter("PLCC", ":6.2f")
230
- progress = ProgressMeter(
231
- len(test_loader),
232
- [batch_time, srocc, plcc],
233
- prefix="Testing ",
234
- )
235
-
236
- print('reverse ----', reverse)
237
- MoNet.train(False)
238
- with torch.no_grad():
239
- for index, (img_or, label_or, paths, dataset_type) in enumerate(test_loader):
240
- # print(dataset_type)
241
- t = time.time()
242
- dataset_type = dataset_type[0]
243
-
244
- has_prompt = False
245
- if hasattr(MoNet.module, 'check_prompt'):
246
- has_prompt = MoNet.module.check_prompt(dataset_type)
247
-
248
- if not has_prompt:
249
- print('Load Prompt For ', dataset_type)
250
- prompt_dataset = promt_data_loader[dataset_type]
251
- for img, label in prompt_dataset:
252
- img = img.squeeze(0).cuda()
253
- label = label.squeeze(0).cuda()
254
- if reverse == 2:
255
- # label = torch.tensor([random.random() for i in range(len(label[:, -1]))]).cuda()
256
- #
257
- label = torch.rand_like(label[:, -1]).cuda()
258
- print(label)
259
- elif reverse == 3:
260
- print('Total Random')
261
- label = torch.rand_like(label[:, -1]).cuda()
262
- img = torch.rand_like(img).cuda()
263
- else:
264
- label = label[:, -1].cuda() if not reverse else (1 - label[:, -1].cuda())
265
- MoNet.module.forward_prompt(img, label.reshape(-1, 1), dataset_type)
266
-
267
- img = img_or.squeeze(0).cuda()
268
- label = label_or.squeeze(0).cuda()[:, 2]
269
-
270
- # print(img.shape)
271
-
272
- pred = MoNet.module.inference(img, dataset_type)
273
-
274
- if dataset_type not in pred_scores:
275
- pred_scores[dataset_type] = []
276
-
277
- if dataset_type not in gt_scores:
278
- gt_scores[dataset_type] = []
279
-
280
- pred_scores[dataset_type] = pred_scores[dataset_type] + pred.cpu().tolist()
281
- gt_scores[dataset_type] = gt_scores[dataset_type] + label.cpu().tolist()
282
- path = path + list(paths)
283
-
284
- batch_time.update(time.time() - t)
285
-
286
- if index % 100 == 0:
287
- for k, v in pred_scores.items():
288
- test_srocc, test_plcc = cal_srocc_plcc(pred_scores[k], gt_scores[k])
289
- # print('\t{}, SROCC: {}, PLCC: {}'.format(k, round(test_srocc, 4), round(test_plcc, 4)))
290
- srocc.update(test_srocc)
291
- plcc.update(test_plcc)
292
-
293
- progress.display(index)
294
-
295
- MoNet.module.clear()
296
- # MoNet.train(True)
297
- return 'test_srocc', 'test_plcc', pred_scores, gt_scores, path
298
-
299
- if __name__ == "__main__":
300
- parser = argparse.ArgumentParser()
301
- parser.add_argument(
302
- "--seed",
303
- dest="seed",
304
- type=int,
305
- default=570908,
306
- help="Random seeds for result reproduction.",
307
- )
308
-
309
- parser.add_argument(
310
- "--mal_num",
311
- dest="mal_num",
312
- type=int,
313
- default=2,
314
- help="The number of the MAL modules.",
315
- )
316
-
317
- # data related
318
- parser.add_argument(
319
- "--dataset",
320
- dest="dataset",
321
- nargs='+', default=None,
322
- help="Support datasets: livec|koniq10k|bid|spaq",
323
- )
324
-
325
- # training related
326
- parser.add_argument(
327
- "--queue_ratio",
328
- dest="queue_ratio",
329
- type=float,
330
- default=0.6,
331
- help="Ratio of queue length used in GC loss to training set length.",
332
- )
333
-
334
- parser.add_argument(
335
- "--loss",
336
- dest="loss",
337
- type=str,
338
- default="MSE",
339
- help="Loss function to use. Support losses: GC|MAE|MSE.",
340
- )
341
-
342
- parser.add_argument(
343
- "--lr", dest="lr", type=float, default=1e-5, help="Learning rate"
344
- )
345
-
346
- parser.add_argument(
347
- "--weight_decay",
348
- dest="weight_decay",
349
- type=float,
350
- default=1e-5,
351
- help="Weight decay",
352
- )
353
- parser.add_argument(
354
- "--batch_size", dest="batch_size", type=int, default=11, help="Batch size"
355
- )
356
- parser.add_argument(
357
- "--epochs", dest="epochs", type=int, default=50, help="Epochs for training"
358
- )
359
- parser.add_argument(
360
- "--T_max",
361
- dest="T_max",
362
- type=int,
363
- default=50,
364
- help="Hyper-parameter for CosineAnnealingLR",
365
- )
366
- parser.add_argument(
367
- "--eta_min",
368
- dest="eta_min",
369
- type=int,
370
- default=0,
371
- help="Hyper-parameter for CosineAnnealingLR",
372
- )
373
-
374
- parser.add_argument(
375
- "-j",
376
- "--workers",
377
- default=32,
378
- type=int,
379
- metavar="N",
380
- help="number of data loading workers (default: 32)",
381
- )
382
-
383
- # result related
384
- parser.add_argument(
385
- "--save_path",
386
- dest="save_path",
387
- type=str,
388
- default="./save_logs/Matrix_Comparation_Koniq_bs_25",
389
- help="The path where the model and logs will be saved.",
390
- )
391
-
392
- parser.add_argument(
393
- "--world-size",
394
- default=-1,
395
- type=int,
396
- help="number of nodes for distributed training",
397
- )
398
- parser.add_argument(
399
- "--rank", default=-1, type=int, help="node rank for distributed training"
400
- )
401
- parser.add_argument(
402
- "--dist-url",
403
- default="tcp://224.66.41.62:23456",
404
- type=str,
405
- help="url used to set up distributed training",
406
- )
407
- parser.add_argument(
408
- "--dist-backend", default="nccl", type=str, help="distributed backend"
409
- )
410
- parser.add_argument(
411
- "--multiprocessing-distributed",
412
- action="store_true",
413
- help="Use multi-processing distributed training to launch "
414
- "N processes per node, which has N GPUs. This is the "
415
- "fastest way to use PyTorch for either single node or "
416
- "multi node data parallel training",
417
- )
418
-
419
- parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.")
420
- parser.add_argument("--pkl_path", required=True, type=str)
421
- parser.add_argument("--prompt_type", required=True, type=str)
422
- parser.add_argument("--reverse", required=True, type=int)
423
- parser.add_argument("--types", default='SSIM', type=str)
424
-
425
- config = parser.parse_args()
426
-
427
- config.save_path = os.path.dirname(config.pkl_path)
428
-
429
- main(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PromptIQA/test.sh DELETED
@@ -1,9 +0,0 @@
1
- # python test.py --dist-url 'tcp://localhost:10055' --dataset spaq tid2013 livec bid spaq flive --batch_size 50 --prompt_type fix --multiprocessing-distributed --world-size 1 --rank 0 --reverse 0 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/FourTask/N_F_A_U_RandomScale_MAE_loaderDebug_Rate95/best_model_five_52.pth.tar
2
- # python test.py --dist-url 'tcp://localhost:12755' --dataset csiq --batch_size 50 --prompt_type fix --multiprocessing-distributed --world-size 1 --rank 0 --reverse 3 --seed 2024 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Training_log/FourTask/N_F_A_U_RandomScale_MAE_loaderDebug_Rate95/best_model_five_52.pth.tar
3
- python test.py --dist-url 'tcp://localhost:12755' --dataset livec bid csiq --batch_size 50 --prompt_type random --multiprocessing-distributed --world-size 1 --rank 0 --reverse 2 --seed 2026 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Formal/PromptIQA_2026/best_model_five_92.pth.tar
4
- # reverse 0 no, 1 yes, 2 random
5
-
6
- python test.py --dist-url 'tcp://localhost:12755' --dataset tid2013_other --batch_size 50 --prompt_type random --multiprocessing-distributed --world-size 1 --rank 0 --reverse 2 --seed 2026 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Formal/PromptIQA_2026/best_model_five_92.pth.tar --types 'SSIM'
7
-
8
-
9
- CUDA_VISIBLE_DEVICES="0" python test.py --dist-url 'tcp://localhost:12755' --dataset tid2013_other --batch_size 50 --prompt_type random --multiprocessing-distributed --world-size 1 --rank 0 --reverse 2 --seed 2024 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Publication/PromptIQA_2024_WO_Norm_Score/best_model_five_22.pth.tar --types 'SSIM'
 
 
 
 
 
 
 
 
 
 
best_model.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:993555b9efaeae660d2dd6f4056f13c6957628ca592a2ce74ff2e8eb5a4a2280
3
+ size 1272842308
get_examplt.py DELETED
@@ -1,27 +0,0 @@
1
- import os
2
- from copy import deepcopy
3
-
4
-
5
- isp_json = []
6
- path = './Examples'
7
- for img_dir in sorted(os.listdir(path)):
8
- if os.path.isdir(os.path.join(path, img_dir)):
9
- ISPP = os.path.join(path, img_dir, 'ISPP')
10
-
11
- ispp = {}
12
- ispp['Example_id'] = img_dir
13
- ispp['ISPP'] = []
14
- img_list = []
15
- for idx, img in enumerate(sorted(os.listdir(ISPP))):
16
- ispp['ISPP'].append([os.path.join(ISPP, img), idx / 10 if '1' in img_dir else 1 - idx / 10])
17
-
18
- for file in os.listdir(os.path.join(path, img_dir)):
19
- if os.path.isfile(os.path.join(path, img_dir, file)):
20
- img_list.append(file)
21
- ispp['Image'] = [os.path.join(path, img_dir, file), 7]
22
- ispp['Remark'] = []
23
- isp_json.append(deepcopy(ispp))
24
-
25
- with open('example2.json', 'w') as f:
26
- import json
27
- json.dump(isp_json, f, indent=4)