James332 commited on
Commit
9b57d0e
·
1 Parent(s): bda1385

Upload app (1).py

Browse files
Files changed (1) hide show
  1. app (1).py +67 -0
app (1).py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import ViltProcessor, ViltForQuestionAnswering
3
+ from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering
4
+ from PIL import Image
5
+ import torch
6
+
7
+ dataset_name = "Multimodal-Fatima/OK-VQA_train"
8
+ original_model_name = "microsoft/git-base-vqav2"
9
+ model_name = "hyo37009/git-vqa-finetuned-on-ok-vqa"
10
+ model_path = "git-vqa-finetuned-on-ok-vqa"
11
+
12
+ questions = ["What can happen the objects shown are thrown on the ground?",
13
+ "What was the machine beside the bowl used for?",
14
+ "What kind of cars are in the photo?",
15
+ "What is the hairstyle of the blond called?",
16
+ "How old do you have to be in canada to do this?",
17
+ "Can you guess the place where the man is playing?",
18
+ "What loony tune character is in this photo?",
19
+ "Whose birthday is being celebrated?",
20
+ "Where can that toilet seat be bought?",
21
+ "What do you call the kind of pants that the man on the right is wearing?"]
22
+
23
+ processor = AutoProcessor.from_pretrained(model_path)
24
+ model = AutoModelForVisualQuestionAnswering.from_pretrained(model_path)
25
+
26
+
27
+ def main(select_exemple_num):
28
+ selectednum = select_exemple_num
29
+ exemple_img = f"image{selectednum}.jpg"
30
+ img = Image.open(exemple_img)
31
+ question = questions[selectednum - 1]
32
+
33
+ encoding = processor(img, question, return_tensors='pt')
34
+
35
+ outputs = model(**encoding)
36
+ logits = outputs.logits
37
+
38
+ # ---
39
+ output_str = 'pridicted : \n'
40
+ predicted_classes = torch.sigmoid(logits)
41
+
42
+ probs, classes = torch.topk(predicted_classes, 5)
43
+ ans = ''
44
+
45
+ for prob, class_idx in zip(probs.squeeze().tolist(), classes.squeeze().tolist()):
46
+ print(prob, model.config.id2label[class_idx])
47
+ output_str += str(prob)
48
+ output_str += " "
49
+ output_str += model.config.id2label[class_idx]
50
+ output_str += "\n"
51
+ if not ans:
52
+ ans = model.config.id2label[class_idx]
53
+
54
+ print(ans)
55
+ # ---
56
+ output_str += f"\nso I think it's answer is : \n{ans}"
57
+
58
+ return exemple_img, question, output_str
59
+
60
+
61
+ demo = gr.Interface(
62
+ fn=main,
63
+ inputs=[gr.Slider(1, len(questions), step=1)],
64
+ outputs=["image", "text", "text"],
65
+ )
66
+
67
+ demo.launch(share=True)