Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import os | |
from datetime import datetime | |
from datasets import load_dataset | |
# Initialize session state | |
if 'search_history' not in st.session_state: | |
st.session_state['search_history'] = [] | |
if 'search_columns' not in st.session_state: | |
st.session_state['search_columns'] = [] | |
if 'dataset_loaded' not in st.session_state: | |
st.session_state['dataset_loaded'] = False | |
if 'current_page' not in st.session_state: | |
st.session_state['current_page'] = 0 | |
if 'data_cache' not in st.session_state: | |
st.session_state['data_cache'] = None | |
if 'dataset_info' not in st.session_state: | |
st.session_state['dataset_info'] = None | |
ROWS_PER_PAGE = 100 # Number of rows to load at a time | |
def get_model(): | |
"""Cache the model loading""" | |
return SentenceTransformer('all-MiniLM-L6-v2') | |
def load_dataset_page(dataset_id, token, page, rows_per_page): | |
"""Load and cache a specific page of data""" | |
try: | |
start_idx = page * rows_per_page | |
end_idx = start_idx + rows_per_page | |
dataset = load_dataset( | |
dataset_id, | |
token=token, | |
streaming=False, | |
split=f'train[{start_idx}:{end_idx}]' | |
) | |
return pd.DataFrame(dataset) | |
except Exception as e: | |
st.error(f"Error loading page {page}: {str(e)}") | |
return pd.DataFrame() | |
def get_dataset_info(dataset_id, token): | |
"""Load and cache dataset information""" | |
try: | |
dataset = load_dataset( | |
dataset_id, | |
token=token, | |
streaming=True | |
) | |
return dataset['train'].info | |
except Exception as e: | |
st.error(f"Error loading dataset info: {str(e)}") | |
return None | |
class FastDatasetSearcher: | |
def __init__(self, dataset_id="tomg-group-umd/cinepile"): | |
self.dataset_id = dataset_id | |
self.text_model = get_model() | |
self.token = os.environ.get('DATASET_KEY') | |
if not self.token: | |
st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.") | |
st.stop() | |
# Initialize numpy for model inputs | |
self.np = np | |
# Load dataset info if not already loaded | |
if st.session_state['dataset_info'] is None: | |
st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token) | |
def load_page(self, page=0): | |
"""Load a specific page of data using cached function""" | |
return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE) | |
def quick_search(self, query, df): | |
"""Fast search on current page""" | |
if df.empty: | |
return df | |
try: | |
# Get columns to search (excluding numpy array columns) | |
searchable_cols = [] | |
for col in df.columns: | |
sample_val = df[col].iloc[0] | |
if not isinstance(sample_val, (np.ndarray, bytes)): | |
searchable_cols.append(col) | |
# Prepare query | |
query_lower = query.lower() | |
query_embedding = self.text_model.encode([query], show_progress_bar=False)[0] | |
scores = [] | |
# Process each row | |
for _, row in df.iterrows(): | |
# Combine text from searchable columns | |
text_parts = [] | |
for col in searchable_cols: | |
val = row[col] | |
if val is not None: | |
if isinstance(val, (list, dict)): | |
text_parts.append(str(val)) | |
else: | |
text_parts.append(str(val)) | |
text = ' '.join(text_parts) | |
# Calculate scores | |
if text.strip(): | |
# Keyword matching | |
keyword_score = text.lower().count(query_lower) / max(len(text.split()), 1) | |
# Semantic matching | |
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0] | |
semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0]) | |
# Combine scores | |
combined_score = 0.5 * semantic_score + 0.5 * keyword_score | |
else: | |
combined_score = 0.0 | |
scores.append(combined_score) | |
# Get top results | |
results_df = df.copy() | |
results_df['score'] = scores | |
return results_df.sort_values('score', ascending=False) | |
def render_result(result): | |
"""Render a single search result""" | |
score = result.pop('score', 0) | |
# Display video if available | |
if 'youtube_id' in result: | |
st.video( | |
f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}" | |
) | |
# Display other fields | |
cols = st.columns([2, 1]) | |
with cols[0]: | |
for key, value in result.items(): | |
if isinstance(value, (str, int, float)): | |
st.write(f"**{key}:** {value}") | |
with cols[1]: | |
st.metric("Relevance Score", f"{score:.2%}") | |
def main(): | |
st.title("🎥 Fast Video Dataset Search") | |
# Initialize search class | |
searcher = FastDatasetSearcher() | |
# Show dataset info | |
if st.session_state['dataset_info']: | |
st.sidebar.write("### Dataset Info") | |
st.sidebar.write(f"Total examples: {st.session_state['dataset_info'].splits['train'].num_examples:,}") | |
total_pages = st.session_state['dataset_info'].splits['train'].num_examples // ROWS_PER_PAGE | |
current_page = st.number_input("Page", min_value=0, max_value=total_pages, value=st.session_state['current_page']) | |
else: | |
current_page = st.number_input("Page", min_value=0, value=st.session_state['current_page']) | |
# Load current page | |
with st.spinner(f"Loading page {current_page}..."): | |
df = searcher.load_page(current_page) | |
if df.empty: | |
st.warning("No data available for this page.") | |
return | |
# Search interface | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
query = st.text_input("Search in current page:", | |
help="Searches within currently loaded data") | |
with col2: | |
max_results = st.slider("Max results", 1, ROWS_PER_PAGE, 10) | |
if query: | |
with st.spinner("Searching..."): | |
results = searcher.quick_search(query, df) | |
# Display results | |
st.write(f"Found {len(results)} results on this page:") | |
for i, (_, result) in enumerate(results.head(max_results).iterrows(), 1): | |
with st.expander(f"Result {i}", expanded=i==1): | |
render_result(result) | |
# Show raw data | |
with st.expander("Show Raw Data"): | |
st.dataframe(df) | |
# Navigation buttons | |
cols = st.columns(2) | |
with cols[0]: | |
if st.button("⬅️ Previous Page") and current_page > 0: | |
st.session_state['current_page'] = current_page - 1 | |
st.rerun() | |
with cols[1]: | |
if st.button("Next Page ➡️"): | |
st.session_state['current_page'] = current_page + 1 | |
st.rerun() | |
if __name__ == "__main__": | |
main() |