NourFakih commited on
Commit
ac70fb3
·
verified ·
1 Parent(s): 2e2a748

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -60
app.py CHANGED
@@ -1,22 +1,22 @@
1
  import streamlit as st
2
- import cv2
3
  from PIL import Image
4
  from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, pipeline, AutoModelForSeq2SeqLM
5
  from datetime import datetime
6
  import pandas as pd
7
  import tempfile
8
- from nltk.corpus import wordnet
9
  import nltk
10
- #import base64
11
  import spacy
12
  from spacy.cli import download
13
  from streamlit_option_menu import option_menu
14
- import torch
 
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
16
  # Download necessary NLTK and spaCy data
17
  nltk.download('wordnet')
18
  nltk.download('omw-1.4')
19
-
20
  download("en_core_web_sm")
21
 
22
  # Load the models
@@ -25,13 +25,11 @@ model_name = "NourFakih/Vit-GPT2-COCO2017Flickr-115k-12"
25
  model = VisionEncoderDecoderModel.from_pretrained(model_name)
26
  feature_extractor = ViTImageProcessor.from_pretrained(model_name)
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
- # GPT2 only has bos/eos tokens but not decoder_start/pad tokens
29
  tokenizer.pad_token = tokenizer.eos_token
30
- # update the model config
31
  model.config.eos_token_id = tokenizer.eos_token_id
32
  model.config.decoder_start_token_id = tokenizer.bos_token_id
33
  model.config.pad_token_id = tokenizer.pad_token_id
34
- image_captioner = pipeline('image-to-text', model=model_name)#, device=0)
35
 
36
  model_sum_name = "google-t5/t5-base"
37
  tokenizer_sum = AutoTokenizer.from_pretrained("google-t5/t5-base")
@@ -42,12 +40,8 @@ if 'captured_images' not in st.session_state:
42
  st.session_state.captured_images = []
43
 
44
  def generate_caption(image):
45
- # pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
46
- # pixel_values = pixel_values.to(device)
47
- # output_ids = model.generate(pixel_values)
48
- # caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
49
  caption = image_captioner(image)
50
- return caption
51
 
52
  def get_synonyms(word):
53
  synonyms = set()
@@ -67,59 +61,41 @@ def preprocess_query(query):
67
 
68
  def search_captions(query, captions):
69
  query_tokens = preprocess_query(query)
70
-
71
  results = []
72
  for img_str, caption, capture_time in captions:
73
  caption_tokens = preprocess_query(caption)
74
  if query_tokens & caption_tokens:
75
  results.append((img_str, caption, capture_time))
76
-
77
  return results
78
 
79
  def add_image_to_state(image, caption, capture_time):
80
  img_str = base64.b64encode(cv2.imencode('.jpg', image)[1]).decode()
81
- if len(st.session_state.captured_images) < 20: # Limit to 20 images
82
  st.session_state.captured_images.append((img_str, caption, capture_time))
83
 
84
  def page_image_captioning():
85
  st.title("Image Captioning")
86
- # Your image captioning code here
87
 
88
  def page_video_captioning():
89
  st.title("Video Captioning")
90
- # Your video captioning code here
91
 
92
  def page_webcam_capture():
93
  st.title("Live Captioning with Webcam")
94
- run = st.checkbox('Run')
95
- stop = st.button('Stop')
96
- FRAME_WINDOW = st.image([])
97
-
98
- if 'camera' not in st.session_state:
99
- st.session_state.camera = cv2.VideoCapture(0)
100
-
101
- if run:
102
- while run:
103
- ret, frame = st.session_state.camera.read()
104
- if not ret:
105
- st.write("Failed to capture image.")
106
- break
107
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
108
- FRAME_WINDOW.image(frame)
109
- pil_image = Image.fromarray(frame)
110
- caption = generate_caption(pil_image)
111
- capture_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
112
- add_image_to_state(frame, caption, capture_time)
113
- st.write(f"Caption: {caption}")
114
- if cv2.waitKey(500) & 0xFF == ord('q'):
115
- break
116
-
117
- if stop and 'camera' in st.session_state:
118
- st.session_state.camera.release()
119
- del st.session_state.camera
120
- st.write("Camera stopped.")
121
 
122
- # Display the collected data
 
 
 
 
 
 
 
 
 
 
 
123
  if st.session_state.captured_images:
124
  df = pd.DataFrame(st.session_state.captured_images, columns=['Image', 'Caption', 'Capture Time'])
125
  st.table(df[['Capture Time', 'Caption']])
@@ -143,12 +119,7 @@ def page_webcam_capture():
143
  st.write("No matching captions found.")
144
 
145
  if st.sidebar.button("Generate Report"):
146
- if 'camera' in st.session_state:
147
- st.session_state.camera.release()
148
- del st.session_state.camera
149
-
150
  if st.session_state.captured_images:
151
- # Display captured images in a 4-column grid
152
  st.subheader("Captured Images and Captions:")
153
  cols = st.columns(4)
154
  for idx, (img_str, caption, capture_time) in enumerate(st.session_state.captured_images):
@@ -158,7 +129,6 @@ def page_webcam_capture():
158
  img = Image.open(tempfile.NamedTemporaryFile(delete=False, suffix='.jpg', mode='wb').write(img_data))
159
  st.image(img, caption=f"{caption}\n\n*{capture_time}*", width=150)
160
 
161
- # Save captions to Excel and provide a download button
162
  df = pd.DataFrame(st.session_state.captured_images, columns=['Image', 'Caption', 'Capture Time'])
163
  df['Image'] = df['Image'].apply(lambda x: f'<img src="data:image/jpeg;base64,{x}"/>')
164
  excel_file = tempfile.NamedTemporaryFile(delete=False, suffix='.xlsx')
@@ -168,14 +138,12 @@ def page_webcam_capture():
168
  file_name="camera_captions.xlsx",
169
  mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
170
 
171
- # Summarize captions in groups of 10
172
  summaries = []
173
  for i in range(0, len(st.session_state.captured_images), 10):
174
  batch_captions = " ".join([st.session_state.captured_images[j][1] for j in range(i, min(i+10, len(st.session_state.captured_images)))] )
175
  summary = summarize_pipe(batch_captions)[0]['summary_text']
176
- summaries.append((st.session_state.captured_images[i][2], summary)) # Use the capture time of the first image in the batch
177
 
178
- # Save summaries to Excel and provide a download button
179
  df_summary = pd.DataFrame(summaries, columns=['Capture Time', 'Summary'])
180
  summary_file = tempfile.NamedTemporaryFile(delete=False, suffix='.xlsx')
181
  df_summary.to_excel(summary_file.name, index=False)
@@ -187,7 +155,6 @@ def page_webcam_capture():
187
  def main():
188
  st.session_state.active_page = st.session_state.get("active_page", "Image Captioning")
189
 
190
- # Sidebar for navigation
191
  with st.sidebar:
192
  selected = option_menu(
193
  menu_title="Main Menu",
@@ -208,10 +175,6 @@ def main():
208
  page_webcam_capture()
209
 
210
  def handle_page_switch(selected_page):
211
- if st.session_state.active_page == "Webcam Captioning" and "camera" in st.session_state:
212
- st.session_state.camera.release()
213
- del st.session_state.camera
214
-
215
  st.session_state.active_page = selected_page
216
 
217
  if __name__ == "__main__":
 
1
  import streamlit as st
 
2
  from PIL import Image
3
  from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, pipeline, AutoModelForSeq2SeqLM
4
  from datetime import datetime
5
  import pandas as pd
6
  import tempfile
7
+ import base64
8
  import nltk
 
9
  import spacy
10
  from spacy.cli import download
11
  from streamlit_option_menu import option_menu
12
+ import torch
13
+
14
+ # Set device
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
  # Download necessary NLTK and spaCy data
18
  nltk.download('wordnet')
19
  nltk.download('omw-1.4')
 
20
  download("en_core_web_sm")
21
 
22
  # Load the models
 
25
  model = VisionEncoderDecoderModel.from_pretrained(model_name)
26
  feature_extractor = ViTImageProcessor.from_pretrained(model_name)
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
28
  tokenizer.pad_token = tokenizer.eos_token
 
29
  model.config.eos_token_id = tokenizer.eos_token_id
30
  model.config.decoder_start_token_id = tokenizer.bos_token_id
31
  model.config.pad_token_id = tokenizer.pad_token_id
32
+ image_captioner = pipeline('image-to-text', model=model_name)
33
 
34
  model_sum_name = "google-t5/t5-base"
35
  tokenizer_sum = AutoTokenizer.from_pretrained("google-t5/t5-base")
 
40
  st.session_state.captured_images = []
41
 
42
  def generate_caption(image):
 
 
 
 
43
  caption = image_captioner(image)
44
+ return caption[0]['generated_text']
45
 
46
  def get_synonyms(word):
47
  synonyms = set()
 
61
 
62
  def search_captions(query, captions):
63
  query_tokens = preprocess_query(query)
 
64
  results = []
65
  for img_str, caption, capture_time in captions:
66
  caption_tokens = preprocess_query(caption)
67
  if query_tokens & caption_tokens:
68
  results.append((img_str, caption, capture_time))
 
69
  return results
70
 
71
  def add_image_to_state(image, caption, capture_time):
72
  img_str = base64.b64encode(cv2.imencode('.jpg', image)[1]).decode()
73
+ if len(st.session_state.captured_images) < 20:
74
  st.session_state.captured_images.append((img_str, caption, capture_time))
75
 
76
  def page_image_captioning():
77
  st.title("Image Captioning")
78
+ st.write("Your image captioning code here")
79
 
80
  def page_video_captioning():
81
  st.title("Video Captioning")
82
+ st.write("Your video captioning code here")
83
 
84
  def page_webcam_capture():
85
  st.title("Live Captioning with Webcam")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ img_file = st.camera_input("Capture an image")
88
+
89
+ if img_file:
90
+ img = Image.open(img_file)
91
+ img_array = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
92
+ caption = generate_caption(img)
93
+ capture_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
94
+ add_image_to_state(img_array, caption, capture_time)
95
+ st.image(img, caption=f"Caption: {caption}")
96
+
97
+ if st.button('Stop'):
98
+ st.write("Camera stopped.")
99
  if st.session_state.captured_images:
100
  df = pd.DataFrame(st.session_state.captured_images, columns=['Image', 'Caption', 'Capture Time'])
101
  st.table(df[['Capture Time', 'Caption']])
 
119
  st.write("No matching captions found.")
120
 
121
  if st.sidebar.button("Generate Report"):
 
 
 
 
122
  if st.session_state.captured_images:
 
123
  st.subheader("Captured Images and Captions:")
124
  cols = st.columns(4)
125
  for idx, (img_str, caption, capture_time) in enumerate(st.session_state.captured_images):
 
129
  img = Image.open(tempfile.NamedTemporaryFile(delete=False, suffix='.jpg', mode='wb').write(img_data))
130
  st.image(img, caption=f"{caption}\n\n*{capture_time}*", width=150)
131
 
 
132
  df = pd.DataFrame(st.session_state.captured_images, columns=['Image', 'Caption', 'Capture Time'])
133
  df['Image'] = df['Image'].apply(lambda x: f'<img src="data:image/jpeg;base64,{x}"/>')
134
  excel_file = tempfile.NamedTemporaryFile(delete=False, suffix='.xlsx')
 
138
  file_name="camera_captions.xlsx",
139
  mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
140
 
 
141
  summaries = []
142
  for i in range(0, len(st.session_state.captured_images), 10):
143
  batch_captions = " ".join([st.session_state.captured_images[j][1] for j in range(i, min(i+10, len(st.session_state.captured_images)))] )
144
  summary = summarize_pipe(batch_captions)[0]['summary_text']
145
+ summaries.append((st.session_state.captured_images[i][2], summary))
146
 
 
147
  df_summary = pd.DataFrame(summaries, columns=['Capture Time', 'Summary'])
148
  summary_file = tempfile.NamedTemporaryFile(delete=False, suffix='.xlsx')
149
  df_summary.to_excel(summary_file.name, index=False)
 
155
  def main():
156
  st.session_state.active_page = st.session_state.get("active_page", "Image Captioning")
157
 
 
158
  with st.sidebar:
159
  selected = option_menu(
160
  menu_title="Main Menu",
 
175
  page_webcam_capture()
176
 
177
  def handle_page_switch(selected_page):
 
 
 
 
178
  st.session_state.active_page = selected_page
179
 
180
  if __name__ == "__main__":