NourFakih commited on
Commit
da03ea9
·
verified ·
1 Parent(s): 5af131e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -0
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ from PIL import Image
4
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, pipeline, AutoModelForSeq2SeqLM
5
+ from datetime import datetime
6
+ import pandas as pd
7
+ import tempfile
8
+ from nltk.corpus import wordnet
9
+ import nltk
10
+ import base64
11
+ import spacy
12
+ from spacy.cli import download
13
+ from streamlit_option_menu import option_menu
14
+ import torch
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ # Download necessary NLTK and spaCy data
17
+ nltk.download('wordnet')
18
+ nltk.download('omw-1.4')
19
+ # download("en_core_web_sm")
20
+
21
+ # Load the models
22
+ nlp = spacy.load("en_core_web_sm")
23
+ model_name = "NourFakih/Vit-GPT2-COCO2017Flickr-115k-12"
24
+ model = VisionEncoderDecoderModel.from_pretrained(model_name)
25
+ feature_extractor = ViTImageProcessor.from_pretrained(model_name)
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+ # GPT2 only has bos/eos tokens but not decoder_start/pad tokens
28
+ tokenizer.pad_token = tokenizer.eos_token
29
+ # update the model config
30
+ model.config.eos_token_id = tokenizer.eos_token_id
31
+ model.config.decoder_start_token_id = tokenizer.bos_token_id
32
+ model.config.pad_token_id = tokenizer.pad_token_id
33
+ image_captioner = pipeline('image-to-text', model=model_name)#, device=0)
34
+
35
+ model_sum_name = "google-t5/t5-base"
36
+ tokenizer_sum = AutoTokenizer.from_pretrained("google-t5/t5-base")
37
+ model_sum = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
38
+ summarize_pipe = pipeline("summarization", model=model_sum_name)
39
+
40
+ if 'captured_images' not in st.session_state:
41
+ st.session_state.captured_images = []
42
+
43
+ def generate_caption(image):
44
+ # pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
45
+ # pixel_values = pixel_values.to(device)
46
+ # output_ids = model.generate(pixel_values)
47
+ # caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
48
+ caption = image_captioner(image)
49
+ return caption
50
+
51
+ def get_synonyms(word):
52
+ synonyms = set()
53
+ for syn in wordnet.synsets(word):
54
+ for lemma in syn.lemmas():
55
+ synonyms.add(lemma.name())
56
+ return synonyms
57
+
58
+ def preprocess_query(query):
59
+ doc = nlp(query)
60
+ tokens = set()
61
+ for token in doc:
62
+ tokens.add(token.text)
63
+ tokens.add(token.lemma_)
64
+ tokens.update(get_synonyms(token.text))
65
+ return tokens
66
+
67
+ def search_captions(query, captions):
68
+ query_tokens = preprocess_query(query)
69
+
70
+ results = []
71
+ for img_str, caption, capture_time in captions:
72
+ caption_tokens = preprocess_query(caption)
73
+ if query_tokens & caption_tokens:
74
+ results.append((img_str, caption, capture_time))
75
+
76
+ return results
77
+
78
+ def add_image_to_state(image, caption, capture_time):
79
+ img_str = base64.b64encode(cv2.imencode('.jpg', image)[1]).decode()
80
+ if len(st.session_state.captured_images) < 20: # Limit to 20 images
81
+ st.session_state.captured_images.append((img_str, caption, capture_time))
82
+
83
+ def page_image_captioning():
84
+ st.title("Image Captioning")
85
+ # Your image captioning code here
86
+
87
+ def page_video_captioning():
88
+ st.title("Video Captioning")
89
+ # Your video captioning code here
90
+
91
+ def page_webcam_capture():
92
+ st.title("Live Captioning with Webcam")
93
+ run = st.checkbox('Run')
94
+ stop = st.button('Stop')
95
+ FRAME_WINDOW = st.image([])
96
+
97
+ if 'camera' not in st.session_state:
98
+ st.session_state.camera = cv2.VideoCapture(0)
99
+
100
+ if run:
101
+ while run:
102
+ ret, frame = st.session_state.camera.read()
103
+ if not ret:
104
+ st.write("Failed to capture image.")
105
+ break
106
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
107
+ FRAME_WINDOW.image(frame)
108
+ pil_image = Image.fromarray(frame)
109
+ caption = generate_caption(pil_image)
110
+ capture_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
111
+ add_image_to_state(frame, caption, capture_time)
112
+ st.write(f"Caption: {caption}")
113
+ if cv2.waitKey(500) & 0xFF == ord('q'):
114
+ break
115
+
116
+ if stop and 'camera' in st.session_state:
117
+ st.session_state.camera.release()
118
+ del st.session_state.camera
119
+ st.write("Camera stopped.")
120
+
121
+ # Display the collected data
122
+ if st.session_state.captured_images:
123
+ df = pd.DataFrame(st.session_state.captured_images, columns=['Image', 'Caption', 'Capture Time'])
124
+ st.table(df[['Capture Time', 'Caption']])
125
+ else:
126
+ st.write("No images captured.")
127
+
128
+ st.sidebar.title("Search Captions")
129
+ query = st.sidebar.text_input("Enter a word to search in captions:")
130
+ if st.sidebar.button("Search"):
131
+ results = search_captions(query, st.session_state.captured_images)
132
+ if results:
133
+ st.subheader("Search Results:")
134
+ cols = st.columns(4)
135
+ for idx, (img_str, caption, capture_time) in enumerate(results):
136
+ col = cols[idx % 4]
137
+ with col:
138
+ img_data = base64.b64decode(img_str)
139
+ img = Image.open(tempfile.NamedTemporaryFile(delete=False, suffix='.jpg', mode='wb').write(img_data))
140
+ st.image(img, caption=f"{caption}\n\n*{capture_time}*", width=150)
141
+ else:
142
+ st.write("No matching captions found.")
143
+
144
+ if st.sidebar.button("Generate Report"):
145
+ if 'camera' in st.session_state:
146
+ st.session_state.camera.release()
147
+ del st.session_state.camera
148
+
149
+ if st.session_state.captured_images:
150
+ # Display captured images in a 4-column grid
151
+ st.subheader("Captured Images and Captions:")
152
+ cols = st.columns(4)
153
+ for idx, (img_str, caption, capture_time) in enumerate(st.session_state.captured_images):
154
+ col = cols[idx % 4]
155
+ with col:
156
+ img_data = base64.b64decode(img_str)
157
+ img = Image.open(tempfile.NamedTemporaryFile(delete=False, suffix='.jpg', mode='wb').write(img_data))
158
+ st.image(img, caption=f"{caption}\n\n*{capture_time}*", width=150)
159
+
160
+ # Save captions to Excel and provide a download button
161
+ df = pd.DataFrame(st.session_state.captured_images, columns=['Image', 'Caption', 'Capture Time'])
162
+ df['Image'] = df['Image'].apply(lambda x: f'<img src="data:image/jpeg;base64,{x}"/>')
163
+ excel_file = tempfile.NamedTemporaryFile(delete=False, suffix='.xlsx')
164
+ df.to_excel(excel_file.name, index=False)
165
+ st.sidebar.download_button(label="Download Captions as Excel",
166
+ data=open(excel_file.name, 'rb').read(),
167
+ file_name="camera_captions.xlsx",
168
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
169
+
170
+ # Summarize captions in groups of 10
171
+ summaries = []
172
+ for i in range(0, len(st.session_state.captured_images), 10):
173
+ batch_captions = " ".join([st.session_state.captured_images[j][1] for j in range(i, min(i+10, len(st.session_state.captured_images)))] )
174
+ summary = summarize_pipe(batch_captions)[0]['summary_text']
175
+ summaries.append((st.session_state.captured_images[i][2], summary)) # Use the capture time of the first image in the batch
176
+
177
+ # Save summaries to Excel and provide a download button
178
+ df_summary = pd.DataFrame(summaries, columns=['Capture Time', 'Summary'])
179
+ summary_file = tempfile.NamedTemporaryFile(delete=False, suffix='.xlsx')
180
+ df_summary.to_excel(summary_file.name, index=False)
181
+ st.sidebar.download_button(label="Download Summary Report",
182
+ data=open(summary_file.name, 'rb').read(),
183
+ file_name="camera_summary_report.xlsx",
184
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
185
+
186
+ def main():
187
+ st.session_state.active_page = st.session_state.get("active_page", "Image Captioning")
188
+
189
+ # Sidebar for navigation
190
+ with st.sidebar:
191
+ selected = option_menu(
192
+ menu_title="Main Menu",
193
+ options=["Image Captioning", "Video Captioning", "Webcam Captioning"],
194
+ icons=["image", "Caret-right-square", "camera"],
195
+ menu_icon="cast",
196
+ default_index=0,
197
+ )
198
+
199
+ if selected != st.session_state.active_page:
200
+ handle_page_switch(selected)
201
+
202
+ if selected == "Image Captioning":
203
+ page_image_captioning()
204
+ elif selected == "Video Captioning":
205
+ page_video_captioning()
206
+ elif selected == "Webcam Captioning":
207
+ page_webcam_capture()
208
+
209
+ def handle_page_switch(selected_page):
210
+ if st.session_state.active_page == "Webcam Captioning" and "camera" in st.session_state:
211
+ st.session_state.camera.release()
212
+ del st.session_state.camera
213
+
214
+ st.session_state.active_page = selected_page
215
+
216
+ if __name__ == "__main__":
217
+ main()