Spaces:
Sleeping
Sleeping
chuanenlin
commited on
Commit
Β·
83c81a5
1
Parent(s):
a9cbf7c
Revamp
Browse files- .DS_Store +0 -0
- SessionState.py +0 -70
- cached_data/example_features.pt +3 -0
- cached_data/example_fps.npy +3 -0
- cached_data/example_frame_indices.npy +3 -0
- cached_data/example_frames.npy +3 -0
- requirements.txt +4 -2
- whichframe.py +309 -105
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
SessionState.py
DELETED
@@ -1,70 +0,0 @@
|
|
1 |
-
import streamlit.report_thread as ReportThread
|
2 |
-
from streamlit.server.server import Server
|
3 |
-
|
4 |
-
|
5 |
-
class SessionState():
|
6 |
-
"""SessionState: Add per-session state to Streamlit."""
|
7 |
-
def __init__(self, **kwargs):
|
8 |
-
"""A new SessionState object.
|
9 |
-
|
10 |
-
Parameters
|
11 |
-
----------
|
12 |
-
**kwargs : any
|
13 |
-
Default values for the session state.
|
14 |
-
|
15 |
-
Example
|
16 |
-
-------
|
17 |
-
>>> session_state = SessionState(user_name='', favorite_color='black')
|
18 |
-
>>> session_state.user_name = 'Mary'
|
19 |
-
''
|
20 |
-
>>> session_state.favorite_color
|
21 |
-
'black'
|
22 |
-
|
23 |
-
"""
|
24 |
-
for key, val in kwargs.items():
|
25 |
-
setattr(self, key, val)
|
26 |
-
|
27 |
-
|
28 |
-
def get(**kwargs):
|
29 |
-
"""Gets a SessionState object for the current session.
|
30 |
-
|
31 |
-
Creates a new object if necessary.
|
32 |
-
|
33 |
-
Parameters
|
34 |
-
----------
|
35 |
-
**kwargs : any
|
36 |
-
Default values you want to add to the session state, if we're creating a
|
37 |
-
new one.
|
38 |
-
|
39 |
-
Example
|
40 |
-
-------
|
41 |
-
>>> session_state = get(user_name='', favorite_color='black')
|
42 |
-
>>> session_state.user_name
|
43 |
-
''
|
44 |
-
>>> session_state.user_name = 'Mary'
|
45 |
-
>>> session_state.favorite_color
|
46 |
-
'black'
|
47 |
-
|
48 |
-
Since you set user_name above, next time your script runs this will be the
|
49 |
-
result:
|
50 |
-
>>> session_state = get(user_name='', favorite_color='black')
|
51 |
-
>>> session_state.user_name
|
52 |
-
'Mary'
|
53 |
-
|
54 |
-
"""
|
55 |
-
# Hack to get the session object from Streamlit.
|
56 |
-
|
57 |
-
session_id = ReportThread.get_report_ctx().session_id
|
58 |
-
session_info = Server.get_current()._get_session_info(session_id)
|
59 |
-
|
60 |
-
if session_info is None:
|
61 |
-
raise RuntimeError('Could not get Streamlit session object.')
|
62 |
-
|
63 |
-
this_session = session_info.session
|
64 |
-
|
65 |
-
# Got the session object! Now let's attach some state into it.
|
66 |
-
|
67 |
-
if not hasattr(this_session, '_custom_session_state'):
|
68 |
-
this_session._custom_session_state = SessionState(**kwargs)
|
69 |
-
|
70 |
-
return this_session._custom_session_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cached_data/example_features.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:acb80bcbcb93af49b4bfc874f9823402fd30802aadefa21a4bb10ae13853fee9
|
3 |
+
size 695497
|
cached_data/example_fps.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:531bf47f7c8d488f38892c54649751f669325416158545dadb696ea8875456ef
|
3 |
+
size 136
|
cached_data/example_frame_indices.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eb6e12a8f3e0a3a71d30a1c9adcbfa686403a9a3fee8d0dfd38320e1a6840b0a
|
3 |
+
size 2840
|
cached_data/example_frames.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dfdc073b8a2236707ed3e75bdac537163dc0f3b1d65fc92834b9c352491e895d
|
3 |
+
size 234316928
|
requirements.txt
CHANGED
@@ -1,6 +1,8 @@
|
|
|
|
1 |
Pillow
|
2 |
-
|
3 |
opencv-python-headless
|
4 |
torch
|
5 |
git+https://github.com/openai/CLIP.git
|
6 |
-
humanfriendly
|
|
|
|
1 |
+
streamlit>=1.1.0
|
2 |
Pillow
|
3 |
+
yt-dlp
|
4 |
opencv-python-headless
|
5 |
torch
|
6 |
git+https://github.com/openai/CLIP.git
|
7 |
+
humanfriendly
|
8 |
+
numpy
|
whichframe.py
CHANGED
@@ -6,124 +6,328 @@ from PIL import Image
|
|
6 |
import clip as openai_clip
|
7 |
import torch
|
8 |
import math
|
9 |
-
import SessionState
|
10 |
from humanfriendly import format_timespan
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def fetch_video(url):
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
def img_to_bytes(img):
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
def
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
time = format_timespan(seconds)
|
67 |
-
if ss.input == "file":
|
68 |
-
st.write("Seen at " + str(time) + " into the video.")
|
69 |
else:
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
def
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
hide_streamlit_style = """
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
|
96 |
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
-
st.title("Which Frame?")
|
100 |
-
st.markdown("
|
101 |
-
|
|
|
|
|
102 |
|
103 |
-
|
|
|
104 |
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
ss.video, ss.video_name = fetch_video(url)
|
114 |
-
ss.id = extract.video_id(url)
|
115 |
-
ss.url = "https://www.youtube.com/watch?v=" + ss.id
|
116 |
-
else:
|
117 |
-
st.error("Please upload a video or link to a valid YouTube video")
|
118 |
-
st.stop()
|
119 |
-
ss.video_frames, ss.fps = extract_frames(ss.video_name)
|
120 |
-
ss.video_features = encode_frames(ss.video_frames)
|
121 |
-
st.video(ss.url)
|
122 |
-
ss.progress = 2
|
123 |
-
|
124 |
-
if ss.progress == 2:
|
125 |
-
ss.text_query = st.text_input("Enter search query (Example: a person with sunglasses and earphones)")
|
126 |
-
|
127 |
-
if st.button("Submit"):
|
128 |
-
if ss.text_query is not None:
|
129 |
-
text_search(ss.text_query)
|
|
|
6 |
import clip as openai_clip
|
7 |
import torch
|
8 |
import math
|
|
|
9 |
from humanfriendly import format_timespan
|
10 |
+
from moviepy.video.io.VideoFileClip import VideoFileClip
|
11 |
+
import numpy as np
|
12 |
+
import time
|
13 |
+
import os
|
14 |
+
import yt_dlp
|
15 |
+
import io
|
16 |
+
|
17 |
+
EXAMPLE_URL = "https://www.youtube.com/watch?v=zTvJJnoWIPk"
|
18 |
+
CACHED_DATA_PATH = "cached_data/"
|
19 |
+
|
20 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
+
model, preprocess = openai_clip.load("ViT-B/32", device=device)
|
22 |
|
23 |
def fetch_video(url):
|
24 |
+
try:
|
25 |
+
ydl_opts = {
|
26 |
+
'format': 'bestvideo[height<=360][ext=mp4]',
|
27 |
+
'quiet': True,
|
28 |
+
'no_warnings': True
|
29 |
+
}
|
30 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
31 |
+
info = ydl.extract_info(url, download=False)
|
32 |
+
duration = info.get('duration', 0)
|
33 |
+
if duration >= 300: # 5 minutes
|
34 |
+
st.error("Please find a YouTube video shorter than 5 minutes.")
|
35 |
+
st.stop()
|
36 |
+
video_url = info['url']
|
37 |
+
return None, video_url
|
38 |
+
|
39 |
+
except Exception as e:
|
40 |
+
st.error(f"Error fetching video: {str(e)}")
|
41 |
+
st.error("Try another YouTube video or check if the URL is correct.")
|
42 |
+
st.stop()
|
43 |
+
|
44 |
+
def extract_frames(video, status_text, progress_bar):
|
45 |
+
cap = cv2.VideoCapture(video)
|
46 |
+
frames = []
|
47 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
48 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
49 |
+
step = max(1, round(fps/2))
|
50 |
+
total_frames = frame_count // step
|
51 |
+
frame_indices = []
|
52 |
+
for i in range(0, frame_count, step):
|
53 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
54 |
+
ret, frame = cap.read()
|
55 |
+
if ret:
|
56 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
57 |
+
frames.append(Image.fromarray(frame_rgb))
|
58 |
+
frame_indices.append(i)
|
59 |
+
|
60 |
+
current_frame = len(frames)
|
61 |
+
status_text.text(f'Extracting frames... ({min(current_frame, total_frames)}/{total_frames})')
|
62 |
+
progress = min(current_frame / total_frames, 1.0)
|
63 |
+
progress_bar.progress(progress)
|
64 |
+
|
65 |
+
cap.release()
|
66 |
+
return frames, fps, frame_indices
|
67 |
+
|
68 |
+
def encode_frames(video_frames, status_text):
|
69 |
+
batch_size = 256
|
70 |
+
batches = math.ceil(len(video_frames) / batch_size)
|
71 |
+
video_features = torch.empty([0, 512], dtype=torch.float32).to(device)
|
72 |
+
|
73 |
+
for i in range(batches):
|
74 |
+
batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
|
75 |
+
batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device)
|
76 |
+
with torch.no_grad():
|
77 |
+
batch_features = model.encode_image(batch_preprocessed)
|
78 |
+
batch_features = batch_features.float()
|
79 |
+
batch_features /= batch_features.norm(dim=-1, keepdim=True)
|
80 |
+
video_features = torch.cat((video_features, batch_features))
|
81 |
+
status_text.text(f'Encoding frames... ({(i+1)*batch_size}/{len(video_frames)})')
|
82 |
+
|
83 |
+
return video_features
|
84 |
|
85 |
def img_to_bytes(img):
|
86 |
+
img_byte_arr = io.BytesIO()
|
87 |
+
img.save(img_byte_arr, format='JPEG')
|
88 |
+
img_byte_arr = img_byte_arr.getvalue()
|
89 |
+
return img_byte_arr
|
90 |
+
|
91 |
+
def get_youtube_timestamp_url(url, frame_idx, frame_indices):
|
92 |
+
frame_count = frame_indices[frame_idx]
|
93 |
+
fps = st.session_state.fps
|
94 |
+
seconds = frame_count / fps
|
95 |
+
seconds_rounded = int(seconds)
|
96 |
+
|
97 |
+
if url == EXAMPLE_URL:
|
98 |
+
video_id = "zTvJJnoWIPk"
|
|
|
|
|
|
|
99 |
else:
|
100 |
+
try:
|
101 |
+
from urllib.parse import urlparse, parse_qs
|
102 |
+
parsed_url = urlparse(url)
|
103 |
+
video_id = parse_qs(parsed_url.query)['v'][0]
|
104 |
+
except:
|
105 |
+
return None, None
|
106 |
+
|
107 |
+
return f"https://youtu.be/{video_id}?t={seconds_rounded}", seconds
|
108 |
+
|
109 |
+
def display_results(best_photo_idx, video_frames):
|
110 |
+
st.subheader("Top 10 Results")
|
111 |
+
for frame_id in best_photo_idx:
|
112 |
+
result = video_frames[frame_id]
|
113 |
+
st.image(result, width=400)
|
114 |
+
|
115 |
+
timestamp_url, seconds = get_youtube_timestamp_url(st.session_state.url, frame_id, st.session_state.frame_indices)
|
116 |
+
if timestamp_url:
|
117 |
+
st.markdown(f"[βΆοΈ Play video at {format_timespan(int(seconds))}]({timestamp_url})")
|
118 |
+
|
119 |
+
def text_search(search_query, video_features, video_frames, display_results_count=10):
|
120 |
+
display_results_count = min(display_results_count, len(video_frames))
|
121 |
+
|
122 |
+
with torch.no_grad():
|
123 |
+
text_tokens = openai_clip.tokenize(search_query).to(device)
|
124 |
+
text_features = model.encode_text(text_tokens)
|
125 |
+
text_features = text_features.float()
|
126 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
127 |
+
|
128 |
+
video_features = video_features.float()
|
129 |
+
|
130 |
+
similarities = (100.0 * video_features @ text_features.T)
|
131 |
+
values, best_photo_idx = similarities.topk(display_results_count, dim=0)
|
132 |
+
display_results(best_photo_idx, video_frames)
|
133 |
|
134 |
+
def image_search(query_image, video_features, video_frames, display_results_count=10):
|
135 |
+
query_image = preprocess(query_image).unsqueeze(0).to(device)
|
136 |
+
|
137 |
+
with torch.no_grad():
|
138 |
+
image_features = model.encode_image(query_image)
|
139 |
+
image_features = image_features.float()
|
140 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
141 |
+
|
142 |
+
video_features = video_features.float()
|
143 |
+
|
144 |
+
similarities = (100.0 * video_features @ image_features.T)
|
145 |
+
values, best_photo_idx = similarities.topk(display_results_count, dim=0)
|
146 |
+
display_results(best_photo_idx, video_frames)
|
147 |
+
|
148 |
+
def text_and_image_search(search_query, query_image, video_features, video_frames, display_results_count=10):
|
149 |
+
with torch.no_grad():
|
150 |
+
text_tokens = openai_clip.tokenize(search_query).to(device)
|
151 |
+
text_features = model.encode_text(text_tokens)
|
152 |
+
text_features = text_features.float()
|
153 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
154 |
+
|
155 |
+
query_image = preprocess(query_image).unsqueeze(0).to(device)
|
156 |
+
with torch.no_grad():
|
157 |
+
image_features = model.encode_image(query_image)
|
158 |
+
image_features = image_features.float()
|
159 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
160 |
+
|
161 |
+
combined_features = (text_features + image_features) / 2
|
162 |
+
|
163 |
+
video_features = video_features.float()
|
164 |
+
similarities = (100.0 * video_features @ combined_features.T)
|
165 |
+
values, best_photo_idx = similarities.topk(display_results_count, dim=0)
|
166 |
+
display_results(best_photo_idx, video_frames)
|
167 |
+
|
168 |
+
def load_cached_data(url):
|
169 |
+
if url == EXAMPLE_URL:
|
170 |
+
try:
|
171 |
+
video_frames = np.load(f"{CACHED_DATA_PATH}example_frames.npy", allow_pickle=True)
|
172 |
+
video_features = torch.load(f"{CACHED_DATA_PATH}example_features.pt")
|
173 |
+
fps = np.load(f"{CACHED_DATA_PATH}example_fps.npy")
|
174 |
+
frame_indices = np.load(f"{CACHED_DATA_PATH}example_frame_indices.npy")
|
175 |
+
return video_frames, video_features, fps, frame_indices
|
176 |
+
except:
|
177 |
+
return None, None, None, None
|
178 |
+
return None, None, None, None
|
179 |
+
|
180 |
+
def save_cached_data(url, video_frames, video_features, fps, frame_indices):
|
181 |
+
if url == EXAMPLE_URL:
|
182 |
+
os.makedirs(CACHED_DATA_PATH, exist_ok=True)
|
183 |
+
np.save(f"{CACHED_DATA_PATH}example_frames.npy", video_frames)
|
184 |
+
torch.save(video_features, f"{CACHED_DATA_PATH}example_features.pt")
|
185 |
+
np.save(f"{CACHED_DATA_PATH}example_fps.npy", fps)
|
186 |
+
np.save(f"{CACHED_DATA_PATH}example_frame_indices.npy", frame_indices)
|
187 |
|
188 |
+
def clear_cached_data():
|
189 |
+
if os.path.exists(CACHED_DATA_PATH):
|
190 |
+
try:
|
191 |
+
for file in os.listdir(CACHED_DATA_PATH):
|
192 |
+
file_path = os.path.join(CACHED_DATA_PATH, file)
|
193 |
+
if os.path.isfile(file_path):
|
194 |
+
os.unlink(file_path)
|
195 |
+
os.rmdir(CACHED_DATA_PATH)
|
196 |
+
except Exception as e:
|
197 |
+
print(f"Error clearing cache: {e}")
|
198 |
+
|
199 |
+
st.set_page_config(page_title="Which Frame? ποΈπ", page_icon = "π", layout = "centered", initial_sidebar_state = "collapsed")
|
200 |
|
201 |
hide_streamlit_style = """
|
202 |
+
<style>
|
203 |
+
/* Hide Streamlit elements */
|
204 |
+
#MainMenu {visibility: hidden;}
|
205 |
+
footer {visibility: hidden;}
|
206 |
+
* {
|
207 |
+
font-family: Avenir;
|
208 |
+
}
|
209 |
+
.block-container {
|
210 |
+
max-width: 800px;
|
211 |
+
padding: 2rem 1rem;
|
212 |
+
}
|
213 |
+
.stTextInput input {
|
214 |
+
border-radius: 8px;
|
215 |
+
border: 1px solid #E0E0E0;
|
216 |
+
padding: 0.75rem;
|
217 |
+
font-size: 1rem;
|
218 |
+
}
|
219 |
+
.stRadio [role="radiogroup"] {
|
220 |
+
background: #F8F8F8;
|
221 |
+
padding: 1rem;
|
222 |
+
border-radius: 12px;
|
223 |
+
}
|
224 |
+
h1 {text-align: center;}
|
225 |
+
.css-gma2qf {display: flex; justify-content: center; font-size: 36px; font-weight: bold;}
|
226 |
+
a:link {text-decoration: none;}
|
227 |
+
a:hover {text-decoration: none;}
|
228 |
+
.st-ba {font-family: Avenir;}
|
229 |
+
.st-button {text-align: center;}
|
230 |
+
</style>
|
231 |
+
"""
|
232 |
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
|
233 |
|
234 |
+
if 'progress' not in st.session_state:
|
235 |
+
st.session_state.progress = 1
|
236 |
+
if 'video_frames' not in st.session_state:
|
237 |
+
st.session_state.video_frames = None
|
238 |
+
if 'video_features' not in st.session_state:
|
239 |
+
st.session_state.video_features = None
|
240 |
+
if 'fps' not in st.session_state:
|
241 |
+
st.session_state.fps = None
|
242 |
+
if 'video_name' not in st.session_state:
|
243 |
+
st.session_state.video_name = 'videos/example.mp4'
|
244 |
|
245 |
+
st.title("Which Frame? ποΈπ")
|
246 |
+
st.markdown("""
|
247 |
+
Search a video semantically. For example, which frame has "a person with sunglasses"?
|
248 |
+
Search using text, images, or a mix of text + image. WhichFrame uses [CLIP](https://github.com/openai/CLIP) for zero-shot frame classification.
|
249 |
+
""")
|
250 |
|
251 |
+
if 'url' not in st.session_state:
|
252 |
+
st.session_state.url = ''
|
253 |
|
254 |
+
url = st.text_input("Enter a YouTube URL (e.g., https://www.youtube.com/watch?v=zTvJJnoWIPk)", key="url_input")
|
255 |
+
|
256 |
+
if st.button("Process Video"):
|
257 |
+
if not url:
|
258 |
+
st.error("Please enter a YouTube URL first")
|
259 |
+
else:
|
260 |
+
try:
|
261 |
+
cached_frames, cached_features, cached_fps, cached_frame_indices = load_cached_data(url)
|
262 |
+
|
263 |
+
if cached_frames is not None:
|
264 |
+
st.session_state.video_frames = cached_frames
|
265 |
+
st.session_state.video_features = cached_features
|
266 |
+
st.session_state.fps = cached_fps
|
267 |
+
st.session_state.frame_indices = cached_frame_indices
|
268 |
+
st.session_state.url = url
|
269 |
+
st.session_state.progress = 2
|
270 |
+
st.success("Loaded cached video data!")
|
271 |
+
else:
|
272 |
+
with st.spinner('Fetching video...'):
|
273 |
+
video, video_url = fetch_video(url)
|
274 |
+
st.session_state.url = url
|
275 |
+
|
276 |
+
progress_bar = st.progress(0)
|
277 |
+
status_text = st.empty()
|
278 |
+
|
279 |
+
# Extract frames
|
280 |
+
st.session_state.video_frames, st.session_state.fps, st.session_state.frame_indices = extract_frames(video_url, status_text, progress_bar)
|
281 |
+
|
282 |
+
# Encode frames
|
283 |
+
st.session_state.video_features = encode_frames(st.session_state.video_frames, status_text)
|
284 |
+
|
285 |
+
save_cached_data(url, st.session_state.video_frames, st.session_state.video_features, st.session_state.fps, st.session_state.frame_indices)
|
286 |
+
status_text.text('Finalizing...')
|
287 |
+
st.session_state.progress = 2
|
288 |
+
progress_bar.progress(100)
|
289 |
+
status_text.empty()
|
290 |
+
progress_bar.empty()
|
291 |
+
st.success("Video processed successfully!")
|
292 |
+
|
293 |
+
except Exception as e:
|
294 |
+
st.error(f"Error processing video: {str(e)}")
|
295 |
+
|
296 |
+
if st.session_state.progress == 2:
|
297 |
+
search_type = st.radio("Search Method", ["Text Search", "Image Search", "Text + Image Search"], index=0)
|
298 |
+
|
299 |
+
if search_type == "Text Search": # Text Search
|
300 |
+
text_query = st.text_input("Type a search query (e.g., 'red car' or 'person with sunglasses')")
|
301 |
+
if st.button("Search"):
|
302 |
+
if not text_query:
|
303 |
+
st.error("Please enter a search query first")
|
304 |
+
else:
|
305 |
+
text_search(text_query, st.session_state.video_features, st.session_state.video_frames)
|
306 |
+
elif search_type == "Image Search": # Image Search
|
307 |
+
uploaded_file = st.file_uploader("Upload a query image", type=['png', 'jpg', 'jpeg'])
|
308 |
+
if uploaded_file is not None:
|
309 |
+
query_image = Image.open(uploaded_file).convert('RGB')
|
310 |
+
st.image(query_image, caption="Query Image", width=200)
|
311 |
+
if st.button("Search"):
|
312 |
+
if uploaded_file is None:
|
313 |
+
st.error("Please upload an image first")
|
314 |
+
else:
|
315 |
+
image_search(query_image, st.session_state.video_features, st.session_state.video_frames)
|
316 |
+
else: # Text + Image Search
|
317 |
+
text_query = st.text_input("Type a search query")
|
318 |
+
uploaded_file = st.file_uploader("Upload a query image", type=['png', 'jpg', 'jpeg'])
|
319 |
+
if uploaded_file is not None:
|
320 |
+
query_image = Image.open(uploaded_file).convert('RGB')
|
321 |
+
st.image(query_image, caption="Query Image", width=200)
|
322 |
+
|
323 |
+
if st.button("Search"):
|
324 |
+
if not text_query or uploaded_file is None:
|
325 |
+
st.error("Please provide both text query and image")
|
326 |
+
else:
|
327 |
+
text_and_image_search(text_query, query_image, st.session_state.video_features, st.session_state.video_frames)
|
328 |
|
329 |
+
st.markdown("---")
|
330 |
+
st.markdown(
|
331 |
+
"By [David Chuan-En Lin](https://chuanenlin.com/). "
|
332 |
+
"Play with the code at [https://github.com/chuanenlin/whichframe](https://github.com/chuanenlin/whichframe)."
|
333 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|