KeerthiVM commited on
Commit
f46a35d
·
1 Parent(s): 167ea92
Files changed (2) hide show
  1. SkinGPT.py +2 -0
  2. test.py +39 -30
SkinGPT.py CHANGED
@@ -41,6 +41,7 @@ class Blip2QFormer(nn.Module):
41
  )
42
  self.vision_proj = nn.Linear(vision_width, self.bert_config.hidden_size)
43
  self._init_weights()
 
44
 
45
  def _init_weights(self):
46
  nn.init.normal_(self.query_tokens, std=0.02)
@@ -71,6 +72,7 @@ class Blip2QFormer(nn.Module):
71
  output_attentions=True,
72
  return_dict=True
73
  )
 
74
  return outputs.last_hidden_state[:, :self.num_query_tokens]
75
 
76
 
 
41
  )
42
  self.vision_proj = nn.Linear(vision_width, self.bert_config.hidden_size)
43
  self._init_weights()
44
+ self.last_attention = None
45
 
46
  def _init_weights(self):
47
  nn.init.normal_(self.query_tokens, std=0.02)
 
72
  output_attentions=True,
73
  return_dict=True
74
  )
75
+ self.last_attention = outputs.attentions[-1]
76
  return outputs.last_hidden_state[:, :self.num_query_tokens]
77
 
78
 
test.py CHANGED
@@ -27,37 +27,46 @@ class SkinGPTTester:
27
  with torch.no_grad():
28
  # Get attention maps
29
  _ = self.classifier.model.encode_image(image_tensor)
 
 
 
 
 
 
 
30
  attention = self.classifier.model.q_former.last_attention[0].mean(dim=0)
31
-
32
- # Reshape attention to image size
33
- h = w = int(math.sqrt(attention.shape[1]))
34
- attention = attention.reshape(h, w)
35
-
36
- # Plot
37
- plt.figure(figsize=(15, 5))
38
-
39
- # Original image
40
- plt.subplot(1, 3, 1)
41
- plt.imshow(image)
42
- plt.title('Original Image')
43
- plt.axis('off')
44
-
45
- # Attention map
46
- plt.subplot(1, 3, 2)
47
- plt.imshow(attention, cmap='hot')
48
- plt.title('Attention Map')
49
- plt.axis('off')
50
-
51
- # Overlay
52
- plt.subplot(1, 3, 3)
53
- plt.imshow(image)
54
- plt.imshow(attention, alpha=0.5, cmap='hot')
55
- plt.title('Attention Overlay')
56
- plt.axis('off')
57
-
58
- plt.tight_layout()
59
- plt.savefig('attention_visualization.png')
60
- plt.close()
 
 
61
 
62
  def check_feature_similarity(self, image_path1, image_path2):
63
  """Compare embeddings of two images"""
 
27
  with torch.no_grad():
28
  # Get attention maps
29
  _ = self.classifier.model.encode_image(image_tensor)
30
+
31
+ # Get attention from Q-Former
32
+ if self.classifier.model.q_former.last_attention is None:
33
+ print("Warning: No attention maps available. Make sure output_attentions=True in BERT config.")
34
+ return
35
+
36
+ # Get the last layer's attention
37
  attention = self.classifier.model.q_former.last_attention[0].mean(dim=0)
38
+
39
+ # Reshape attention to image size
40
+ h = w = int(math.sqrt(attention.shape[1]))
41
+ attention = attention.reshape(h, w)
42
+
43
+ # Plot
44
+ plt.figure(figsize=(15, 5))
45
+
46
+ # Original image
47
+ plt.subplot(1, 3, 1)
48
+ plt.imshow(image)
49
+ plt.title('Original Image')
50
+ plt.axis('off')
51
+
52
+ # Attention map
53
+ plt.subplot(1, 3, 2)
54
+ plt.imshow(attention, cmap='hot')
55
+ plt.title('Attention Map')
56
+ plt.axis('off')
57
+
58
+ # Overlay
59
+ plt.subplot(1, 3, 3)
60
+ plt.imshow(image)
61
+ plt.imshow(attention, alpha=0.5, cmap='hot')
62
+ plt.title('Attention Overlay')
63
+ plt.axis('off')
64
+
65
+ plt.tight_layout()
66
+ plt.savefig('attention_visualization.png')
67
+ plt.close()
68
+
69
+ print(f"Attention visualization saved as 'attention_visualization.png'")
70
 
71
  def check_feature_similarity(self, image_path1, image_path2):
72
  """Compare embeddings of two images"""