AAOBA commited on
Commit
f7e3261
·
1 Parent(s): baa1359

Updated RM.md

Browse files
Files changed (1) hide show
  1. score_it.py +165 -0
score_it.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms as transforms
5
+
6
+ import timm
7
+
8
+ from PIL import Image
9
+
10
+ import matplotlib.pyplot as plt
11
+
12
+ import os
13
+
14
+ # Thanks to ( ), proxy can be essentail :)
15
+ # os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:10809'
16
+ # os.environ['HTTP_PROXY'] = 'http://127.0.0.1:10809'
17
+ # os.environ['ALL_PROXY'] = 'socks5://127.0.0.1:10808'
18
+
19
+ IMG_FILE_LIST = [
20
+ './testcases/14.jpg',
21
+ './testcases/15.jpg',
22
+ './testcases/16.jpg',
23
+ './testcases/17.jpg',
24
+ './testcases/18.jpg',
25
+ './testcases/19.jpg'
26
+ ]
27
+
28
+ TANH_SCALE = 1
29
+
30
+
31
+ class Scorer(nn.Module):
32
+ def __init__(
33
+ self,
34
+ model_name,
35
+ pretrained=False,
36
+ features_only=True,
37
+ embedding_dim=128
38
+ ):
39
+ super(Scorer, self).__init__()
40
+ self.model = timm.create_model(model_name, pretrained=pretrained, features_only=features_only)
41
+ pooled_dim = 128 + 256 + 512 + 1024
42
+ self.layer_norms = nn.ModuleList([
43
+ nn.LayerNorm(128),
44
+ nn.LayerNorm(256),
45
+ nn.LayerNorm(512),
46
+ nn.LayerNorm(1024)
47
+ ])
48
+ self.mlp = nn.Sequential(
49
+ nn.Linear(pooled_dim, pooled_dim),
50
+ nn.BatchNorm1d(pooled_dim),
51
+ nn.GELU(),
52
+ )
53
+ # Probably a BYOL-accidental BatchNorm could help ?
54
+ self.mlp_1 = nn.Sequential(
55
+ nn.Linear(pooled_dim, pooled_dim // 4),
56
+ nn.BatchNorm1d(pooled_dim // 4),
57
+ nn.GELU(),
58
+ nn.Linear(pooled_dim // 4, 3),
59
+ nn.Tanh()
60
+ )
61
+ self.mlp_2 = nn.Sequential(
62
+ nn.Linear(pooled_dim, pooled_dim // 4),
63
+ nn.GELU(),
64
+ nn.Linear(pooled_dim // 4, 1),
65
+ )
66
+
67
+ def forward(self, x, upload_date=None, freeze_backbone=False):
68
+ if freeze_backbone:
69
+ with torch.no_grad():
70
+ out_features = self.model(x)
71
+ else:
72
+ out_features = self.model(x)
73
+ # out_features: List [
74
+ # torch.Size([1, 128, x, x])
75
+ # torch.Size([1, 256, x, x])
76
+ # torch.Size([1, 512, x, x])
77
+ # torch.Size([1, 1024, x, x])
78
+ # ]
79
+ # Pool the output features from each layer on the channel dimension
80
+ pooled_features = [F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1) for x in out_features]
81
+ # Normalize the pooled features
82
+ pooled_features = [self.layer_norms[i](x) for i, x in enumerate(pooled_features)]
83
+ # Embed the upload date
84
+ # date_embedding_features = self.embedding(upload_date)
85
+ # Concatenate the pooled features
86
+ out = torch.cat(pooled_features, dim=-1)
87
+ # Concatenate the date embedding features
88
+ # out = torch.cat([out, date_embedding_features], dim=-1)
89
+ out = self.mlp(out)
90
+ rl_out = self.mlp_1(out) * TANH_SCALE
91
+ ai_out = self.mlp_2(out).squeeze(-1)
92
+ return rl_out[:, 0], rl_out[:, 1], F.sigmoid(ai_out), rl_out[:, 2]
93
+
94
+
95
+ BACKBONE = 'convnextv2_base.fcmae'
96
+ RESOLUTION = 640
97
+ SHOW_GRAD = False
98
+ GRAD_SCALE = 50
99
+
100
+ MORE_LIKE = False
101
+ MORE_COLLECTION = False
102
+ LESS_AI = False
103
+ MORE_RELATIVE_POP = True
104
+
105
+ WEIGHT_PATH = './scorer.pt'
106
+
107
+ DECIVE = 'cuda'
108
+
109
+
110
+ def main():
111
+ model = Scorer(BACKBONE)
112
+ transform = transforms.Compose([
113
+ transforms.Resize((RESOLUTION, RESOLUTION)),
114
+ transforms.ToTensor(),
115
+ transforms.Normalize(
116
+ mean=[0.485, 0.456, 0.406],
117
+ std=[0.229, 0.224, 0.225]
118
+ )
119
+ ])
120
+ model.load_state_dict(torch.load(WEIGHT_PATH))
121
+ model.eval()
122
+ model.to(DECIVE)
123
+
124
+ # Show all the images in pyplot horizontally, and mark the predicted values under each image
125
+ fig = plt.figure(figsize=(20, 20))
126
+ for i, img_file in enumerate(IMG_FILE_LIST):
127
+ img = Image.open(img_file, 'r').convert('RGB')
128
+ transformed_img = transform(img).unsqueeze(0).to(DECIVE)
129
+ transformed_img.requires_grad = True
130
+ liking_pred, collection_pred, ai_pred, relative_pop = model(transformed_img, torch.tensor([1]), False)
131
+ ax = fig.add_subplot(1, len(IMG_FILE_LIST), i + 1)
132
+
133
+ backwardee = 0
134
+ if MORE_LIKE:
135
+ backwardee -= liking_pred
136
+ if MORE_COLLECTION:
137
+ backwardee -= collection_pred
138
+ if LESS_AI:
139
+ backwardee += ai_pred
140
+ if MORE_RELATIVE_POP:
141
+ backwardee -= relative_pop
142
+ if SHOW_GRAD:
143
+ model.zero_grad()
144
+ # Figure out which part of the image is the most important to popularity
145
+ backwardee.backward()
146
+ # Get the gradients of the image, and normalize them
147
+ gradients = transformed_img.grad
148
+ # squeeze the batch dimension
149
+ gradients = gradients.squeeze(0).detach()
150
+ # resize the gradients to the same size as the image
151
+ gradients = transforms.Resize((img.height, img.width))(gradients)
152
+ # add the gradients to the image
153
+ img = transforms.ToTensor()(img)
154
+ img = img + gradients.cpu() * GRAD_SCALE
155
+ img = transforms.ToPILImage()(img.cpu())
156
+ ax.imshow(img)
157
+ del img
158
+ ax.set_title(
159
+ f'Liking: {liking_pred.item():.3f}\nCollection: {collection_pred.item():.3f}\nAI: {ai_pred.item() * 100:.3f}%\nPopularity: {relative_pop.item():.3f}')
160
+ plt.show()
161
+ pass
162
+
163
+
164
+ if __name__ == '__main__':
165
+ main()