chuanenlin commited on
Commit
83c81a5
Β·
1 Parent(s): a9cbf7c
.DS_Store ADDED
Binary file (6.15 kB). View file
 
SessionState.py DELETED
@@ -1,70 +0,0 @@
1
- import streamlit.report_thread as ReportThread
2
- from streamlit.server.server import Server
3
-
4
-
5
- class SessionState():
6
- """SessionState: Add per-session state to Streamlit."""
7
- def __init__(self, **kwargs):
8
- """A new SessionState object.
9
-
10
- Parameters
11
- ----------
12
- **kwargs : any
13
- Default values for the session state.
14
-
15
- Example
16
- -------
17
- >>> session_state = SessionState(user_name='', favorite_color='black')
18
- >>> session_state.user_name = 'Mary'
19
- ''
20
- >>> session_state.favorite_color
21
- 'black'
22
-
23
- """
24
- for key, val in kwargs.items():
25
- setattr(self, key, val)
26
-
27
-
28
- def get(**kwargs):
29
- """Gets a SessionState object for the current session.
30
-
31
- Creates a new object if necessary.
32
-
33
- Parameters
34
- ----------
35
- **kwargs : any
36
- Default values you want to add to the session state, if we're creating a
37
- new one.
38
-
39
- Example
40
- -------
41
- >>> session_state = get(user_name='', favorite_color='black')
42
- >>> session_state.user_name
43
- ''
44
- >>> session_state.user_name = 'Mary'
45
- >>> session_state.favorite_color
46
- 'black'
47
-
48
- Since you set user_name above, next time your script runs this will be the
49
- result:
50
- >>> session_state = get(user_name='', favorite_color='black')
51
- >>> session_state.user_name
52
- 'Mary'
53
-
54
- """
55
- # Hack to get the session object from Streamlit.
56
-
57
- session_id = ReportThread.get_report_ctx().session_id
58
- session_info = Server.get_current()._get_session_info(session_id)
59
-
60
- if session_info is None:
61
- raise RuntimeError('Could not get Streamlit session object.')
62
-
63
- this_session = session_info.session
64
-
65
- # Got the session object! Now let's attach some state into it.
66
-
67
- if not hasattr(this_session, '_custom_session_state'):
68
- this_session._custom_session_state = SessionState(**kwargs)
69
-
70
- return this_session._custom_session_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cached_data/example_features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:acb80bcbcb93af49b4bfc874f9823402fd30802aadefa21a4bb10ae13853fee9
3
+ size 695497
cached_data/example_fps.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:531bf47f7c8d488f38892c54649751f669325416158545dadb696ea8875456ef
3
+ size 136
cached_data/example_frame_indices.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb6e12a8f3e0a3a71d30a1c9adcbfa686403a9a3fee8d0dfd38320e1a6840b0a
3
+ size 2840
cached_data/example_frames.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfdc073b8a2236707ed3e75bdac537163dc0f3b1d65fc92834b9c352491e895d
3
+ size 234316928
requirements.txt CHANGED
@@ -1,6 +1,8 @@
 
1
  Pillow
2
- pytube
3
  opencv-python-headless
4
  torch
5
  git+https://github.com/openai/CLIP.git
6
- humanfriendly
 
 
1
+ streamlit>=1.1.0
2
  Pillow
3
+ yt-dlp
4
  opencv-python-headless
5
  torch
6
  git+https://github.com/openai/CLIP.git
7
+ humanfriendly
8
+ numpy
whichframe.py CHANGED
@@ -6,124 +6,328 @@ from PIL import Image
6
  import clip as openai_clip
7
  import torch
8
  import math
9
- import SessionState
10
  from humanfriendly import format_timespan
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def fetch_video(url):
13
- yt = YouTube(url)
14
- streams = yt.streams.filter(adaptive=True, subtype="mp4", resolution="360p", only_video=True)
15
- length = yt.length
16
- if length >= 300:
17
- st.error("Please find a YouTube video shorter than 5 minutes. Sorry about this, the server capacity is limited for the time being.")
18
- st.stop()
19
- video = streams[0]
20
- return video, video.url
21
-
22
- @st.cache()
23
- def extract_frames(video):
24
- frames = []
25
- capture = cv2.VideoCapture(video)
26
- fps = capture.get(cv2.CAP_PROP_FPS)
27
- current_frame = 0
28
- while capture.isOpened():
29
- ret, frame = capture.read()
30
- if ret == True:
31
- frames.append(Image.fromarray(frame[:, :, ::-1]))
32
- else:
33
- break
34
- current_frame += N
35
- capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
36
- return frames, fps
37
-
38
- @st.cache()
39
- def encode_frames(video_frames):
40
- batch_size = 256
41
- batches = math.ceil(len(video_frames) / batch_size)
42
- video_features = torch.empty([0, 512], dtype=torch.float16).to(device)
43
- for i in range(batches):
44
- batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
45
- batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device)
46
- with torch.no_grad():
47
- batch_features = model.encode_image(batch_preprocessed)
48
- batch_features /= batch_features.norm(dim=-1, keepdim=True)
49
- video_features = torch.cat((video_features, batch_features))
50
- return video_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def img_to_bytes(img):
53
- img_byte_arr = io.BytesIO()
54
- img.save(img_byte_arr, format='JPEG')
55
- img_byte_arr = img_byte_arr.getvalue()
56
- return img_byte_arr
57
-
58
- def display_results(best_photo_idx):
59
- st.markdown("**Top-5 matching results**")
60
- result_arr = []
61
- for frame_id in best_photo_idx:
62
- result = ss.video_frames[frame_id]
63
- st.image(result)
64
- seconds = round(frame_id.cpu().numpy()[0] * N / ss.fps)
65
- result_arr.append(seconds)
66
- time = format_timespan(seconds)
67
- if ss.input == "file":
68
- st.write("Seen at " + str(time) + " into the video.")
69
  else:
70
- st.markdown("Seen at [" + str(time) + "](" + url + "&t=" + str(seconds) + "s) into the video.")
71
- return result_arr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- def text_search(search_query, display_results_count=5):
74
- with torch.no_grad():
75
- text_features = model.encode_text(openai_clip.tokenize(search_query).to(device))
76
- text_features /= text_features.norm(dim=-1, keepdim=True)
77
- similarities = (100.0 * ss.video_features @ text_features.T)
78
- values, best_photo_idx = similarities.topk(display_results_count, dim=0)
79
- result_arr = display_results(best_photo_idx)
80
- return result_arr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- st.set_page_config(page_title="Which Frame?", page_icon = "πŸ”", layout = "centered", initial_sidebar_state = "collapsed")
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  hide_streamlit_style = """
85
- <style>
86
- #MainMenu {visibility: hidden;}
87
- footer {visibility: hidden;}
88
- * {font-family: Avenir;}
89
- .css-gma2qf {display: flex; justify-content: center; font-size: 42px; font-weight: bold;}
90
- a:link {text-decoration: none;}
91
- a:hover {text-decoration: none;}
92
- .st-ba {font-family: Avenir;}
93
- </style>
94
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  st.markdown(hide_streamlit_style, unsafe_allow_html=True)
96
 
97
- ss = SessionState.get(url=None, id=None, input=None, file_name=None, video=None, video_name=None, video_frames=None, video_features=None, fps=None, mode=None, query=None, progress=1)
 
 
 
 
 
 
 
 
 
98
 
99
- st.title("Which Frame?")
100
- st.markdown("Search a video **semantically**. For example: Which frame has a person with sunglasses and earphones?")
101
- url = st.text_input("Link to a YouTube video (Example: https://www.youtube.com/watch?v=sxaTnm_4YMY)")
 
 
102
 
103
- N = 30
 
104
 
105
- device = "cuda" if torch.cuda.is_available() else "cpu"
106
- model, preprocess = openai_clip.load("ViT-B/32", device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- if st.button("Process video (this may take a while)"):
109
- ss.progress = 1
110
- ss.video_start_time = 0
111
- if url:
112
- ss.input = "link"
113
- ss.video, ss.video_name = fetch_video(url)
114
- ss.id = extract.video_id(url)
115
- ss.url = "https://www.youtube.com/watch?v=" + ss.id
116
- else:
117
- st.error("Please upload a video or link to a valid YouTube video")
118
- st.stop()
119
- ss.video_frames, ss.fps = extract_frames(ss.video_name)
120
- ss.video_features = encode_frames(ss.video_frames)
121
- st.video(ss.url)
122
- ss.progress = 2
123
-
124
- if ss.progress == 2:
125
- ss.text_query = st.text_input("Enter search query (Example: a person with sunglasses and earphones)")
126
-
127
- if st.button("Submit"):
128
- if ss.text_query is not None:
129
- text_search(ss.text_query)
 
6
  import clip as openai_clip
7
  import torch
8
  import math
 
9
  from humanfriendly import format_timespan
10
+ from moviepy.video.io.VideoFileClip import VideoFileClip
11
+ import numpy as np
12
+ import time
13
+ import os
14
+ import yt_dlp
15
+ import io
16
+
17
+ EXAMPLE_URL = "https://www.youtube.com/watch?v=zTvJJnoWIPk"
18
+ CACHED_DATA_PATH = "cached_data/"
19
+
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ model, preprocess = openai_clip.load("ViT-B/32", device=device)
22
 
23
  def fetch_video(url):
24
+ try:
25
+ ydl_opts = {
26
+ 'format': 'bestvideo[height<=360][ext=mp4]',
27
+ 'quiet': True,
28
+ 'no_warnings': True
29
+ }
30
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
31
+ info = ydl.extract_info(url, download=False)
32
+ duration = info.get('duration', 0)
33
+ if duration >= 300: # 5 minutes
34
+ st.error("Please find a YouTube video shorter than 5 minutes.")
35
+ st.stop()
36
+ video_url = info['url']
37
+ return None, video_url
38
+
39
+ except Exception as e:
40
+ st.error(f"Error fetching video: {str(e)}")
41
+ st.error("Try another YouTube video or check if the URL is correct.")
42
+ st.stop()
43
+
44
+ def extract_frames(video, status_text, progress_bar):
45
+ cap = cv2.VideoCapture(video)
46
+ frames = []
47
+ fps = cap.get(cv2.CAP_PROP_FPS)
48
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
49
+ step = max(1, round(fps/2))
50
+ total_frames = frame_count // step
51
+ frame_indices = []
52
+ for i in range(0, frame_count, step):
53
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
54
+ ret, frame = cap.read()
55
+ if ret:
56
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
57
+ frames.append(Image.fromarray(frame_rgb))
58
+ frame_indices.append(i)
59
+
60
+ current_frame = len(frames)
61
+ status_text.text(f'Extracting frames... ({min(current_frame, total_frames)}/{total_frames})')
62
+ progress = min(current_frame / total_frames, 1.0)
63
+ progress_bar.progress(progress)
64
+
65
+ cap.release()
66
+ return frames, fps, frame_indices
67
+
68
+ def encode_frames(video_frames, status_text):
69
+ batch_size = 256
70
+ batches = math.ceil(len(video_frames) / batch_size)
71
+ video_features = torch.empty([0, 512], dtype=torch.float32).to(device)
72
+
73
+ for i in range(batches):
74
+ batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
75
+ batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device)
76
+ with torch.no_grad():
77
+ batch_features = model.encode_image(batch_preprocessed)
78
+ batch_features = batch_features.float()
79
+ batch_features /= batch_features.norm(dim=-1, keepdim=True)
80
+ video_features = torch.cat((video_features, batch_features))
81
+ status_text.text(f'Encoding frames... ({(i+1)*batch_size}/{len(video_frames)})')
82
+
83
+ return video_features
84
 
85
  def img_to_bytes(img):
86
+ img_byte_arr = io.BytesIO()
87
+ img.save(img_byte_arr, format='JPEG')
88
+ img_byte_arr = img_byte_arr.getvalue()
89
+ return img_byte_arr
90
+
91
+ def get_youtube_timestamp_url(url, frame_idx, frame_indices):
92
+ frame_count = frame_indices[frame_idx]
93
+ fps = st.session_state.fps
94
+ seconds = frame_count / fps
95
+ seconds_rounded = int(seconds)
96
+
97
+ if url == EXAMPLE_URL:
98
+ video_id = "zTvJJnoWIPk"
 
 
 
99
  else:
100
+ try:
101
+ from urllib.parse import urlparse, parse_qs
102
+ parsed_url = urlparse(url)
103
+ video_id = parse_qs(parsed_url.query)['v'][0]
104
+ except:
105
+ return None, None
106
+
107
+ return f"https://youtu.be/{video_id}?t={seconds_rounded}", seconds
108
+
109
+ def display_results(best_photo_idx, video_frames):
110
+ st.subheader("Top 10 Results")
111
+ for frame_id in best_photo_idx:
112
+ result = video_frames[frame_id]
113
+ st.image(result, width=400)
114
+
115
+ timestamp_url, seconds = get_youtube_timestamp_url(st.session_state.url, frame_id, st.session_state.frame_indices)
116
+ if timestamp_url:
117
+ st.markdown(f"[▢️ Play video at {format_timespan(int(seconds))}]({timestamp_url})")
118
+
119
+ def text_search(search_query, video_features, video_frames, display_results_count=10):
120
+ display_results_count = min(display_results_count, len(video_frames))
121
+
122
+ with torch.no_grad():
123
+ text_tokens = openai_clip.tokenize(search_query).to(device)
124
+ text_features = model.encode_text(text_tokens)
125
+ text_features = text_features.float()
126
+ text_features /= text_features.norm(dim=-1, keepdim=True)
127
+
128
+ video_features = video_features.float()
129
+
130
+ similarities = (100.0 * video_features @ text_features.T)
131
+ values, best_photo_idx = similarities.topk(display_results_count, dim=0)
132
+ display_results(best_photo_idx, video_frames)
133
 
134
+ def image_search(query_image, video_features, video_frames, display_results_count=10):
135
+ query_image = preprocess(query_image).unsqueeze(0).to(device)
136
+
137
+ with torch.no_grad():
138
+ image_features = model.encode_image(query_image)
139
+ image_features = image_features.float()
140
+ image_features /= image_features.norm(dim=-1, keepdim=True)
141
+
142
+ video_features = video_features.float()
143
+
144
+ similarities = (100.0 * video_features @ image_features.T)
145
+ values, best_photo_idx = similarities.topk(display_results_count, dim=0)
146
+ display_results(best_photo_idx, video_frames)
147
+
148
+ def text_and_image_search(search_query, query_image, video_features, video_frames, display_results_count=10):
149
+ with torch.no_grad():
150
+ text_tokens = openai_clip.tokenize(search_query).to(device)
151
+ text_features = model.encode_text(text_tokens)
152
+ text_features = text_features.float()
153
+ text_features /= text_features.norm(dim=-1, keepdim=True)
154
+
155
+ query_image = preprocess(query_image).unsqueeze(0).to(device)
156
+ with torch.no_grad():
157
+ image_features = model.encode_image(query_image)
158
+ image_features = image_features.float()
159
+ image_features /= image_features.norm(dim=-1, keepdim=True)
160
+
161
+ combined_features = (text_features + image_features) / 2
162
+
163
+ video_features = video_features.float()
164
+ similarities = (100.0 * video_features @ combined_features.T)
165
+ values, best_photo_idx = similarities.topk(display_results_count, dim=0)
166
+ display_results(best_photo_idx, video_frames)
167
+
168
+ def load_cached_data(url):
169
+ if url == EXAMPLE_URL:
170
+ try:
171
+ video_frames = np.load(f"{CACHED_DATA_PATH}example_frames.npy", allow_pickle=True)
172
+ video_features = torch.load(f"{CACHED_DATA_PATH}example_features.pt")
173
+ fps = np.load(f"{CACHED_DATA_PATH}example_fps.npy")
174
+ frame_indices = np.load(f"{CACHED_DATA_PATH}example_frame_indices.npy")
175
+ return video_frames, video_features, fps, frame_indices
176
+ except:
177
+ return None, None, None, None
178
+ return None, None, None, None
179
+
180
+ def save_cached_data(url, video_frames, video_features, fps, frame_indices):
181
+ if url == EXAMPLE_URL:
182
+ os.makedirs(CACHED_DATA_PATH, exist_ok=True)
183
+ np.save(f"{CACHED_DATA_PATH}example_frames.npy", video_frames)
184
+ torch.save(video_features, f"{CACHED_DATA_PATH}example_features.pt")
185
+ np.save(f"{CACHED_DATA_PATH}example_fps.npy", fps)
186
+ np.save(f"{CACHED_DATA_PATH}example_frame_indices.npy", frame_indices)
187
 
188
+ def clear_cached_data():
189
+ if os.path.exists(CACHED_DATA_PATH):
190
+ try:
191
+ for file in os.listdir(CACHED_DATA_PATH):
192
+ file_path = os.path.join(CACHED_DATA_PATH, file)
193
+ if os.path.isfile(file_path):
194
+ os.unlink(file_path)
195
+ os.rmdir(CACHED_DATA_PATH)
196
+ except Exception as e:
197
+ print(f"Error clearing cache: {e}")
198
+
199
+ st.set_page_config(page_title="Which Frame? πŸŽžοΈπŸ”", page_icon = "πŸ”", layout = "centered", initial_sidebar_state = "collapsed")
200
 
201
  hide_streamlit_style = """
202
+ <style>
203
+ /* Hide Streamlit elements */
204
+ #MainMenu {visibility: hidden;}
205
+ footer {visibility: hidden;}
206
+ * {
207
+ font-family: Avenir;
208
+ }
209
+ .block-container {
210
+ max-width: 800px;
211
+ padding: 2rem 1rem;
212
+ }
213
+ .stTextInput input {
214
+ border-radius: 8px;
215
+ border: 1px solid #E0E0E0;
216
+ padding: 0.75rem;
217
+ font-size: 1rem;
218
+ }
219
+ .stRadio [role="radiogroup"] {
220
+ background: #F8F8F8;
221
+ padding: 1rem;
222
+ border-radius: 12px;
223
+ }
224
+ h1 {text-align: center;}
225
+ .css-gma2qf {display: flex; justify-content: center; font-size: 36px; font-weight: bold;}
226
+ a:link {text-decoration: none;}
227
+ a:hover {text-decoration: none;}
228
+ .st-ba {font-family: Avenir;}
229
+ .st-button {text-align: center;}
230
+ </style>
231
+ """
232
  st.markdown(hide_streamlit_style, unsafe_allow_html=True)
233
 
234
+ if 'progress' not in st.session_state:
235
+ st.session_state.progress = 1
236
+ if 'video_frames' not in st.session_state:
237
+ st.session_state.video_frames = None
238
+ if 'video_features' not in st.session_state:
239
+ st.session_state.video_features = None
240
+ if 'fps' not in st.session_state:
241
+ st.session_state.fps = None
242
+ if 'video_name' not in st.session_state:
243
+ st.session_state.video_name = 'videos/example.mp4'
244
 
245
+ st.title("Which Frame? πŸŽžοΈπŸ”")
246
+ st.markdown("""
247
+ Search a video semantically. For example, which frame has "a person with sunglasses"?
248
+ Search using text, images, or a mix of text + image. WhichFrame uses [CLIP](https://github.com/openai/CLIP) for zero-shot frame classification.
249
+ """)
250
 
251
+ if 'url' not in st.session_state:
252
+ st.session_state.url = ''
253
 
254
+ url = st.text_input("Enter a YouTube URL (e.g., https://www.youtube.com/watch?v=zTvJJnoWIPk)", key="url_input")
255
+
256
+ if st.button("Process Video"):
257
+ if not url:
258
+ st.error("Please enter a YouTube URL first")
259
+ else:
260
+ try:
261
+ cached_frames, cached_features, cached_fps, cached_frame_indices = load_cached_data(url)
262
+
263
+ if cached_frames is not None:
264
+ st.session_state.video_frames = cached_frames
265
+ st.session_state.video_features = cached_features
266
+ st.session_state.fps = cached_fps
267
+ st.session_state.frame_indices = cached_frame_indices
268
+ st.session_state.url = url
269
+ st.session_state.progress = 2
270
+ st.success("Loaded cached video data!")
271
+ else:
272
+ with st.spinner('Fetching video...'):
273
+ video, video_url = fetch_video(url)
274
+ st.session_state.url = url
275
+
276
+ progress_bar = st.progress(0)
277
+ status_text = st.empty()
278
+
279
+ # Extract frames
280
+ st.session_state.video_frames, st.session_state.fps, st.session_state.frame_indices = extract_frames(video_url, status_text, progress_bar)
281
+
282
+ # Encode frames
283
+ st.session_state.video_features = encode_frames(st.session_state.video_frames, status_text)
284
+
285
+ save_cached_data(url, st.session_state.video_frames, st.session_state.video_features, st.session_state.fps, st.session_state.frame_indices)
286
+ status_text.text('Finalizing...')
287
+ st.session_state.progress = 2
288
+ progress_bar.progress(100)
289
+ status_text.empty()
290
+ progress_bar.empty()
291
+ st.success("Video processed successfully!")
292
+
293
+ except Exception as e:
294
+ st.error(f"Error processing video: {str(e)}")
295
+
296
+ if st.session_state.progress == 2:
297
+ search_type = st.radio("Search Method", ["Text Search", "Image Search", "Text + Image Search"], index=0)
298
+
299
+ if search_type == "Text Search": # Text Search
300
+ text_query = st.text_input("Type a search query (e.g., 'red car' or 'person with sunglasses')")
301
+ if st.button("Search"):
302
+ if not text_query:
303
+ st.error("Please enter a search query first")
304
+ else:
305
+ text_search(text_query, st.session_state.video_features, st.session_state.video_frames)
306
+ elif search_type == "Image Search": # Image Search
307
+ uploaded_file = st.file_uploader("Upload a query image", type=['png', 'jpg', 'jpeg'])
308
+ if uploaded_file is not None:
309
+ query_image = Image.open(uploaded_file).convert('RGB')
310
+ st.image(query_image, caption="Query Image", width=200)
311
+ if st.button("Search"):
312
+ if uploaded_file is None:
313
+ st.error("Please upload an image first")
314
+ else:
315
+ image_search(query_image, st.session_state.video_features, st.session_state.video_frames)
316
+ else: # Text + Image Search
317
+ text_query = st.text_input("Type a search query")
318
+ uploaded_file = st.file_uploader("Upload a query image", type=['png', 'jpg', 'jpeg'])
319
+ if uploaded_file is not None:
320
+ query_image = Image.open(uploaded_file).convert('RGB')
321
+ st.image(query_image, caption="Query Image", width=200)
322
+
323
+ if st.button("Search"):
324
+ if not text_query or uploaded_file is None:
325
+ st.error("Please provide both text query and image")
326
+ else:
327
+ text_and_image_search(text_query, query_image, st.session_state.video_features, st.session_state.video_frames)
328
 
329
+ st.markdown("---")
330
+ st.markdown(
331
+ "By [David Chuan-En Lin](https://chuanenlin.com/). "
332
+ "Play with the code at [https://github.com/chuanenlin/whichframe](https://github.com/chuanenlin/whichframe)."
333
+ )