KeerthiVM commited on
Commit
0e15733
·
1 Parent(s): f46a35d
Files changed (1) hide show
  1. test.py +26 -5
test.py CHANGED
@@ -36,9 +36,30 @@ class SkinGPTTester:
36
  # Get the last layer's attention
37
  attention = self.classifier.model.q_former.last_attention[0].mean(dim=0)
38
 
39
- # Reshape attention to image size
40
- h = w = int(math.sqrt(attention.shape[1]))
41
- attention = attention.reshape(h, w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # Plot
44
  plt.figure(figsize=(15, 5))
@@ -51,14 +72,14 @@ class SkinGPTTester:
51
 
52
  # Attention map
53
  plt.subplot(1, 3, 2)
54
- plt.imshow(attention, cmap='hot')
55
  plt.title('Attention Map')
56
  plt.axis('off')
57
 
58
  # Overlay
59
  plt.subplot(1, 3, 3)
60
  plt.imshow(image)
61
- plt.imshow(attention, alpha=0.5, cmap='hot')
62
  plt.title('Attention Overlay')
63
  plt.axis('off')
64
 
 
36
  # Get the last layer's attention
37
  attention = self.classifier.model.q_former.last_attention[0].mean(dim=0)
38
 
39
+ # Print attention shape for debugging
40
+ print(f"Attention shape: {attention.shape}")
41
+
42
+ # Reshape attention to match image dimensions
43
+ # The attention shape should be [num_query_tokens + num_patches, num_query_tokens + num_patches]
44
+ # We want to visualize the attention from query tokens to image patches
45
+ num_query_tokens = self.classifier.model.q_former.num_query_tokens
46
+ attention_to_patches = attention[num_query_tokens:, :num_query_tokens].mean(dim=1)
47
+
48
+ # Calculate the number of patches
49
+ num_patches = attention_to_patches.shape[0]
50
+ h = w = int(math.sqrt(num_patches))
51
+
52
+ if h * w != num_patches:
53
+ print(f"Warning: Number of patches ({num_patches}) is not a perfect square")
54
+ # Use the closest square dimensions
55
+ h = w = int(math.ceil(math.sqrt(num_patches)))
56
+ # Pad the attention map to make it square
57
+ padded_attention = torch.zeros(h * w, device=attention_to_patches.device)
58
+ padded_attention[:num_patches] = attention_to_patches
59
+ attention_to_patches = padded_attention
60
+
61
+ # Reshape to 2D
62
+ attention_map = attention_to_patches.reshape(h, w)
63
 
64
  # Plot
65
  plt.figure(figsize=(15, 5))
 
72
 
73
  # Attention map
74
  plt.subplot(1, 3, 2)
75
+ plt.imshow(attention_map.cpu().numpy(), cmap='hot')
76
  plt.title('Attention Map')
77
  plt.axis('off')
78
 
79
  # Overlay
80
  plt.subplot(1, 3, 3)
81
  plt.imshow(image)
82
+ plt.imshow(attention_map.cpu().numpy(), alpha=0.5, cmap='hot')
83
  plt.title('Attention Overlay')
84
  plt.axis('off')
85