Spaces:
Runtime error
Runtime error
Add VQA
Browse files- app_caption.py +1 -1
- app_vqa.py +2 -2
- prismer_model.py +1 -1
app_caption.py
CHANGED
@@ -15,7 +15,7 @@ def create_demo():
|
|
15 |
with gr.Row():
|
16 |
with gr.Column():
|
17 |
image = gr.Image(label='Input', type='filepath')
|
18 |
-
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base, Prismer-Large'], value='Prismer-Base')
|
19 |
run_button = gr.Button('Run')
|
20 |
with gr.Column(scale=1.5):
|
21 |
caption = gr.Text(label='Model Prediction')
|
|
|
15 |
with gr.Row():
|
16 |
with gr.Column():
|
17 |
image = gr.Image(label='Input', type='filepath')
|
18 |
+
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base', 'Prismer-Large'], value='Prismer-Base')
|
19 |
run_button = gr.Button('Run')
|
20 |
with gr.Column(scale=1.5):
|
21 |
caption = gr.Text(label='Model Prediction')
|
app_vqa.py
CHANGED
@@ -44,9 +44,9 @@ def create_demo():
|
|
44 |
gr.Examples(examples=examples,
|
45 |
inputs=inputs,
|
46 |
outputs=outputs,
|
47 |
-
fn=model.
|
48 |
|
49 |
-
run_button.click(fn=model.
|
50 |
|
51 |
|
52 |
if __name__ == '__main__':
|
|
|
44 |
gr.Examples(examples=examples,
|
45 |
inputs=inputs,
|
46 |
outputs=outputs,
|
47 |
+
fn=model.run_vqa)
|
48 |
|
49 |
+
run_button.click(fn=model.run_vqa, inputs=inputs, outputs=outputs)
|
50 |
|
51 |
|
52 |
if __name__ == '__main__':
|
prismer_model.py
CHANGED
@@ -145,7 +145,7 @@ class Model:
|
|
145 |
test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
|
146 |
experts, _ = next(iter(test_loader))
|
147 |
question = pre_question(question)
|
148 |
-
answer = self.model(experts, question, train=False, inference='generate')
|
149 |
answer = self.tokenizer(answer, max_length=30, padding='max_length', return_tensors='pt').input_ids
|
150 |
answer = answer.to(experts['rgb'].device)[0]
|
151 |
answer = self.tokenizer.decode(answer, skip_special_tokens=True)
|
|
|
145 |
test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
|
146 |
experts, _ = next(iter(test_loader))
|
147 |
question = pre_question(question)
|
148 |
+
answer = self.model(experts, [question], train=False, inference='generate')
|
149 |
answer = self.tokenizer(answer, max_length=30, padding='max_length', return_tensors='pt').input_ids
|
150 |
answer = answer.to(experts['rgb'].device)[0]
|
151 |
answer = self.tokenizer.decode(answer, skip_special_tokens=True)
|