Irpan commited on
Commit
190650c
·
1 Parent(s): 205a1e3
Files changed (1) hide show
  1. app.py +30 -8
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import ViltProcessor, ViltForQuestionAnswering, pipeline
3
  import torch
4
 
5
  import httpcore
@@ -9,8 +9,6 @@ from googletrans import Translator
9
  from googletrans import LANGCODES
10
  import re
11
 
12
- torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkeys.jpg')
13
-
14
  # List of acceptable languages
15
  acceptable_languages = set(L.split()[0] for L in LANGCODES)
16
  acceptable_languages.add("mandarin")
@@ -56,10 +54,7 @@ def remove_language_phrase(sentence):
56
  return cleaned_sentence
57
 
58
  def vqa(image, text):
59
- vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
60
- vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
61
  encoding = vqa_processor(image, text, return_tensors="pt")
62
- # forward pass
63
  with torch.no_grad():
64
  outputs = vqa_model(**encoding)
65
 
@@ -69,12 +64,39 @@ def vqa(image, text):
69
 
70
  return predicted_answer
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def main(image, text):
73
  en_question, question_src_lang = google_translate(text, dest='en')
74
  dest_lang = find_dest_language(en_question, question_src_lang)
75
  cleaned_sentence = remove_language_phrase(en_question)
76
  vqa_answer = vqa(image, cleaned_sentence)
77
- final_answer, _ = google_translate(vqa_answer, dest=dest_lang)
 
78
  return final_answer
79
 
80
 
@@ -84,7 +106,7 @@ answer = gr.Textbox(label="Predicted answer")
84
  examples = [["monkeys.jpg", "How many monkeys are there, in French?"]]
85
 
86
  title = "Cross-lingual VQA"
87
- description = "ViLT (Vision and Language Transformer), fine-tuned on VQAv2 "
88
 
89
  interface = gr.Interface(fn=main,
90
  inputs=[image, question],
 
1
  import gradio as gr
2
+ from transformers import ViltProcessor, ViltForQuestionAnswering, AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
4
 
5
  import httpcore
 
9
  from googletrans import LANGCODES
10
  import re
11
 
 
 
12
  # List of acceptable languages
13
  acceptable_languages = set(L.split()[0] for L in LANGCODES)
14
  acceptable_languages.add("mandarin")
 
54
  return cleaned_sentence
55
 
56
  def vqa(image, text):
 
 
57
  encoding = vqa_processor(image, text, return_tensors="pt")
 
58
  with torch.no_grad():
59
  outputs = vqa_model(**encoding)
60
 
 
64
 
65
  return predicted_answer
66
 
67
+
68
+ def llm(cleaned_sentence, vqa_answer):
69
+ # Prepare the input prompt
70
+ prompt = (
71
+ f"A question: {cleaned_sentence}\n"
72
+ f"An answer: {vqa_answer}.\n"
73
+ f"Based on these, answer the question with a complete sentence without extra information."
74
+ )
75
+ inputs = flan_tokenizer(prompt, return_tensors="pt")
76
+ outputs = flan_model.generate(**inputs, max_length=50)
77
+ response = flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
78
+ print("T5 prompt: " + prompt)
79
+ print("T5 response: " + response)
80
+ return response
81
+
82
+
83
+ torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkeys.jpg')
84
+
85
+ vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
86
+ vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
87
+
88
+ flan_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
89
+ flan_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
90
+
91
+
92
+
93
  def main(image, text):
94
  en_question, question_src_lang = google_translate(text, dest='en')
95
  dest_lang = find_dest_language(en_question, question_src_lang)
96
  cleaned_sentence = remove_language_phrase(en_question)
97
  vqa_answer = vqa(image, cleaned_sentence)
98
+ llm_answer = llm(cleaned_sentence, vqa_answer)
99
+ final_answer, _ = google_translate(llm_answer, dest=dest_lang)
100
  return final_answer
101
 
102
 
 
106
  examples = [["monkeys.jpg", "How many monkeys are there, in French?"]]
107
 
108
  title = "Cross-lingual VQA"
109
+ description = "Visual Question Answering (VQA) across langages"
110
 
111
  interface = gr.Interface(fn=main,
112
  inputs=[image, question],