File size: 8,792 Bytes
cde9023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb502db
cde9023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb502db
cde9023
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import streamlit as st
from typing import Union, Optional
import pandas as pd
import plotly.express as px
import faiss
import numpy as np
from transformers import AutoModel, AutoProcessor
import torch
from datetime import datetime


class SelectedIndex:
    def __init__(self, idx) -> None:
        self.idx = int(idx)
        self.timestamp = datetime.now()

    def __eq__(self, value: Union["SelectedIndex", int]) -> bool:
        if isinstance(value, SelectedIndex):
            return self.idx == value.idx
        return self.idx == int(value)

    def __ne__(self, value: Union["SelectedIndex", int]) -> bool:
        return not self.__eq__(value)

    def is_valid(self) -> bool:
        return self.idx >= 0


@st.cache_data
def load_data(path: str):
    df = pd.read_parquet(path)
    embs = np.stack(df["embedding"].tolist()).astype("float32")
    faiss.normalize_L2(embs)
    D = embs.shape[1]
    index = faiss.IndexFlatIP(D)
    index.add(embs)
    return df, index, embs


def load_model() -> tuple[AutoModel, AutoProcessor]:
    if "preprocessor" not in st.session_state:
        st.session_state.preprocessor = AutoProcessor.from_pretrained(
            "nvidia/Cosmos-Embed1-224p", trust_remote_code=True, token=True,
        )
    if "model" not in st.session_state:
        model = AutoModel.from_pretrained(
            "nvidia/Cosmos-Embed1-224p", trust_remote_code=True, token=True,
        )
        model.eval()
        st.session_state.model = model
    return st.session_state.model, st.session_state.preprocessor


def preview_video(df, idx, slot, height=420, margin_top=30, autoplay=True, title=None) -> None:
    if title:
        slot.markdown(f"### {title}")
    start = int(df.loc[idx, "span_start"])
    end = int(df.loc[idx, "span_end"])
    youtube_id = df.loc[idx, "youtube_id"]
    url = f"https://www.youtube.com/embed/{youtube_id}?start={start}&end={end}"
    sep = "?" if "?" not in url else "&"
    params = f"{sep}mute=1&rel=0"
    if autoplay:
        params += "&autoplay=1"
    slot.markdown(
        f'''
        <div style="margin-top:{margin_top}px">
          <iframe width="100%" height="{height}"
                  src="{url}{params}"
                  frameborder="0"
                  allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture"
                  allow="autoplay; fullscreen" allowfullscreen>
          </iframe>
        </div>
        ''',
        unsafe_allow_html=True
    )


def get_nearest_ids(vec, k=5, ignore_self=True) -> list:
    q = vec.reshape(1, -1).astype("float32")
    faiss.normalize_L2(q)
    topk = k + 1 if ignore_self else k
    _, I = index.search(q, topk)
    ids = I[0]
    return ids[1:].tolist() if ignore_self else ids.tolist()


def get_most_recent_selection() -> tuple[Optional[int], str]:
    if st.session_state.text_selection.is_valid() and st.session_state.click_selection.is_valid():
        if st.session_state.text_selection.timestamp > st.session_state.click_selection.timestamp:
            return st.session_state.text_selection.idx, "text"
        return st.session_state.click_selection.idx, "click"
    if st.session_state.text_selection.is_valid():
        return st.session_state.text_selection.idx, "text"
    if st.session_state.click_selection.is_valid():
        return st.session_state.click_selection.idx, "click"
    return None, ""


def reset_state() -> None:
    if "text_selection" not in st.session_state:
        st.session_state.text_selection = SelectedIndex(-1)
    if "click_selection" not in st.session_state:
        st.session_state.click_selection = SelectedIndex(-1)
    if "text_query" not in st.session_state:
        st.session_state.text_query = ""

# ─── App setup ────────────────────────────────────────────────────

st.set_page_config(layout="wide")
reset_state()
model, preprocessor = load_model()
file_map = {"kinetics700 (val)": "src/kinetics700_val.parquet", "opendv (val)": "src/opendrive_val.parquet"}
st.title("πŸ” Search with Cosmos-Embed1")

col1, col2 = st.columns([2,2])
with col1:
    dataset = st.selectbox("Select dataset", list(file_map.keys()), on_change=reset_state)
df, index, embs = load_data(file_map[dataset])

# initialize session state
if "text_selection" not in st.session_state:
    st.session_state.text_selection = SelectedIndex(-1)
if "click_selection" not in st.session_state:
    st.session_state.click_selection = SelectedIndex(-1)
if "text_query" not in st.session_state:
    st.session_state.text_query = ""

# ─── Layout ────────────────────────────────────────────────────────

# LEFT: scatter
with col1:
    fig = px.scatter(
        df, x="x", y="y",
        hover_name="tar_key", hover_data=["cluster_id"],
        color="cluster_id", color_continuous_scale="Turbo",
        title="t-SNE projection (click to select)"
    )
    fig.update_layout(
        dragmode="zoom",
        margin=dict(l=5, r=5, t=40, b=5),
        xaxis_title=None, yaxis_title=None,
        coloraxis_colorbar=dict(title="")
    )
    fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False,
                     showline=True, linecolor="black", mirror=True)
    fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False,
                     showline=True, linecolor="black", mirror=True)
    fig.update_layout(annotations=[dict(
        text="k-means cluster", xref="paper", yref="paper",
        x=1.02, y=0.5, textangle=90, showarrow=False
    )])

    most_recent_idx, most_recent_method = get_most_recent_selection()
    if most_recent_idx is not None and most_recent_method == "text":
        x0, y0 = df.iloc[most_recent_idx][["x", "y"]]
        span = 6.0
        fig.update_layout(
            xaxis_range=[x0 - span, x0 + span],
            yaxis_range=[y0 - span, y0 + span],
            transition={"duration": 1},
        )

    click_event = st.plotly_chart(
        fig, use_container_width=True,
        on_select="rerun", selection_mode="points"
    )

# RIGHT: text input & preview
with col2:
    if click_event and click_event.get("selection", {}).get("point_indices"):
        curr_click = click_event["selection"]["point_indices"][0]
        if curr_click != st.session_state.click_selection:
            # new click so update the previous selection and wipe any text query
            st.session_state.click_selection = SelectedIndex(curr_click)
            st.session_state.text_query = ""

    # text input (will pick up cleared or existing text)
    text_query = st.text_input(
        "Search via text",
        key="text_query",
        help="Type a query and press Enter"
    )

    # if user typed text (and pressed Enter), override selection
    if text_query:
        with torch.no_grad():
            model_input = preprocessor(text=[text_query])
            emb_out = model.get_text_embeddings(**model_input).text_proj.cpu().numpy()
        idx_text, = get_nearest_ids(emb_out, k=1, ignore_self=False)
        if st.session_state.text_selection != idx_text:
            # new text so update the previous selection and wipe any text query
            st.session_state.text_selection = SelectedIndex(idx_text)
            st.rerun()

    # main preview
    preview_slot = st.empty()
    most_recent, most_recent_modality = get_most_recent_selection()
    if most_recent is not None:
        preview_video(df, most_recent, preview_slot)
    else:
        preview_slot.write("⏳ Waiting for selection…")

# BOTTOM: 5 nearest neighbors
st.markdown("### 🎬 5 Closest Videos")
if most_recent is not None:
    ignore_self = most_recent_modality == "click"
    nn_ids = get_nearest_ids(embs[most_recent], k=5, ignore_self=ignore_self)
    cols = st.columns(5)
    for c, nid in zip(cols, nn_ids):
        preview_video(df, nid, c, height=180, margin_top=5, autoplay=False)
else:
    st.write("Use a click or a text query above to list neighbors.")