|
import streamlit as st |
|
from transformers import pipeline |
|
from PIL import Image |
|
|
|
import torch |
|
|
|
|
|
|
|
initial_caption_pipe = pipeline('image-to-text', model="Salesforce/blip-image-captioning-large") |
|
|
|
|
|
uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) |
|
if uploaded_image is not None: |
|
image= Image.open(uploaded_image) |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
image = Image.open(uploaded_image) |
|
initial_caption = initial_caption_pipe(image) |
|
initial_caption = initial_caption[0]['generated_text'] |
|
|
|
|
|
from transformers import CLIPProcessor, CLIPModel |
|
model_id = "openai/clip-vit-large-patch14" |
|
processor = CLIPProcessor.from_pretrained(model_id) |
|
model = CLIPModel.from_pretrained(model_id) |
|
scene_labels=['Arrest', |
|
'Arson', |
|
'Explosion', |
|
'public fight', |
|
'Normal', |
|
'Road Accident', |
|
'Robbery', |
|
'Shooting', |
|
'Stealing', |
|
'Vandalism', |
|
'Suspicious activity', |
|
'Tailgating', |
|
'Unauthorized entry', |
|
'Protest/Demonstration', |
|
'Drone suspicious activity', |
|
'Fire/Smoke detection', |
|
'Medical emergency', |
|
'Suspicious package/object', |
|
'Threatening', |
|
'Attack', |
|
'Shoplifting', |
|
'burglary ', |
|
'distress', |
|
'assault'] |
|
image = Image.open(uploaded_image) |
|
inputs = processor(text=scene_labels, images=image, return_tensors="pt", padding=True) |
|
outputs = model(**inputs) |
|
logits_per_image = outputs.logits_per_image |
|
probs = logits_per_image.softmax(dim=1) |
|
context_raw= scene_labels[probs.argmax(-1)] |
|
context= 'the image is depicting scene of '+ context_raw |
|
|
|
|
|
GOOGLE_API_KEY = st.text_input("Please enter your GOOGLE GEMINI API KEY", type="password") |
|
os.environ['GOOGLE_API_KEY'] = GOOGLE_API_KEY |
|
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain.prompts import PromptTemplate |
|
from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory |
|
llm = ChatGoogleGenerativeAI(model="gemini-1.0-pro-latest", google_api_key=GOOGLE_API_KEY, temperature=0.2, top_p=1, top_k=1, |
|
safety_settings={ |
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, |
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, |
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, |
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, |
|
|
|
}, |
|
) |
|
template="""You are an advanced image captioning AI assistant for surveillance related images. |
|
Your task is to refine and improve an initial image caption using relevant contextual information provided. |
|
You will receive two inputs: |
|
Input 1: {initial_caption} - This is the initial caption for the image, most likely grammatically incorrect |
|
and incomplete sentence, generated by a separate not so good image captioning model. |
|
Input 2: {context} - This is the contextual information that provides more details about the background |
|
Your goal is to take the initial caption and the additional context, and produce a new, refined caption that |
|
incorporates the contextual details. |
|
Please do not speculate things which are not provided. The final caption should be grammatically correct. |
|
Please output only the final caption.""" |
|
|
|
prompt_template = PromptTemplate( |
|
template=template, |
|
input_variables=["initial_caption", "context"], |
|
) |
|
|
|
prompt=prompt_template.format(initial_caption=initial_caption, context=context) |
|
response = llm.invoke(prompt) |
|
final_caption = response.content |
|
|
|
|
|
if st.button("Generate Caption"): |
|
st.write(final_caption) |
|
|