import streamlit as st |
import os |
import cv2 |
import tempfile |
from PIL import Image |
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, pipeline |
import torch |
import pandas as pd |
from nltk.corpus import wordnet |
import nltk |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
nltk.download('wordnet') |
nltk.download('omw-1.4') |
model_name = "NourFakih/Vit-GPT2-COCO2017Flickr-85k-09" |
model = VisionEncoderDecoderModel.from_pretrained(model_name) |
model.config.pad_token_id |
feature_extractor = ViTImageProcessor.from_pretrained(model_name) |
tokenizer = AutoTokenizer.from_pretrained(model_name) |
model_sum_name = "google-t5/t5-base" |
tokenizer_sum = AutoTokenizer.from_pretrained("google-t5/t5-base") |
model_sum = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") |
summarize_pipe = pipeline("summarization", model=model_sum_name) |
def generate_caption(image): |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values |
output_ids = model.generate(pixel_values) |
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
return caption |
def get_synonyms(word): |
synonyms = set() |
for syn in wordnet.synsets(word): |
for lemma in syn.lemmas(): |
synonyms.add(lemma.name()) |
return synonyms |
def search_captions(query, captions): |
query_words = query.split() |
query_synonyms = set(query_words) |
for word in query_words: |
query_synonyms.update(get_synonyms(word)) |
results = [] |
for path, caption in captions.items(): |
if any(word in caption.split() for word in query_synonyms): |
results.append((path, caption)) |
return results |
def convert_frame_to_pil(frame): |
return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
def process_video(video_path, frame_interval): |
cap = cv2.VideoCapture(video_path) |
if not cap.isOpened(): |
st.error("Error: Could not open video file.") |
return [], pd.DataFrame() |
video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1 |
frames = [] |
count = 0 |
frame_id = 0 |
while cap.isOpened(): |
ret, frame = cap.read() |
if not ret: |
break |
if count % frame_interval == 0: |
frames.append((frame_id, frame)) |
frame_id += 1 |
count += 1 |
if count > video_length - 1: |
break |
cap.release() |
captions_data = [] |
for i, (frame_id, frame) in enumerate(frames): |
pil_image = convert_frame_to_pil(frame) |
caption = generate_caption(pil_image) |
captions_data.append({'Frame_ID': frame_id + 1, 'Caption': caption}) |
captions_df = pd.DataFrame(captions_data) |
return frames, captions_df |
def image_captioning_page(): |
st.title("Image Gallery with Captioning and Search") |
with st.sidebar: |
query = st.text_input("Search images by caption:") |
folder_path = st.text_input("Enter the folder path containing images:") |
if folder_path and os.path.isdir(folder_path): |
image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('png', 'jpg', 'jpeg'))] |
captions = {} |
for image_file in image_files: |
image_path = os.path.join(folder_path, image_file) |
image = Image.open(image_path) |
caption = generate_caption(image) |
captions[image_path] = caption |
cols = st.columns(4) |
for idx, (image_path, caption) in enumerate(captions.items()): |
with cols[idx % 4]: |
st.image(image_path, caption=caption) |
if query: |
results = search_captions(query, captions) |
st.write("Search Results:") |
for image_path, caption in results: |
st.image(image_path, caption=caption) |
if st.button("Save captions to excel"): |
df = pd.DataFrame(list(captions.items()), columns=['Image', 'Caption']) |
save_path = st.text_input("Enter the path to save the Excel file:", folder_path) |
if save_path: |
if not os.path.exists(save_path): |
os.makedirs(save_path) |
excel_file_path = os.path.join(save_path, "captions.xlsx") |
df.to_excel(excel_file_path, index=False) |
st.success(f"Captions saved to {excel_file_path}") |
def live_camera_captioning_page(): |
st.title("Live Captioning with Webcam") |
run = st.checkbox('Run') |
FRAME_WINDOW = st.image([]) |
camera = cv2.VideoCapture(0) |
while run: |
ret, frame = camera.read() |
if not ret: |
st.write("Failed to capture image.") |
break |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
FRAME_WINDOW.image(frame) |
pil_image = Image.fromarray(frame) |
caption = generate_caption(pil_image) |
st.write("Caption: ", caption) |
cv2.waitKey(500) |
camera.release() |
def video_captioning_page(): |
st.title("Video Captioning") |
with st.sidebar: |
query = st.text_input("Search videos by caption:") |
folder_path = st.text_input("Enter the folder path containing videos:") |
if folder_path and os.path.isdir(folder_path): |
video_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('mp4', 'avi', 'mov', 'mkv'))] |
captions = {} |
for video_file in video_files: |
video_path = os.path.join(folder_path, video_file) |
frames, captions_df = process_video(video_path, frame_interval=20) |
if frames and not captions_df.empty: |
generated_captions = ' '.join(captions_df['Caption']) |
summary = summarize_pipe(generated_captions)[0]['summary_text'] |
captions[video_path] = summary |
cols = st.columns(4) |
for idx, (video_path, summary) in enumerate(captions.items()): |
with cols[idx % 4]: |
st.video(video_path, caption=summary) |
if query: |
results = search_captions(query, captions) |
st.write("Search Results:") |
for video_path, summary in results: |
st.video(video_path, caption=summary) |
if st.button("Save captions to excel"): |
df = pd.DataFrame(list(captions.items()), columns=['Video', 'Caption']) |
save_path = st.text_input("Enter the path to save the Excel file:", folder_path) |
if save_path: |
if not os.path.exists(save_path): |
os.makedirs(save_path) |
excel_file_path = os.path.join(save_path, "captions.xlsx") |
df.to_excel(excel_file_path, index=False) |
st.success(f"Captions saved to {excel_file_path}") |
def main(): |
st.sidebar.title("Navigation") |
page = st.sidebar.selectbox("Select a page", ["Image Captioning", "Live Camera Captioning", "Video Captioning"]) |
if page == "Image Captioning": |
image_captioning_page() |
elif page == "Live Camera Captioning": |
live_camera_captioning_page() |
elif page == "Video Captioning": |
video_captioning_page() |
if __name__ == "__main__": |
main() |