mouaddb commited on
Commit
914a76b
·
1 Parent(s): 365c9f0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, VisionEncoderDecoderModel
3
+ import torch
4
+
5
+ git_processor_base = AutoProcessor.from_pretrained("microsoft/git-base-coco")
6
+ git_model_base = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
7
+
8
+ git_processor_large = AutoProcessor.from_pretrained("microsoft/git-large-coco")
9
+ git_model_large = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
10
+
11
+ blip_processor_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
12
+ blip_model_base = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
13
+
14
+ blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
15
+ blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
16
+
17
+ vitgpt_processor = AutoImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
18
+ vitgpt_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
19
+ vitgpt_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
20
+
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ git_model_base.to(device)
24
+ blip_model_base.to(device)
25
+ git_model_large.to(device)
26
+ blip_model_large.to(device)
27
+ vitgpt_model.to(device)
28
+
29
+ def generate_caption(processor, model, image, tokenizer=None):
30
+ inputs = processor(images=image, return_tensors="pt").to(device)
31
+
32
+ generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
33
+
34
+ if tokenizer is not None:
35
+ generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
+ else:
37
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
38
+
39
+ return generated_caption
40
+
41
+
42
+ def generate_captions(image):
43
+ caption_git_base = generate_caption(git_processor_base, git_model_base, image)
44
+
45
+ caption_git_large = generate_caption(git_processor_large, git_model_large, image)
46
+
47
+ caption_blip_base = generate_caption(blip_processor_base, blip_model_base, image)
48
+
49
+ caption_blip_large = generate_caption(blip_processor_large, blip_model_large, image)
50
+
51
+ caption_vitgpt = generate_caption(vitgpt_processor, vitgpt_model, image, vitgpt_tokenizer)
52
+
53
+ return caption_git_base, caption_git_large, caption_blip_base, caption_blip_large, caption_vitgpt
54
+
55
+
56
+ examples = [["test-1.jpeg"], ["test-2.jpeg"], ["test-3.jpeg"], ["test-4.jpeg"], ["test-5.jpeg"], ["test-6.jpg"]]
57
+ outputs = [gr.outputs.Textbox(label="Caption generated by GIT-base"), gr.outputs.Textbox(label="Caption generated by GIT-large"), gr.outputs.Textbox(label="Caption generated by BLIP-base"), gr.outputs.Textbox(label="Caption generated by BLIP-large"), gr.outputs.Textbox(label="Caption generated by ViT+GPT-2")]
58
+
59
+
60
+ interface = gr.Interface(fn=generate_captions,
61
+ inputs=gr.inputs.Image(type="pil"),
62
+ outputs=outputs,
63
+ examples=examples,
64
+ enable_queue=True)
65
+ interface.launch(debug=True)