ShreyMehra commited on
Commit
702d337
·
unverified ·
1 Parent(s): c7ccceb

Add files via upload

Browse files
Files changed (1) hide show
  1. app2.py +73 -0
app2.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import random
3
+ import requests
4
+ import io
5
+ from PIL import Image
6
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
7
+ from peft import PeftModel, PeftConfig
8
+ import torch
9
+
10
+
11
+ model = None
12
+ processor = None
13
+
14
+ st.title("Image Captioner - Caption the images")
15
+ st.markdown("Link to the model - [Image-to-Caption-App on 🤗 Spaces](https://huggingface.co/spaces/Shrey23/Image-Captioning)")
16
+
17
+
18
+ class UI:
19
+ def __init__(self):
20
+ model = Model()
21
+ model.load_model()
22
+
23
+ def displayUI(self):
24
+ image = st.file_uploader(label = "Upload your image here",type=['png','jpg','jpeg'])
25
+ if image is not None:
26
+
27
+ input_image = Image.open(image) #read image
28
+ st.image(input_image) #display image
29
+
30
+ with st.spinner("🤖 AI is at Work! "):
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
33
+ pixel_values = inputs.pixel_values
34
+
35
+
36
+ generated_ids = model.generate(pixel_values=pixel_values, max_length=25)
37
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
38
+
39
+ st.write(generated_caption)
40
+
41
+ st.success("Here you go!")
42
+ st.balloons()
43
+ else:
44
+ st.write("Upload an Image")
45
+
46
+ st.caption("Made with ❤️ by @1littlecoder. Credits to 🤗 Spaces for Hosting this ")
47
+
48
+
49
+ class Model:
50
+ def load_model():
51
+ peft_model_id = "Shrey23/Image-Captioning"
52
+ config = PeftConfig.from_pretrained(peft_model_id)
53
+ global model
54
+ global processor
55
+ model = Blip2ForConditionalGeneration.from_pretrained(config.base_model_name_or_path, torch_dtype=torch.float16) #, device_map="auto", load_in_8bit=True
56
+ model = PeftModel.from_pretrained(model, peft_model_id)
57
+ processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
58
+
59
+ def query(self , payload):
60
+ response = requests.post(self.API_URL, headers=self.headers, json=payload)
61
+ return response.content
62
+
63
+ def generate_response(self, prompt):
64
+ image_bytes = self.query({ "inputs": prompt, })
65
+ return io.BytesIO(image_bytes)
66
+
67
+
68
+ def main():
69
+ ui = UI()
70
+ ui.displayUI()
71
+
72
+ if __name__ == "__main__":
73
+ main()