Updated RM.md
Browse files- score_it.py +165 -0
score_it.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
|
6 |
+
import timm
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
|
12 |
+
import os
|
13 |
+
|
14 |
+
# Thanks to ( ), proxy can be essentail :)
|
15 |
+
# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:10809'
|
16 |
+
# os.environ['HTTP_PROXY'] = 'http://127.0.0.1:10809'
|
17 |
+
# os.environ['ALL_PROXY'] = 'socks5://127.0.0.1:10808'
|
18 |
+
|
19 |
+
IMG_FILE_LIST = [
|
20 |
+
'./testcases/14.jpg',
|
21 |
+
'./testcases/15.jpg',
|
22 |
+
'./testcases/16.jpg',
|
23 |
+
'./testcases/17.jpg',
|
24 |
+
'./testcases/18.jpg',
|
25 |
+
'./testcases/19.jpg'
|
26 |
+
]
|
27 |
+
|
28 |
+
TANH_SCALE = 1
|
29 |
+
|
30 |
+
|
31 |
+
class Scorer(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
model_name,
|
35 |
+
pretrained=False,
|
36 |
+
features_only=True,
|
37 |
+
embedding_dim=128
|
38 |
+
):
|
39 |
+
super(Scorer, self).__init__()
|
40 |
+
self.model = timm.create_model(model_name, pretrained=pretrained, features_only=features_only)
|
41 |
+
pooled_dim = 128 + 256 + 512 + 1024
|
42 |
+
self.layer_norms = nn.ModuleList([
|
43 |
+
nn.LayerNorm(128),
|
44 |
+
nn.LayerNorm(256),
|
45 |
+
nn.LayerNorm(512),
|
46 |
+
nn.LayerNorm(1024)
|
47 |
+
])
|
48 |
+
self.mlp = nn.Sequential(
|
49 |
+
nn.Linear(pooled_dim, pooled_dim),
|
50 |
+
nn.BatchNorm1d(pooled_dim),
|
51 |
+
nn.GELU(),
|
52 |
+
)
|
53 |
+
# Probably a BYOL-accidental BatchNorm could help ?
|
54 |
+
self.mlp_1 = nn.Sequential(
|
55 |
+
nn.Linear(pooled_dim, pooled_dim // 4),
|
56 |
+
nn.BatchNorm1d(pooled_dim // 4),
|
57 |
+
nn.GELU(),
|
58 |
+
nn.Linear(pooled_dim // 4, 3),
|
59 |
+
nn.Tanh()
|
60 |
+
)
|
61 |
+
self.mlp_2 = nn.Sequential(
|
62 |
+
nn.Linear(pooled_dim, pooled_dim // 4),
|
63 |
+
nn.GELU(),
|
64 |
+
nn.Linear(pooled_dim // 4, 1),
|
65 |
+
)
|
66 |
+
|
67 |
+
def forward(self, x, upload_date=None, freeze_backbone=False):
|
68 |
+
if freeze_backbone:
|
69 |
+
with torch.no_grad():
|
70 |
+
out_features = self.model(x)
|
71 |
+
else:
|
72 |
+
out_features = self.model(x)
|
73 |
+
# out_features: List [
|
74 |
+
# torch.Size([1, 128, x, x])
|
75 |
+
# torch.Size([1, 256, x, x])
|
76 |
+
# torch.Size([1, 512, x, x])
|
77 |
+
# torch.Size([1, 1024, x, x])
|
78 |
+
# ]
|
79 |
+
# Pool the output features from each layer on the channel dimension
|
80 |
+
pooled_features = [F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1) for x in out_features]
|
81 |
+
# Normalize the pooled features
|
82 |
+
pooled_features = [self.layer_norms[i](x) for i, x in enumerate(pooled_features)]
|
83 |
+
# Embed the upload date
|
84 |
+
# date_embedding_features = self.embedding(upload_date)
|
85 |
+
# Concatenate the pooled features
|
86 |
+
out = torch.cat(pooled_features, dim=-1)
|
87 |
+
# Concatenate the date embedding features
|
88 |
+
# out = torch.cat([out, date_embedding_features], dim=-1)
|
89 |
+
out = self.mlp(out)
|
90 |
+
rl_out = self.mlp_1(out) * TANH_SCALE
|
91 |
+
ai_out = self.mlp_2(out).squeeze(-1)
|
92 |
+
return rl_out[:, 0], rl_out[:, 1], F.sigmoid(ai_out), rl_out[:, 2]
|
93 |
+
|
94 |
+
|
95 |
+
BACKBONE = 'convnextv2_base.fcmae'
|
96 |
+
RESOLUTION = 640
|
97 |
+
SHOW_GRAD = False
|
98 |
+
GRAD_SCALE = 50
|
99 |
+
|
100 |
+
MORE_LIKE = False
|
101 |
+
MORE_COLLECTION = False
|
102 |
+
LESS_AI = False
|
103 |
+
MORE_RELATIVE_POP = True
|
104 |
+
|
105 |
+
WEIGHT_PATH = './scorer.pt'
|
106 |
+
|
107 |
+
DECIVE = 'cuda'
|
108 |
+
|
109 |
+
|
110 |
+
def main():
|
111 |
+
model = Scorer(BACKBONE)
|
112 |
+
transform = transforms.Compose([
|
113 |
+
transforms.Resize((RESOLUTION, RESOLUTION)),
|
114 |
+
transforms.ToTensor(),
|
115 |
+
transforms.Normalize(
|
116 |
+
mean=[0.485, 0.456, 0.406],
|
117 |
+
std=[0.229, 0.224, 0.225]
|
118 |
+
)
|
119 |
+
])
|
120 |
+
model.load_state_dict(torch.load(WEIGHT_PATH))
|
121 |
+
model.eval()
|
122 |
+
model.to(DECIVE)
|
123 |
+
|
124 |
+
# Show all the images in pyplot horizontally, and mark the predicted values under each image
|
125 |
+
fig = plt.figure(figsize=(20, 20))
|
126 |
+
for i, img_file in enumerate(IMG_FILE_LIST):
|
127 |
+
img = Image.open(img_file, 'r').convert('RGB')
|
128 |
+
transformed_img = transform(img).unsqueeze(0).to(DECIVE)
|
129 |
+
transformed_img.requires_grad = True
|
130 |
+
liking_pred, collection_pred, ai_pred, relative_pop = model(transformed_img, torch.tensor([1]), False)
|
131 |
+
ax = fig.add_subplot(1, len(IMG_FILE_LIST), i + 1)
|
132 |
+
|
133 |
+
backwardee = 0
|
134 |
+
if MORE_LIKE:
|
135 |
+
backwardee -= liking_pred
|
136 |
+
if MORE_COLLECTION:
|
137 |
+
backwardee -= collection_pred
|
138 |
+
if LESS_AI:
|
139 |
+
backwardee += ai_pred
|
140 |
+
if MORE_RELATIVE_POP:
|
141 |
+
backwardee -= relative_pop
|
142 |
+
if SHOW_GRAD:
|
143 |
+
model.zero_grad()
|
144 |
+
# Figure out which part of the image is the most important to popularity
|
145 |
+
backwardee.backward()
|
146 |
+
# Get the gradients of the image, and normalize them
|
147 |
+
gradients = transformed_img.grad
|
148 |
+
# squeeze the batch dimension
|
149 |
+
gradients = gradients.squeeze(0).detach()
|
150 |
+
# resize the gradients to the same size as the image
|
151 |
+
gradients = transforms.Resize((img.height, img.width))(gradients)
|
152 |
+
# add the gradients to the image
|
153 |
+
img = transforms.ToTensor()(img)
|
154 |
+
img = img + gradients.cpu() * GRAD_SCALE
|
155 |
+
img = transforms.ToPILImage()(img.cpu())
|
156 |
+
ax.imshow(img)
|
157 |
+
del img
|
158 |
+
ax.set_title(
|
159 |
+
f'Liking: {liking_pred.item():.3f}\nCollection: {collection_pred.item():.3f}\nAI: {ai_pred.item() * 100:.3f}%\nPopularity: {relative_pop.item():.3f}')
|
160 |
+
plt.show()
|
161 |
+
pass
|
162 |
+
|
163 |
+
|
164 |
+
if __name__ == '__main__':
|
165 |
+
main()
|