Irpan commited on
Commit
290c136
·
1 Parent(s): b063473
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import ViltProcessor, ViltForQuestionAnswering
3
  import torch
4
 
5
  import httpcore
@@ -11,9 +11,6 @@ import re
11
 
12
  torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkeys.jpg')
13
 
14
- processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
15
- model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
16
-
17
  # List of acceptable languages
18
  acceptable_languages = set(L.split()[0] for L in LANGCODES)
19
  acceptable_languages.add("mandarin")
@@ -59,14 +56,16 @@ def remove_language_phrase(sentence):
59
  return cleaned_sentence
60
 
61
  def vqa(image, text):
62
- encoding = processor(image, text, return_tensors="pt")
 
 
63
  # forward pass
64
  with torch.no_grad():
65
- outputs = model(**encoding)
66
 
67
  logits = outputs.logits
68
  idx = logits.argmax(-1).item()
69
- predicted_answer = model.config.id2label[idx]
70
 
71
  return predicted_answer
72
 
@@ -78,9 +77,9 @@ def main(image, text):
78
  return vqa_answer
79
 
80
 
81
- image = gr.inputs.Image(type="pil")
82
- question = gr.inputs.Textbox(label="Question")
83
- answer = gr.outputs.Textbox(label="Predicted answer")
84
  examples = [["monkeys.jpg", "How many monkeys are there, in French?"]]
85
 
86
  title = "Cross-lingual VQA"
 
1
  import gradio as gr
2
+ from transformers import ViltProcessor, ViltForQuestionAnswering, pipeline
3
  import torch
4
 
5
  import httpcore
 
11
 
12
  torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkeys.jpg')
13
 
 
 
 
14
  # List of acceptable languages
15
  acceptable_languages = set(L.split()[0] for L in LANGCODES)
16
  acceptable_languages.add("mandarin")
 
56
  return cleaned_sentence
57
 
58
  def vqa(image, text):
59
+ vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
60
+ vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
61
+ encoding = vqa_processor(image, text, return_tensors="pt")
62
  # forward pass
63
  with torch.no_grad():
64
+ outputs = vqa_model(**encoding)
65
 
66
  logits = outputs.logits
67
  idx = logits.argmax(-1).item()
68
+ predicted_answer = vqa_model.config.id2label[idx]
69
 
70
  return predicted_answer
71
 
 
77
  return vqa_answer
78
 
79
 
80
+ image = gr.Image(type="pil")
81
+ question = gr.Textbox(label="Question")
82
+ answer = gr.Textbox(label="Predicted answer")
83
  examples = [["monkeys.jpg", "How many monkeys are there, in French?"]]
84
 
85
  title = "Cross-lingual VQA"