kartikay24 commited on
Commit
4dca880
·
1 Parent(s): 56a5c69

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import requests
3
+ from PIL import Image
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration
5
+ import gradio as gr
6
+
7
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
8
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16).to("cuda")
9
+
10
+ # Function to process the image and generate captions
11
+ def generate_caption(image, caption_type, text):
12
+ raw_image = Image.fromarray(image.astype('uint8'), 'RGB')
13
+
14
+ if caption_type == "Conditional":
15
+ caption = conditional_image_captioning(raw_image, text)
16
+ else:
17
+ caption = unconditional_image_captioning(raw_image)
18
+
19
+ return caption
20
+
21
+ # Conditional image captioning
22
+ def conditional_image_captioning(raw_image, text):
23
+ inputs = processor(raw_image, text, return_tensors="pt").to("cuda", torch.float16)
24
+ out = model.generate(**inputs)
25
+ caption = processor.decode(out[0], skip_special_tokens=True)
26
+ return caption
27
+
28
+ # Unconditional image captioning
29
+ def unconditional_image_captioning(raw_image):
30
+ inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
31
+ out = model.generate(**inputs)
32
+ caption = processor.decode(out[0], skip_special_tokens=True)
33
+ return caption
34
+
35
+ # Interface setup
36
+ input_image = gr.inputs.Image()
37
+ input_text = gr.inputs.Textbox(label="Enter Text (for Conditional Captioning)")
38
+
39
+ choices = ["Conditional", "Unconditional"]
40
+ radio_button = gr.inputs.Radio(choices, label="Captioning Type")
41
+
42
+ output_text = gr.outputs.Textbox(label="Caption")
43
+
44
+ # Create the interface
45
+ gr.Interface(fn=generate_caption, inputs=[input_image, radio_button, input_text], outputs=output_text, title="Image Captioning",debug=True).launch()