Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import streamlit as st | |
from typing import Union, Optional | |
import pandas as pd | |
import plotly.express as px | |
import faiss | |
import numpy as np | |
from transformers import AutoModel, AutoProcessor | |
import torch | |
from datetime import datetime | |
class SelectedIndex: | |
def __init__(self, idx) -> None: | |
self.idx = int(idx) | |
self.timestamp = datetime.now() | |
def __eq__(self, value: Union["SelectedIndex", int]) -> bool: | |
if isinstance(value, SelectedIndex): | |
return self.idx == value.idx | |
return self.idx == int(value) | |
def __ne__(self, value: Union["SelectedIndex", int]) -> bool: | |
return not self.__eq__(value) | |
def is_valid(self) -> bool: | |
return self.idx >= 0 | |
def load_data(path: str): | |
df = pd.read_parquet(path) | |
embs = np.stack(df["embedding"].tolist()).astype("float32") | |
faiss.normalize_L2(embs) | |
D = embs.shape[1] | |
index = faiss.IndexFlatIP(D) | |
index.add(embs) | |
return df, index, embs | |
def load_model() -> tuple[AutoModel, AutoProcessor]: | |
if "preprocessor" not in st.session_state: | |
st.session_state.preprocessor = AutoProcessor.from_pretrained( | |
"nvidia/Cosmos-Embed1-224p", trust_remote_code=True, token=True, | |
) | |
if "model" not in st.session_state: | |
model = AutoModel.from_pretrained( | |
"nvidia/Cosmos-Embed1-224p", trust_remote_code=True, token=True, | |
) | |
model.eval() | |
st.session_state.model = model | |
return st.session_state.model, st.session_state.preprocessor | |
def preview_video(df, idx, slot, height=420, margin_top=30, autoplay=True, title=None) -> None: | |
if title: | |
slot.markdown(f"### {title}") | |
start = int(df.loc[idx, "span_start"]) | |
end = int(df.loc[idx, "span_end"]) | |
youtube_id = df.loc[idx, "youtube_id"] | |
url = f"https://www.youtube.com/embed/{youtube_id}?start={start}&end={end}" | |
sep = "?" if "?" not in url else "&" | |
params = f"{sep}mute=1&rel=0" | |
if autoplay: | |
params += "&autoplay=1" | |
slot.markdown( | |
f''' | |
<div style="margin-top:{margin_top}px"> | |
<iframe width="100%" height="{height}" | |
src="{url}{params}" | |
frameborder="0" | |
allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture" | |
allow="autoplay; fullscreen" allowfullscreen> | |
</iframe> | |
</div> | |
''', | |
unsafe_allow_html=True | |
) | |
def get_nearest_ids(vec, k=5, ignore_self=True) -> list: | |
q = vec.reshape(1, -1).astype("float32") | |
faiss.normalize_L2(q) | |
topk = k + 1 if ignore_self else k | |
_, I = index.search(q, topk) | |
ids = I[0] | |
return ids[1:].tolist() if ignore_self else ids.tolist() | |
def get_most_recent_selection() -> tuple[Optional[int], str]: | |
if st.session_state.text_selection.is_valid() and st.session_state.click_selection.is_valid(): | |
if st.session_state.text_selection.timestamp > st.session_state.click_selection.timestamp: | |
return st.session_state.text_selection.idx, "text" | |
return st.session_state.click_selection.idx, "click" | |
if st.session_state.text_selection.is_valid(): | |
return st.session_state.text_selection.idx, "text" | |
if st.session_state.click_selection.is_valid(): | |
return st.session_state.click_selection.idx, "click" | |
return None, "" | |
def reset_state() -> None: | |
if "text_selection" not in st.session_state: | |
st.session_state.text_selection = SelectedIndex(-1) | |
if "click_selection" not in st.session_state: | |
st.session_state.click_selection = SelectedIndex(-1) | |
if "text_query" not in st.session_state: | |
st.session_state.text_query = "" | |
# βββ App setup ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
st.set_page_config(layout="wide") | |
reset_state() | |
model, preprocessor = load_model() | |
file_map = {"kinetics700 (val)": "src/kinetics700_val.parquet", "opendv (val)": "src/opendrive_val.parquet"} | |
st.title("π Search with Cosmos-Embed1") | |
col1, col2 = st.columns([2,2]) | |
with col1: | |
dataset = st.selectbox("Select dataset", list(file_map.keys()), on_change=reset_state) | |
df, index, embs = load_data(file_map[dataset]) | |
# initialize session state | |
if "text_selection" not in st.session_state: | |
st.session_state.text_selection = SelectedIndex(-1) | |
if "click_selection" not in st.session_state: | |
st.session_state.click_selection = SelectedIndex(-1) | |
if "text_query" not in st.session_state: | |
st.session_state.text_query = "" | |
# βββ Layout ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# LEFT: scatter | |
with col1: | |
fig = px.scatter( | |
df, x="x", y="y", | |
hover_name="tar_key", hover_data=["cluster_id"], | |
color="cluster_id", color_continuous_scale="Turbo", | |
title="t-SNE projection (click to select)" | |
) | |
fig.update_layout( | |
dragmode="zoom", | |
margin=dict(l=5, r=5, t=40, b=5), | |
xaxis_title=None, yaxis_title=None, | |
coloraxis_colorbar=dict(title="") | |
) | |
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, | |
showline=True, linecolor="black", mirror=True) | |
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, | |
showline=True, linecolor="black", mirror=True) | |
fig.update_layout(annotations=[dict( | |
text="k-means cluster", xref="paper", yref="paper", | |
x=1.02, y=0.5, textangle=90, showarrow=False | |
)]) | |
most_recent_idx, most_recent_method = get_most_recent_selection() | |
if most_recent_idx is not None and most_recent_method == "text": | |
x0, y0 = df.iloc[most_recent_idx][["x", "y"]] | |
span = 6.0 | |
fig.update_layout( | |
xaxis_range=[x0 - span, x0 + span], | |
yaxis_range=[y0 - span, y0 + span], | |
transition={"duration": 1}, | |
) | |
click_event = st.plotly_chart( | |
fig, use_container_width=True, | |
on_select="rerun", selection_mode="points" | |
) | |
# RIGHT: text input & preview | |
with col2: | |
if click_event and click_event.get("selection", {}).get("point_indices"): | |
curr_click = click_event["selection"]["point_indices"][0] | |
if curr_click != st.session_state.click_selection: | |
# new click so update the previous selection and wipe any text query | |
st.session_state.click_selection = SelectedIndex(curr_click) | |
st.session_state.text_query = "" | |
# text input (will pick up cleared or existing text) | |
text_query = st.text_input( | |
"Search via text", | |
key="text_query", | |
help="Type a query and press Enter" | |
) | |
# if user typed text (and pressed Enter), override selection | |
if text_query: | |
with torch.no_grad(): | |
model_input = preprocessor(text=[text_query]) | |
emb_out = model.get_text_embeddings(**model_input).text_proj.cpu().numpy() | |
idx_text, = get_nearest_ids(emb_out, k=1, ignore_self=False) | |
if st.session_state.text_selection != idx_text: | |
# new text so update the previous selection and wipe any text query | |
st.session_state.text_selection = SelectedIndex(idx_text) | |
st.rerun() | |
# main preview | |
preview_slot = st.empty() | |
most_recent, most_recent_modality = get_most_recent_selection() | |
if most_recent is not None: | |
preview_video(df, most_recent, preview_slot) | |
else: | |
preview_slot.write("β³ Waiting for selectionβ¦") | |
# BOTTOM: 5 nearest neighbors | |
st.markdown("### π¬ 5 Closest Videos") | |
if most_recent is not None: | |
ignore_self = most_recent_modality == "click" | |
nn_ids = get_nearest_ids(embs[most_recent], k=5, ignore_self=ignore_self) | |
cols = st.columns(5) | |
for c, nid in zip(cols, nn_ids): | |
preview_video(df, nid, c, height=180, margin_top=5, autoplay=False) | |
else: | |
st.write("Use a click or a text query above to list neighbors.") | |