Sravanth commited on
Commit
74cf048
·
1 Parent(s): d91fe5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -3
app.py CHANGED
@@ -2,6 +2,8 @@ import torch
2
  import re
3
  import gradio as gr
4
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
 
 
5
 
6
  device='cpu'
7
  encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
@@ -18,12 +20,47 @@ def predict(image,max_length=64, num_beams=4):
18
  clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
19
  caption_ids = model.generate(image, max_length = max_length)[0]
20
  caption_text = clean_text(tokenizer.decode(caption_ids))
21
- return caption_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
 
25
  input = gr.inputs.Image(label="Upload your Image", type = 'pil', optional=True)
26
- output = gr.outputs.Textbox(type="text",label="Captions")
 
 
 
27
  examples = [f"example{i}.png" for i in range(1,4)]
28
 
29
  description= "Image caption Generator"
@@ -35,7 +72,7 @@ interface = gr.Interface(
35
  fn=predict,
36
  inputs = input,
37
  theme="grass",
38
- outputs=output,
39
  examples = examples,
40
  title=title,
41
  description=description,
 
2
  import re
3
  import gradio as gr
4
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
+ from transformers import AutoProcessor, AutoTokenizer, BlipForConditionalGeneration
6
+ from huggingface_hub import hf_hub_download
7
 
8
  device='cpu'
9
  encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
 
20
  clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
21
  caption_ids = model.generate(image, max_length = max_length)[0]
22
  caption_text = clean_text(tokenizer.decode(caption_ids))
23
+ caption_text2 = generate_caption(image)
24
+ return caption_text, caption_text2
25
+
26
+ blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
27
+ blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
28
+
29
+ blip_model_large.to(device)
30
+
31
+
32
+ def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
33
+ inputs = processor(images=image, return_tensors="pt").to(device)
34
+
35
+ if use_float_16:
36
+ inputs = inputs.to(torch.float16)
37
+
38
+ generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
39
+
40
+ if tokenizer is not None:
41
+ generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
42
+ else:
43
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
44
+
45
+ return generated_caption
46
+
47
+
48
+
49
+ def generate_captions(image):
50
+
51
+
52
+ caption_blip_large = generate_caption(blip_processor_large, blip_model_large, image)
53
+
54
+
55
+ return caption_blip_large
56
 
57
 
58
 
59
  input = gr.inputs.Image(label="Upload your Image", type = 'pil', optional=True)
60
+ #Two output boxes
61
+
62
+ output_1 = gr.outputs.Textbox(type="text",label="Caption - 1")
63
+ output_2 = gr.outputs.Textbox(type="text",label="Caption - 2")
64
  examples = [f"example{i}.png" for i in range(1,4)]
65
 
66
  description= "Image caption Generator"
 
72
  fn=predict,
73
  inputs = input,
74
  theme="grass",
75
+ outputs = [output_1,output_2],
76
  examples = examples,
77
  title=title,
78
  description=description,