seanshahkarami commited on
Commit
c638c97
·
1 Parent(s): f7336b5
Files changed (1) hide show
  1. app.py +35 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViltProcessor, ViltForQuestionAnswering
2
+ import torch
3
+ import gradio as gr
4
+
5
+
6
+ processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
7
+ model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
8
+
9
+
10
+ def question_answer(image, questions):
11
+ questions = [line.strip() for line in questions.strip().splitlines()]
12
+
13
+ answers = []
14
+
15
+ for question in questions:
16
+ with torch.no_grad():
17
+ # prepare inputs
18
+ encoding = processor(image, question, return_tensors="pt")
19
+ # forward pass
20
+ outputs = model(**encoding)
21
+ logits = outputs.logits
22
+ idx = logits.argmax(-1).item()
23
+ answers.append(
24
+ "Question: " + question + "\nAnswer: " + model.config.id2label[idx]
25
+ )
26
+
27
+ return "\n\n".join(answers)
28
+
29
+
30
+ iface = gr.Interface(
31
+ fn=question_answer,
32
+ inputs=["image", "textarea"],
33
+ outputs=["textarea"],
34
+ )
35
+ iface.launch()