spuun commited on
Commit
ce9d052
·
1 Parent(s): ea64237

Make min and max configurable on VQA

Browse files
Files changed (1) hide show
  1. models/blip_vqa.py +3 -3
models/blip_vqa.py CHANGED
@@ -34,7 +34,7 @@ class BLIP_VQA(nn.Module):
34
  self.text_decoder = BertLMHeadModel(config=decoder_config)
35
 
36
 
37
- def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
38
 
39
  image_embeds = self.visual_encoder(image)
40
  image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
@@ -97,8 +97,8 @@ class BLIP_VQA(nn.Module):
97
  bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
98
 
99
  outputs = self.text_decoder.generate(input_ids=bos_ids,
100
- max_length=10,
101
- min_length=1,
102
  num_beams=num_beams,
103
  eos_token_id=self.tokenizer.sep_token_id,
104
  pad_token_id=self.tokenizer.pad_token_id,
 
34
  self.text_decoder = BertLMHeadModel(config=decoder_config)
35
 
36
 
37
+ def forward(self, image, question, answer=None, n=None, weights=None, train=True, maxa_len=10, mina_len=1, inference='rank', k_test=128):
38
 
39
  image_embeds = self.visual_encoder(image)
40
  image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
 
97
  bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
98
 
99
  outputs = self.text_decoder.generate(input_ids=bos_ids,
100
+ max_length=maxa_len,
101
+ min_length=mina_len,
102
  num_beams=num_beams,
103
  eos_token_id=self.tokenizer.sep_token_id,
104
  pad_token_id=self.tokenizer.pad_token_id,