Make min and max configurable on VQA
Browse files- models/blip_vqa.py +3 -3
models/blip_vqa.py
CHANGED
@@ -34,7 +34,7 @@ class BLIP_VQA(nn.Module):
|
|
34 |
self.text_decoder = BertLMHeadModel(config=decoder_config)
|
35 |
|
36 |
|
37 |
-
def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
|
38 |
|
39 |
image_embeds = self.visual_encoder(image)
|
40 |
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
@@ -97,8 +97,8 @@ class BLIP_VQA(nn.Module):
|
|
97 |
bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
|
98 |
|
99 |
outputs = self.text_decoder.generate(input_ids=bos_ids,
|
100 |
-
max_length=
|
101 |
-
min_length=
|
102 |
num_beams=num_beams,
|
103 |
eos_token_id=self.tokenizer.sep_token_id,
|
104 |
pad_token_id=self.tokenizer.pad_token_id,
|
|
|
34 |
self.text_decoder = BertLMHeadModel(config=decoder_config)
|
35 |
|
36 |
|
37 |
+
def forward(self, image, question, answer=None, n=None, weights=None, train=True, maxa_len=10, mina_len=1, inference='rank', k_test=128):
|
38 |
|
39 |
image_embeds = self.visual_encoder(image)
|
40 |
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
|
|
97 |
bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
|
98 |
|
99 |
outputs = self.text_decoder.generate(input_ids=bos_ids,
|
100 |
+
max_length=maxa_len,
|
101 |
+
min_length=mina_len,
|
102 |
num_beams=num_beams,
|
103 |
eos_token_id=self.tokenizer.sep_token_id,
|
104 |
pad_token_id=self.tokenizer.pad_token_id,
|