Spaces:
Runtime error
Runtime error
import streamlit as st | |
import os | |
import cv2 | |
import tempfile | |
import zipfile | |
from PIL import Image | |
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer, pipeline | |
import torch | |
import pandas as pd | |
from nltk.corpus import wordnet | |
import nltk | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from datetime import datetime | |
import base64 | |
import io | |
nltk.download('wordnet') | |
nltk.download('omw-1.4') | |
# Load the pre-trained model for image captioning | |
model_name = "NourFakih/Vit-GPT2-COCO2017Flickr-85k-09" | |
model = VisionEncoderDecoderModel.from_pretrained(model_name) | |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model_sum_name = "google-t5/t5-base" | |
tokenizer_sum = AutoTokenizer.from_pretrained("google-t5/t5-base") | |
model_sum = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") | |
# Initialize the summarization model | |
summarize_pipe = pipeline("summarization", model=model_sum_name) | |
captured_images = [] | |
def generate_caption(image): | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
output_ids = model.generate(pixel_values) | |
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return caption | |
def get_synonyms(word): | |
synonyms = set() | |
for syn in wordnet.synsets(word): | |
for lemma in syn.lemmas(): | |
synonyms.add(lemma.name()) | |
return synonyms | |
def search_captions(query, captions): | |
query_words = query.split() | |
query_synonyms = set(query_words) | |
for word in query_words: | |
query_synonyms.update(get_synonyms(word)) | |
results = [] | |
for path, caption in captions.items(): | |
if any(word in caption.split() for word in query_synonyms): | |
results.append((path, caption)) | |
return results | |
def image_captioning_page(): | |
st.title("Image Gallery with Captioning and Search") | |
# Sidebar for search functionality | |
with st.sidebar: | |
query = st.text_input("Search images by caption:") | |
# Right side for folder path input and displaying images | |
option = st.selectbox("Select input method:", ["Folder Path", "Upload Images"]) | |
if option == "Folder Path": | |
folder_path = st.text_input("Enter the folder path containing images:") | |
image_files = [] | |
if folder_path and os.path.isdir(folder_path): | |
image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.lower().endswith(('png', 'jpg', 'jpeg'))] | |
else: | |
uploaded_files = st.file_uploader("Upload images or a zip file containing images:", type=['png', 'jpg', 'jpeg', 'zip'], accept_multiple_files=True) | |
image_files = [] | |
if uploaded_files: | |
for uploaded_file in uploaded_files: | |
if uploaded_file.name.endswith('.zip'): | |
with zipfile.ZipFile(uploaded_file, 'r') as zip_ref: | |
zip_ref.extractall("uploaded_images") | |
for file in zip_ref.namelist(): | |
if file.lower().endswith(('png', 'jpg', 'jpeg')): | |
image_files.append(os.path.join("uploaded_images", file)) | |
else: | |
if uploaded_file.name.lower().endswith(('png', 'jpg', 'jpeg')): | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) | |
temp_file.write(uploaded_file.read()) | |
image_files.append(temp_file.name) | |
captions = {} | |
if st.button("Generate Captions"): | |
for image_file in image_files: | |
try: | |
image = Image.open(image_file) | |
caption = generate_caption(image) | |
if option == "Folder Path": | |
captions[os.path.join(folder_path, os.path.basename(image_file))] = caption | |
else: | |
if image_file.startswith("uploaded_images"): | |
captions[image_file.replace("uploaded_images/", "")] = caption | |
else: | |
captions[os.path.basename(image_file)] = caption | |
except Exception as e: | |
st.error(f"Error processing {image_file}: {e}") | |
# Display images in a 4-column grid | |
st.subheader("Images and Captions:") | |
cols = st.columns(4) | |
idx = 0 | |
for image_path, caption in captions.items(): | |
col = cols[idx % 4] | |
with col: | |
try: | |
with open(image_path, "rb") as img_file: | |
img_bytes = img_file.read() | |
encoded_image = base64.b64encode(img_bytes).decode() | |
st.markdown( | |
f""" | |
<div style='text-align: center;'> | |
<img src='data:image/jpeg;base64,{encoded_image}' width='100%'> | |
<p>{caption}</p> | |
</div> | |
""", unsafe_allow_html=True) | |
except Exception as e: | |
st.error(f"Error displaying {image_path}: {e}") | |
idx += 1 | |
if query: | |
results = search_captions(query, captions) | |
st.write("Search Results:") | |
for image_path, caption in results: | |
try: | |
with open(image_path, "rb") as img_file: | |
img_bytes = img_file.read() | |
st.image(img_bytes, caption=caption, width=150) | |
st.write(caption) | |
except Exception as e: | |
st.error(f"Error displaying search result {image_path}: {e}") | |
# Save captions to Excel and provide a download button | |
df = pd.DataFrame(list(captions.items()), columns=['Image', 'Caption']) | |
excel_file = io.BytesIO() | |
df.to_excel(excel_file, index=False) | |
excel_file.seek(0) | |
st.download_button(label="Download captions as Excel", | |
data=excel_file, | |
file_name="captions.xlsx", | |
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") | |
def live_camera_captioning_page(): | |
st.title("Live Captioning with Webcam") | |
run = st.checkbox('Run') | |
FRAME_WINDOW = st.image([]) | |
if 'camera' not in st.session_state: | |
st.session_state.camera = cv2.VideoCapture(0) | |
if run: | |
while run: | |
ret, frame = st.session_state.camera.read() | |
if not ret: | |
st.write("Failed to capture image.") | |
break | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
FRAME_WINDOW.image(frame) | |
pil_image = Image.fromarray(frame) | |
caption = generate_caption(pil_image) | |
capture_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
captured_images.append((frame, caption, capture_time)) | |
st.write("Caption: ", caption) | |
cv2.waitKey(500) # Capture an image every 0.5 seconds | |
if not run and 'camera' in st.session_state: | |
st.session_state.camera.release() | |
del st.session_state.camera | |
st.sidebar.title("Search Captions") | |
query = st.sidebar.text_input("Enter a word to search in captions:") | |
if st.sidebar.button("Search"): | |
results = search_captions(query, captured_images) | |
if results: | |
st.subheader("Search Results:") | |
cols = st.columns(4) | |
for idx, (image, caption, capture_time) in enumerate(results): | |
col = cols[idx % 4] | |
with col: | |
st.image(image, caption=f"{caption}\n\n*{capture_time}*", width=150) | |
else: | |
st.write("No matching captions found.") | |
if st.button("Generate Report"): | |
if captured_images: | |
# Display captured images in a 4-column grid | |
st.subheader("Captured Images and Captions:") | |
cols = st.columns(4) | |
for idx, (image, caption, capture_time) in enumerate(captured_images): | |
col = cols[idx % 4] | |
with col: | |
st.image(image, caption=f"{caption}\n\n*{capture_time}*", width=150) | |
# Save captions to Excel and provide a download button | |
df = pd.DataFrame(captured_images, columns=['Image', 'Caption', 'Capture Time']) | |
excel_file = tempfile.NamedTemporaryFile(delete=False, suffix='.xlsx') | |
df.to_excel(excel_file.name, index=False) | |
st.download_button(label="Download Captions as Excel", | |
data=open(excel_file.name, 'rb').read(), | |
file_name="camera_captions.xlsx", | |
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") | |
# Summarize captions in groups of 10 | |
summaries = [] | |
for i in range(0, len(captured_images), 10): | |
batch_captions = " ".join([captured_images[j][1] for j in range(i, min(i+10, len(captured_images)))]) | |
summary = summarize_pipe(batch_captions)[0]['summary_text'] | |
summaries.append((captured_images[i][2], summary)) # Use the capture time of the first image in the batch | |
# Save summaries to Excel and provide a download button | |
df_summary = pd.DataFrame(summaries, columns=['Capture Time', 'Summary']) | |
summary_file = tempfile.NamedTemporaryFile(delete=False, suffix='.xlsx') | |
df_summary.to_excel(summary_file.name, index=False) | |
st.download_button(label="Download Summary Report", | |
data=open(summary_file.name, 'rb').read(), | |
file_name="camera_summary_report.xlsx", | |
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") | |
def video_captioning_page(): | |
st.title("Video Captioning") | |
# Sidebar for search functionality | |
with st.sidebar: | |
query = st.text_input("Search videos by caption:") | |
# Right side for folder path input and displaying videos | |
folder_path = st.text_input("Enter the folder path containing videos:") | |
if folder_path and os.path.isdir(folder_path): | |
video_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('mp4', 'avi', 'mov', 'mkv'))] | |
captions = {} | |
for video_file in video_files: | |
video_path = os.path.join(folder_path, video_file) | |
frames, captions_df = process_video(video_path, frame_interval=20) | |
if frames and not captions_df.empty: | |
generated_captions = ' '.join(captions_df['Caption']) | |
summary = summarize_pipe(generated_captions)[0]['summary_text'] | |
captions[video_path] = summary | |
# Display videos in a 4-column grid | |
cols = st.columns(4) | |
for idx, (video_path, summary) in enumerate(captions.items()): | |
with cols[idx % 4]: | |
st.video(video_path, caption=summary) | |
if query: | |
results = search_captions(query, captions) | |
st.write("Search Results:") | |
for video_path, summary in results: | |
st.video(video_path, caption=summary) | |
# Save captions to CSV and provide a download button | |
if st.button("Generate CSV"): | |
df = pd.DataFrame(list(captions.items()), columns=['Video', 'Caption']) | |
csv = df.to_csv(index=False) | |
st.download_button(label="Download captions as CSV", | |
data=csv, | |
file_name="captions.csv", | |
mime="text/csv") | |
def main(): | |
st.sidebar.title("Navigation") | |
page = st.sidebar.selectbox("Select a page", ["Image Captioning", "Live Camera Captioning", "Video Captioning"]) | |
if page == "Image Captioning": | |
image_captioning_page() | |
elif page == "Live Camera Captioning": | |
live_camera_captioning_page() | |
elif page == "Video Captioning": | |
video_captioning_page() | |
if __name__ == "__main__": | |
main() | |