KeerthiVM commited on
Commit
863dd32
·
1 Parent(s): 0e15733
Files changed (2) hide show
  1. 1.jpg +0 -0
  2. test.py +2 -2
1.jpg CHANGED
test.py CHANGED
@@ -34,7 +34,7 @@ class SkinGPTTester:
34
  return
35
 
36
  # Get the last layer's attention
37
- attention = self.classifier.model.q_former.last_attention[0].mean(dim=0)
38
 
39
  # Print attention shape for debugging
40
  print(f"Attention shape: {attention.shape}")
@@ -124,7 +124,7 @@ class SkinGPTTester:
124
  # Extract visual features using attention
125
  with torch.no_grad():
126
  image_tensor = self.transform(image).unsqueeze(0).to(self.device)
127
- attention = self.classifier.model.q_former.last_attention[0].mean(dim=0)
128
 
129
  # Get regions with high attention
130
  attention = attention.reshape(int(math.sqrt(attention.shape[1])), -1)
 
34
  return
35
 
36
  # Get the last layer's attention
37
+ attention = self.classifier.model.q_former.last_attention[0][0] # shape: [num_tokens,]
38
 
39
  # Print attention shape for debugging
40
  print(f"Attention shape: {attention.shape}")
 
124
  # Extract visual features using attention
125
  with torch.no_grad():
126
  image_tensor = self.transform(image).unsqueeze(0).to(self.device)
127
+ attention = self.classifier.model.q_former.last_attention[0][0]
128
 
129
  # Get regions with high attention
130
  attention = attention.reshape(int(math.sqrt(attention.shape[1])), -1)