3-page-app / app.py
NourFakih's picture
Create app.py
da03ea9 verified
raw
history blame
9.07 kB
import streamlit as st
import cv2
from PIL import Image
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, pipeline, AutoModelForSeq2SeqLM
from datetime import datetime
import pandas as pd
import tempfile
from nltk.corpus import wordnet
import nltk
import base64
import spacy
from spacy.cli import download
from streamlit_option_menu import option_menu
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download necessary NLTK and spaCy data
nltk.download('wordnet')
nltk.download('omw-1.4')
# download("en_core_web_sm")
# Load the models
nlp = spacy.load("en_core_web_sm")
model_name = "NourFakih/Vit-GPT2-COCO2017Flickr-115k-12"
model = VisionEncoderDecoderModel.from_pretrained(model_name)
feature_extractor = ViTImageProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
tokenizer.pad_token = tokenizer.eos_token
# update the model config
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
image_captioner = pipeline('image-to-text', model=model_name)#, device=0)
model_sum_name = "google-t5/t5-base"
tokenizer_sum = AutoTokenizer.from_pretrained("google-t5/t5-base")
model_sum = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
summarize_pipe = pipeline("summarization", model=model_sum_name)
if 'captured_images' not in st.session_state:
st.session_state.captured_images = []
def generate_caption(image):
# pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
# pixel_values = pixel_values.to(device)
# output_ids = model.generate(pixel_values)
# caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
caption = image_captioner(image)
return caption
def get_synonyms(word):
synonyms = set()
for syn in wordnet.synsets(word):
for lemma in syn.lemmas():
synonyms.add(lemma.name())
return synonyms
def preprocess_query(query):
doc = nlp(query)
tokens = set()
for token in doc:
tokens.add(token.text)
tokens.add(token.lemma_)
tokens.update(get_synonyms(token.text))
return tokens
def search_captions(query, captions):
query_tokens = preprocess_query(query)
results = []
for img_str, caption, capture_time in captions:
caption_tokens = preprocess_query(caption)
if query_tokens & caption_tokens:
results.append((img_str, caption, capture_time))
return results
def add_image_to_state(image, caption, capture_time):
img_str = base64.b64encode(cv2.imencode('.jpg', image)[1]).decode()
if len(st.session_state.captured_images) < 20: # Limit to 20 images
st.session_state.captured_images.append((img_str, caption, capture_time))
def page_image_captioning():
st.title("Image Captioning")
# Your image captioning code here
def page_video_captioning():
st.title("Video Captioning")
# Your video captioning code here
def page_webcam_capture():
st.title("Live Captioning with Webcam")
run = st.checkbox('Run')
stop = st.button('Stop')
FRAME_WINDOW = st.image([])
if 'camera' not in st.session_state:
st.session_state.camera = cv2.VideoCapture(0)
if run:
while run:
ret, frame = st.session_state.camera.read()
if not ret:
st.write("Failed to capture image.")
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
FRAME_WINDOW.image(frame)
pil_image = Image.fromarray(frame)
caption = generate_caption(pil_image)
capture_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
add_image_to_state(frame, caption, capture_time)
st.write(f"Caption: {caption}")
if cv2.waitKey(500) & 0xFF == ord('q'):
break
if stop and 'camera' in st.session_state:
st.session_state.camera.release()
del st.session_state.camera
st.write("Camera stopped.")
# Display the collected data
if st.session_state.captured_images:
df = pd.DataFrame(st.session_state.captured_images, columns=['Image', 'Caption', 'Capture Time'])
st.table(df[['Capture Time', 'Caption']])
else:
st.write("No images captured.")
st.sidebar.title("Search Captions")
query = st.sidebar.text_input("Enter a word to search in captions:")
if st.sidebar.button("Search"):
results = search_captions(query, st.session_state.captured_images)
if results:
st.subheader("Search Results:")
cols = st.columns(4)
for idx, (img_str, caption, capture_time) in enumerate(results):
col = cols[idx % 4]
with col:
img_data = base64.b64decode(img_str)
img = Image.open(tempfile.NamedTemporaryFile(delete=False, suffix='.jpg', mode='wb').write(img_data))
st.image(img, caption=f"{caption}\n\n*{capture_time}*", width=150)
else:
st.write("No matching captions found.")
if st.sidebar.button("Generate Report"):
if 'camera' in st.session_state:
st.session_state.camera.release()
del st.session_state.camera
if st.session_state.captured_images:
# Display captured images in a 4-column grid
st.subheader("Captured Images and Captions:")
cols = st.columns(4)
for idx, (img_str, caption, capture_time) in enumerate(st.session_state.captured_images):
col = cols[idx % 4]
with col:
img_data = base64.b64decode(img_str)
img = Image.open(tempfile.NamedTemporaryFile(delete=False, suffix='.jpg', mode='wb').write(img_data))
st.image(img, caption=f"{caption}\n\n*{capture_time}*", width=150)
# Save captions to Excel and provide a download button
df = pd.DataFrame(st.session_state.captured_images, columns=['Image', 'Caption', 'Capture Time'])
df['Image'] = df['Image'].apply(lambda x: f'<img src="data:image/jpeg;base64,{x}"/>')
excel_file = tempfile.NamedTemporaryFile(delete=False, suffix='.xlsx')
df.to_excel(excel_file.name, index=False)
st.sidebar.download_button(label="Download Captions as Excel",
data=open(excel_file.name, 'rb').read(),
file_name="camera_captions.xlsx",
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
# Summarize captions in groups of 10
summaries = []
for i in range(0, len(st.session_state.captured_images), 10):
batch_captions = " ".join([st.session_state.captured_images[j][1] for j in range(i, min(i+10, len(st.session_state.captured_images)))] )
summary = summarize_pipe(batch_captions)[0]['summary_text']
summaries.append((st.session_state.captured_images[i][2], summary)) # Use the capture time of the first image in the batch
# Save summaries to Excel and provide a download button
df_summary = pd.DataFrame(summaries, columns=['Capture Time', 'Summary'])
summary_file = tempfile.NamedTemporaryFile(delete=False, suffix='.xlsx')
df_summary.to_excel(summary_file.name, index=False)
st.sidebar.download_button(label="Download Summary Report",
data=open(summary_file.name, 'rb').read(),
file_name="camera_summary_report.xlsx",
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
def main():
st.session_state.active_page = st.session_state.get("active_page", "Image Captioning")
# Sidebar for navigation
with st.sidebar:
selected = option_menu(
menu_title="Main Menu",
options=["Image Captioning", "Video Captioning", "Webcam Captioning"],
icons=["image", "Caret-right-square", "camera"],
menu_icon="cast",
default_index=0,
)
if selected != st.session_state.active_page:
handle_page_switch(selected)
if selected == "Image Captioning":
page_image_captioning()
elif selected == "Video Captioning":
page_video_captioning()
elif selected == "Webcam Captioning":
page_webcam_capture()
def handle_page_switch(selected_page):
if st.session_state.active_page == "Webcam Captioning" and "camera" in st.session_state:
st.session_state.camera.release()
del st.session_state.camera
st.session_state.active_page = selected_page
if __name__ == "__main__":
main()