KeerthiVM commited on
Commit
04bab96
·
1 Parent(s): 975c276
Files changed (1) hide show
  1. SkinGPT.py +10 -11
SkinGPT.py CHANGED
@@ -22,6 +22,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
  class Blip2QFormer(nn.Module):
23
  def __init__(self, num_query_tokens=32, vision_width=1408):
24
  super().__init__()
 
25
  # Load pre-trained Q-Former config
26
  self.bert_config = BertConfig(
27
  vocab_size=30522,
@@ -74,26 +75,24 @@ class Blip2QFormer(nn.Module):
74
  visual_embeds = self.vision_proj(visual_features.float())
75
  print(f"Projected embeds stats - min: {visual_embeds.min().item():.4f}, max: {visual_embeds.max().item():.4f}")
76
  # visual_embeds = self.vision_proj(visual_features.float())
77
- visual_attention_mask = torch.ones(
78
- visual_embeds.size()[:-1],
79
- dtype=torch.long,
80
- device=visual_embeds.device
81
- )
82
 
83
  # Expand query tokens
84
  query_tokens = self.query_tokens.expand(visual_embeds.shape[0], -1, -1)
 
 
 
 
 
 
85
 
86
  # Forward through BERT
87
  outputs = self.bert(
88
- input_ids=None, # No text input
89
- attention_mask=None,
90
- inputs_embeds=query_tokens,
91
- encoder_hidden_states=visual_embeds,
92
- encoder_attention_mask=visual_attention_mask,
93
  return_dict=True
94
  )
95
 
96
- return outputs.last_hidden_state
97
 
98
 
99
 
 
22
  class Blip2QFormer(nn.Module):
23
  def __init__(self, num_query_tokens=32, vision_width=1408):
24
  super().__init__()
25
+ self.num_query_tokens = num_query_tokens
26
  # Load pre-trained Q-Former config
27
  self.bert_config = BertConfig(
28
  vocab_size=30522,
 
75
  visual_embeds = self.vision_proj(visual_features.float())
76
  print(f"Projected embeds stats - min: {visual_embeds.min().item():.4f}, max: {visual_embeds.max().item():.4f}")
77
  # visual_embeds = self.vision_proj(visual_features.float())
 
 
 
 
 
78
 
79
  # Expand query tokens
80
  query_tokens = self.query_tokens.expand(visual_embeds.shape[0], -1, -1)
81
+ combined_input = torch.cat([query_tokens, visual_embeds], dim=1)
82
+ attention_mask = torch.ones(
83
+ combined_input.size()[:-1],
84
+ dtype=torch.long,
85
+ device=combined_input.device
86
+ )
87
 
88
  # Forward through BERT
89
  outputs = self.bert(
90
+ attention_mask=attention_mask,
91
+ inputs_embeds=combined_input,
 
 
 
92
  return_dict=True
93
  )
94
 
95
+ return outputs.last_hidden_state[:, :self.num_query_tokens]
96
 
97
 
98