Spaces:
Sleeping
Sleeping
import streamlit as st | |
import cv2 | |
import pandas as pd | |
from PIL import Image | |
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, pipeline, AutoModelForSeq2SeqLM | |
import nltk | |
import tempfile | |
from nltk.corpus import wordnet | |
import spacy | |
from spacy.cli import download | |
import base64 | |
import numpy as np | |
import datetime | |
from streamlit_option_menu import option_menu | |
# Download necessary NLP models | |
nltk.download('wordnet') | |
nltk.download('omw-1.4') | |
download("en_core_web_sm") | |
nlp = spacy.load("en_core_web_sm") | |
# Load the pre-trained models for image captioning and summarization | |
model_name = "NourFakih/Vit-GPT2-COCO2017Flickr-85k-09" | |
model = VisionEncoderDecoderModel.from_pretrained(model_name) | |
feature_extractor = ViTImageProcessor.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# GPT2 only has bos/eos tokens but not decoder_start/pad tokens | |
tokenizer.pad_token = tokenizer.eos_token | |
# update the model config | |
model.config.eos_token_id = tokenizer.eos_token_id | |
model.config.decoder_start_token_id = tokenizer.bos_token_id | |
model.config.pad_token_id = tokenizer.pad_token_id | |
model_sum_name = "google-t5/t5-base" | |
tokenizer_sum = AutoTokenizer.from_pretrained("google-t5/t5-base") | |
model_sum = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") | |
summarize_pipe = pipeline("summarization", model=model_sum_name) | |
if 'captured_images' not in st.session_state: | |
st.session_state.captured_images = [] | |
def generate_caption(image): | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
output_ids = model.generate(pixel_values) | |
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return caption | |
def get_synonyms(word): | |
synonyms = set() | |
for syn in wordnet.synsets(word): | |
for lemma in syn.lemmas(): | |
synonyms.add(lemma.name()) | |
return synonyms | |
def preprocess_query(query): | |
doc = nlp(query) | |
tokens = set() | |
for token in doc: | |
tokens.add(token.text) | |
tokens.add(token.lemma_) | |
tokens.update(get_synonyms(token.text)) | |
return tokens | |
def search_captions(query, captions): | |
query_tokens = preprocess_query(query) | |
results = [] | |
for img_str, caption, capture_time in captions: | |
caption_tokens = preprocess_query(caption) | |
if query_tokens & caption_tokens: | |
results.append((img_str, caption, capture_time)) | |
return results | |
def add_image_to_state(image, caption, capture_time): | |
img_str = base64.b64encode(cv2.imencode('.jpg', image)[1]).decode() | |
if len(st.session_state.captured_images) < 20: | |
st.session_state.captured_images.append((img_str, caption, capture_time)) | |
def page_image_captioning(): | |
st.title("Image Captioning") | |
st.write("Your image captioning code here") | |
def page_video_captioning(): | |
st.title("Video Captioning") | |
st.write("Your video captioning code here") | |
def page_webcam_capture(): | |
st.title("Live Captioning with Webcam") | |
img_file = st.camera_input("Capture an image") | |
if img_file: | |
img = Image.open(img_file) | |
img_array = np.array(img) | |
caption = generate_caption(img) | |
capture_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
add_image_to_state(img_array, caption, capture_time) | |
st.image(img, caption=f"Caption: {caption}") | |
if st.button('Stop'): | |
st.write("Camera stopped.") | |
if st.session_state.captured_images: | |
df = pd.DataFrame(st.session_state.captured_images, columns=['Image', 'Caption', 'Capture Time']) | |
st.table(df[['Capture Time', 'Caption']]) | |
else: | |
st.write("No images captured.") | |
st.sidebar.title("Search Captions") | |
query = st.sidebar.text_input("Enter a word to search in captions:") | |
if st.sidebar.button("Search"): | |
results = search_captions(query, st.session_state.captured_images) | |
if results: | |
st.subheader("Search Results:") | |
cols = st.columns(4) | |
for idx, (img_str, caption, capture_time) in enumerate(results): | |
col = cols[idx % 4] | |
with col: | |
img_data = base64.b64decode(img_str) | |
img = Image.open(tempfile.NamedTemporaryFile(delete=False, suffix='.jpg', mode='wb').write(img_data)) | |
st.image(img, caption=f"{caption}\n\n*{capture_time}*", width=150) | |
else: | |
st.write("No matching captions found.") | |
if st.sidebar.button("Generate Report"): | |
if st.session_state.captured_images: | |
st.subheader("Captured Images and Captions:") | |
cols = st.columns(4) | |
for idx, (img_str, caption, capture_time) in enumerate(st.session_state.captured_images): | |
col = cols[idx % 4] | |
with col: | |
img_data = base64.b64decode(img_str) | |
img = Image.open(tempfile.NamedTemporaryFile(delete=False, suffix='.jpg', mode='wb').write(img_data)) | |
st.image(img, caption=f"{caption}\n\n*{capture_time}*", width=150) | |
df = pd.DataFrame(st.session_state.captured_images, columns=['Image', 'Caption', 'Capture Time']) | |
df['Image'] = df['Image'].apply(lambda x: f'<img src="data:image/jpeg;base64,{x}"/>') | |
excel_file = tempfile.NamedTemporaryFile(delete=False, suffix='.xlsx') | |
df.to_excel(excel_file.name, index=False) | |
st.sidebar.download_button(label="Download Captions as Excel", | |
data=open(excel_file.name, 'rb').read(), | |
file_name="camera_captions.xlsx", | |
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") | |
summaries = [] | |
for i in range(0, len(st.session_state.captured_images), 10): | |
batch_captions = " ".join([st.session_state.captured_images[j][1] for j in range(i, min(i+10, len(st.session_state.captured_images)))] ) | |
summary = summarize_pipe(batch_captions)[0]['summary_text'] | |
summaries.append((st.session_state.captured_images[i][2], summary)) | |
df_summary = pd.DataFrame(summaries, columns=['Capture Time', 'Summary']) | |
summary_file = tempfile.NamedTemporaryFile(delete=False, suffix='.xlsx') | |
df_summary.to_excel(summary_file.name, index=False) | |
st.sidebar.download_button(label="Download Summary Report", | |
data=open(summary_file.name, 'rb').read(), | |
file_name="camera_summary_report.xlsx", | |
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") | |
def main(): | |
st.session_state.active_page = st.session_state.get("active_page", "Image Captioning") | |
with st.sidebar: | |
selected = option_menu( | |
menu_title="Main Menu", | |
options=["Image Captioning", "Video Captioning", "Webcam Captioning"], | |
icons=["image", "Caret-right-square", "camera"], | |
menu_icon="cast", | |
default_index=0, | |
) | |
if selected != st.session_state.active_page: | |
handle_page_switch(selected) | |
if selected == "Image Captioning": | |
page_image_captioning() | |
elif selected == "Video Captioning": | |
page_video_captioning() | |
elif selected == "Webcam Captioning": | |
page_webcam_capture() | |
def handle_page_switch(selected_page): | |
st.session_state.active_page = selected_page | |
if __name__ == "__main__": | |
main() | |