Testing
Browse files- SkinGPT.py +2 -0
- test.py +39 -30
SkinGPT.py
CHANGED
@@ -41,6 +41,7 @@ class Blip2QFormer(nn.Module):
|
|
41 |
)
|
42 |
self.vision_proj = nn.Linear(vision_width, self.bert_config.hidden_size)
|
43 |
self._init_weights()
|
|
|
44 |
|
45 |
def _init_weights(self):
|
46 |
nn.init.normal_(self.query_tokens, std=0.02)
|
@@ -71,6 +72,7 @@ class Blip2QFormer(nn.Module):
|
|
71 |
output_attentions=True,
|
72 |
return_dict=True
|
73 |
)
|
|
|
74 |
return outputs.last_hidden_state[:, :self.num_query_tokens]
|
75 |
|
76 |
|
|
|
41 |
)
|
42 |
self.vision_proj = nn.Linear(vision_width, self.bert_config.hidden_size)
|
43 |
self._init_weights()
|
44 |
+
self.last_attention = None
|
45 |
|
46 |
def _init_weights(self):
|
47 |
nn.init.normal_(self.query_tokens, std=0.02)
|
|
|
72 |
output_attentions=True,
|
73 |
return_dict=True
|
74 |
)
|
75 |
+
self.last_attention = outputs.attentions[-1]
|
76 |
return outputs.last_hidden_state[:, :self.num_query_tokens]
|
77 |
|
78 |
|
test.py
CHANGED
@@ -27,37 +27,46 @@ class SkinGPTTester:
|
|
27 |
with torch.no_grad():
|
28 |
# Get attention maps
|
29 |
_ = self.classifier.model.encode_image(image_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
attention = self.classifier.model.q_former.last_attention[0].mean(dim=0)
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
61 |
|
62 |
def check_feature_similarity(self, image_path1, image_path2):
|
63 |
"""Compare embeddings of two images"""
|
|
|
27 |
with torch.no_grad():
|
28 |
# Get attention maps
|
29 |
_ = self.classifier.model.encode_image(image_tensor)
|
30 |
+
|
31 |
+
# Get attention from Q-Former
|
32 |
+
if self.classifier.model.q_former.last_attention is None:
|
33 |
+
print("Warning: No attention maps available. Make sure output_attentions=True in BERT config.")
|
34 |
+
return
|
35 |
+
|
36 |
+
# Get the last layer's attention
|
37 |
attention = self.classifier.model.q_former.last_attention[0].mean(dim=0)
|
38 |
+
|
39 |
+
# Reshape attention to image size
|
40 |
+
h = w = int(math.sqrt(attention.shape[1]))
|
41 |
+
attention = attention.reshape(h, w)
|
42 |
+
|
43 |
+
# Plot
|
44 |
+
plt.figure(figsize=(15, 5))
|
45 |
+
|
46 |
+
# Original image
|
47 |
+
plt.subplot(1, 3, 1)
|
48 |
+
plt.imshow(image)
|
49 |
+
plt.title('Original Image')
|
50 |
+
plt.axis('off')
|
51 |
+
|
52 |
+
# Attention map
|
53 |
+
plt.subplot(1, 3, 2)
|
54 |
+
plt.imshow(attention, cmap='hot')
|
55 |
+
plt.title('Attention Map')
|
56 |
+
plt.axis('off')
|
57 |
+
|
58 |
+
# Overlay
|
59 |
+
plt.subplot(1, 3, 3)
|
60 |
+
plt.imshow(image)
|
61 |
+
plt.imshow(attention, alpha=0.5, cmap='hot')
|
62 |
+
plt.title('Attention Overlay')
|
63 |
+
plt.axis('off')
|
64 |
+
|
65 |
+
plt.tight_layout()
|
66 |
+
plt.savefig('attention_visualization.png')
|
67 |
+
plt.close()
|
68 |
+
|
69 |
+
print(f"Attention visualization saved as 'attention_visualization.png'")
|
70 |
|
71 |
def check_feature_similarity(self, image_path1, image_path2):
|
72 |
"""Compare embeddings of two images"""
|