Spaces:
Running
Running
Irpan
commited on
Commit
·
190650c
1
Parent(s):
205a1e3
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import ViltProcessor, ViltForQuestionAnswering,
|
3 |
import torch
|
4 |
|
5 |
import httpcore
|
@@ -9,8 +9,6 @@ from googletrans import Translator
|
|
9 |
from googletrans import LANGCODES
|
10 |
import re
|
11 |
|
12 |
-
torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkeys.jpg')
|
13 |
-
|
14 |
# List of acceptable languages
|
15 |
acceptable_languages = set(L.split()[0] for L in LANGCODES)
|
16 |
acceptable_languages.add("mandarin")
|
@@ -56,10 +54,7 @@ def remove_language_phrase(sentence):
|
|
56 |
return cleaned_sentence
|
57 |
|
58 |
def vqa(image, text):
|
59 |
-
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
60 |
-
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
61 |
encoding = vqa_processor(image, text, return_tensors="pt")
|
62 |
-
# forward pass
|
63 |
with torch.no_grad():
|
64 |
outputs = vqa_model(**encoding)
|
65 |
|
@@ -69,12 +64,39 @@ def vqa(image, text):
|
|
69 |
|
70 |
return predicted_answer
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
def main(image, text):
|
73 |
en_question, question_src_lang = google_translate(text, dest='en')
|
74 |
dest_lang = find_dest_language(en_question, question_src_lang)
|
75 |
cleaned_sentence = remove_language_phrase(en_question)
|
76 |
vqa_answer = vqa(image, cleaned_sentence)
|
77 |
-
|
|
|
78 |
return final_answer
|
79 |
|
80 |
|
@@ -84,7 +106,7 @@ answer = gr.Textbox(label="Predicted answer")
|
|
84 |
examples = [["monkeys.jpg", "How many monkeys are there, in French?"]]
|
85 |
|
86 |
title = "Cross-lingual VQA"
|
87 |
-
description = "
|
88 |
|
89 |
interface = gr.Interface(fn=main,
|
90 |
inputs=[image, question],
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import ViltProcessor, ViltForQuestionAnswering, AutoModelForSeq2SeqLM, AutoTokenizer
|
3 |
import torch
|
4 |
|
5 |
import httpcore
|
|
|
9 |
from googletrans import LANGCODES
|
10 |
import re
|
11 |
|
|
|
|
|
12 |
# List of acceptable languages
|
13 |
acceptable_languages = set(L.split()[0] for L in LANGCODES)
|
14 |
acceptable_languages.add("mandarin")
|
|
|
54 |
return cleaned_sentence
|
55 |
|
56 |
def vqa(image, text):
|
|
|
|
|
57 |
encoding = vqa_processor(image, text, return_tensors="pt")
|
|
|
58 |
with torch.no_grad():
|
59 |
outputs = vqa_model(**encoding)
|
60 |
|
|
|
64 |
|
65 |
return predicted_answer
|
66 |
|
67 |
+
|
68 |
+
def llm(cleaned_sentence, vqa_answer):
|
69 |
+
# Prepare the input prompt
|
70 |
+
prompt = (
|
71 |
+
f"A question: {cleaned_sentence}\n"
|
72 |
+
f"An answer: {vqa_answer}.\n"
|
73 |
+
f"Based on these, answer the question with a complete sentence without extra information."
|
74 |
+
)
|
75 |
+
inputs = flan_tokenizer(prompt, return_tensors="pt")
|
76 |
+
outputs = flan_model.generate(**inputs, max_length=50)
|
77 |
+
response = flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
78 |
+
print("T5 prompt: " + prompt)
|
79 |
+
print("T5 response: " + response)
|
80 |
+
return response
|
81 |
+
|
82 |
+
|
83 |
+
torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkeys.jpg')
|
84 |
+
|
85 |
+
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
86 |
+
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
87 |
+
|
88 |
+
flan_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
89 |
+
flan_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
def main(image, text):
|
94 |
en_question, question_src_lang = google_translate(text, dest='en')
|
95 |
dest_lang = find_dest_language(en_question, question_src_lang)
|
96 |
cleaned_sentence = remove_language_phrase(en_question)
|
97 |
vqa_answer = vqa(image, cleaned_sentence)
|
98 |
+
llm_answer = llm(cleaned_sentence, vqa_answer)
|
99 |
+
final_answer, _ = google_translate(llm_answer, dest=dest_lang)
|
100 |
return final_answer
|
101 |
|
102 |
|
|
|
106 |
examples = [["monkeys.jpg", "How many monkeys are there, in French?"]]
|
107 |
|
108 |
title = "Cross-lingual VQA"
|
109 |
+
description = "Visual Question Answering (VQA) across langages"
|
110 |
|
111 |
interface = gr.Interface(fn=main,
|
112 |
inputs=[image, question],
|