Testing
Browse files
test.py
CHANGED
@@ -36,9 +36,30 @@ class SkinGPTTester:
|
|
36 |
# Get the last layer's attention
|
37 |
attention = self.classifier.model.q_former.last_attention[0].mean(dim=0)
|
38 |
|
39 |
-
#
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
# Plot
|
44 |
plt.figure(figsize=(15, 5))
|
@@ -51,14 +72,14 @@ class SkinGPTTester:
|
|
51 |
|
52 |
# Attention map
|
53 |
plt.subplot(1, 3, 2)
|
54 |
-
plt.imshow(
|
55 |
plt.title('Attention Map')
|
56 |
plt.axis('off')
|
57 |
|
58 |
# Overlay
|
59 |
plt.subplot(1, 3, 3)
|
60 |
plt.imshow(image)
|
61 |
-
plt.imshow(
|
62 |
plt.title('Attention Overlay')
|
63 |
plt.axis('off')
|
64 |
|
|
|
36 |
# Get the last layer's attention
|
37 |
attention = self.classifier.model.q_former.last_attention[0].mean(dim=0)
|
38 |
|
39 |
+
# Print attention shape for debugging
|
40 |
+
print(f"Attention shape: {attention.shape}")
|
41 |
+
|
42 |
+
# Reshape attention to match image dimensions
|
43 |
+
# The attention shape should be [num_query_tokens + num_patches, num_query_tokens + num_patches]
|
44 |
+
# We want to visualize the attention from query tokens to image patches
|
45 |
+
num_query_tokens = self.classifier.model.q_former.num_query_tokens
|
46 |
+
attention_to_patches = attention[num_query_tokens:, :num_query_tokens].mean(dim=1)
|
47 |
+
|
48 |
+
# Calculate the number of patches
|
49 |
+
num_patches = attention_to_patches.shape[0]
|
50 |
+
h = w = int(math.sqrt(num_patches))
|
51 |
+
|
52 |
+
if h * w != num_patches:
|
53 |
+
print(f"Warning: Number of patches ({num_patches}) is not a perfect square")
|
54 |
+
# Use the closest square dimensions
|
55 |
+
h = w = int(math.ceil(math.sqrt(num_patches)))
|
56 |
+
# Pad the attention map to make it square
|
57 |
+
padded_attention = torch.zeros(h * w, device=attention_to_patches.device)
|
58 |
+
padded_attention[:num_patches] = attention_to_patches
|
59 |
+
attention_to_patches = padded_attention
|
60 |
+
|
61 |
+
# Reshape to 2D
|
62 |
+
attention_map = attention_to_patches.reshape(h, w)
|
63 |
|
64 |
# Plot
|
65 |
plt.figure(figsize=(15, 5))
|
|
|
72 |
|
73 |
# Attention map
|
74 |
plt.subplot(1, 3, 2)
|
75 |
+
plt.imshow(attention_map.cpu().numpy(), cmap='hot')
|
76 |
plt.title('Attention Map')
|
77 |
plt.axis('off')
|
78 |
|
79 |
# Overlay
|
80 |
plt.subplot(1, 3, 3)
|
81 |
plt.imshow(image)
|
82 |
+
plt.imshow(attention_map.cpu().numpy(), alpha=0.5, cmap='hot')
|
83 |
plt.title('Attention Overlay')
|
84 |
plt.axis('off')
|
85 |
|