image-analysis / app.py
zerishdorelser's picture
Upload 6 files
c46d8ad verified
raw
history blame
4.97 kB
import streamlit as st
from transformers import pipeline
from transformers import DetrImageProcessor, DetrForObjectDetection
from transformers import CLIPProcessor, CLIPModel
from transformers import BlipProcessor, BlipForQuestionAnswering
#from transformers import YolosImageProcessor, YolosForObjectDetection
from PIL import Image
from functions import *
import io
#load models
@st.cache_resource
def load_models():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50",revision="no_timm")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
sales_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
sales_model = BlipForQuestionAnswering.from_pretrained(
"Salesforce/blip-vqa-base",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
return {
"detector": model,
"processor": processor,
"clip": clip_model,
"clip process": clip_processor,
#"t5 token": t5_tokenizer,
#"t5": t5_model,
'story_teller': pipeline("text-generation", model="nickypro/tinyllama-15M"),
"sales process": sales_processor,
"sales model": sales_model,
"device": device
}
def main():
st.header("πŸ“± Nano AI Image Analyzer")
uploaded_file= st.file_uploader("upload image")#, type=['.PNG','png','jpg','jpeg'])
models= load_models()
st.write('models loaded')
#im2=detect_objects(image_path=image, models= models)
#st.write(im2)
#st.write("done")
#annotated_image= draw_bounding_boxes(image, im2)
#st.image(annotated_image, caption="Detected Objects", use_container_width=True)
#buttons UI
if uploaded_file is not None:
image_bytes = uploaded_file.getvalue()
st.write("Filename:", uploaded_file.name)
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption="Uploaded Image", width=200) #use_container_width= False,
col1, col2, col3 = st.columns([0.33,0.33,0.33])
with col1:
detect= st.button("πŸ” Detect Objects", key="btn1")
with col2:
describe= st.button("πŸ“ Describe Image", key="btn2")
with col3:
story= st.button("πŸ“– Generate Story", key="btn3",
help="story is generated based on caption")
if detect:
with st.spinner("Detecting objects..."):
try:
detections = detect_objects(image.copy(), models)
annotated_image= draw_bounding_boxes(image, detections)
st.image(annotated_image, caption="Detected Objects", use_column_width=True)
show_detection_table(detections)
except:
st.write("some error!! try another image")
elif describe:
with st.spinner("trying to describe..."):
description= get_image_description(image.copy(),models)
st.write(description)
elif story:
#st.write('btn3 clicked')
with st.spinner("getting a story..."):
description= get_image_description(image.copy(),models)
story= generate_story(description, models)
st.write(story)
# Chat interface
if "messages" not in st.session_state:
st.session_state.messages = []
chat_container = st.container(height=400)
with chat_container:
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Ask about the image"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = answer_question(image,
prompt,
models["sales process"],
models["sales model"],
models["device"])
#response= "response sample"
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()