fferroni commited on
Commit
cde9023
Β·
1 Parent(s): eb502db

initial commit

Browse files
.streamlit/config.toml ADDED
File without changes
Dockerfile CHANGED
@@ -1,3 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  FROM python:3.9-slim
2
 
3
  WORKDIR /app
@@ -9,6 +24,15 @@ RUN apt-get update && apt-get install -y \
9
  git \
10
  && rm -rf /var/lib/apt/lists/*
11
 
 
 
 
 
 
 
 
 
 
12
  COPY requirements.txt ./
13
  COPY src/ ./src/
14
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
  FROM python:3.9-slim
17
 
18
  WORKDIR /app
 
24
  git \
25
  && rm -rf /var/lib/apt/lists/*
26
 
27
+ RUN mkdir -p /tmp/huggingface \
28
+ && chown -R 1000:1000 /tmp/huggingface
29
+
30
+ ENV HF_HOME=/tmp/huggingface \
31
+ HF_HUB_CACHE=/tmp/huggingface \
32
+ TRANSFORMERS_CACHE=/tmp/huggingface \
33
+ XDG_CACHE_HOME=/tmp \
34
+ TMPDIR=/tmp
35
+
36
  COPY requirements.txt ./
37
  COPY src/ ./src/
38
 
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
- title: Streamlit Template Space
3
  emoji: πŸš€
4
  colorFrom: red
5
  colorTo: red
6
  sdk: docker
7
  app_port: 8501
8
  tags:
9
- - streamlit
10
  pinned: false
11
- short_description: Streamlit template space
12
  ---
13
 
14
  # Welcome to Streamlit!
 
1
  ---
2
+ title: Cosmos Embed1
3
  emoji: πŸš€
4
  colorFrom: red
5
  colorTo: red
6
  sdk: docker
7
  app_port: 8501
8
  tags:
9
+ - streamlit
10
  pinned: false
11
+ short_description: Cosmos-Embed1 demo app
12
  ---
13
 
14
  # Welcome to Streamlit!
requirements.txt CHANGED
@@ -1,3 +1,23 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ altair==5.5.0
17
+ streamlit==1.45.0
18
+ pandas==2.2.3
19
+ plotly==6.0.1
20
+ faiss-cpu==1.11.0
21
+ transformers==4.44.2
22
+ einops==0.8.1
23
+ torchvision==0.21.0
src/kinetics700_val.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c7495ac043e0565b636b2ac964bf630fa0b0ed6c7569cbd6c92f156ea5899bb
3
+ size 15595889
src/opendrive_val.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d0610bbbba386744efed46272c39d140ba902e1ff47acddb7ec54c44ff7d444
3
+ size 10482753
src/streamlit_app.py CHANGED
@@ -1,40 +1,229 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
  import streamlit as st
17
+ from typing import Union, Optional
18
+ import pandas as pd
19
+ import plotly.express as px
20
+ import faiss
21
+ import numpy as np
22
+ from transformers import AutoModel, AutoProcessor
23
+ import torch
24
+ from datetime import datetime
25
+
26
+
27
+ class SelectedIndex:
28
+ def __init__(self, idx) -> None:
29
+ self.idx = int(idx)
30
+ self.timestamp = datetime.now()
31
+
32
+ def __eq__(self, value: Union["SelectedIndex", int]) -> bool:
33
+ if isinstance(value, SelectedIndex):
34
+ return self.idx == value.idx
35
+ return self.idx == int(value)
36
+
37
+ def __ne__(self, value: Union["SelectedIndex", int]) -> bool:
38
+ return not self.__eq__(value)
39
+
40
+ def is_valid(self) -> bool:
41
+ return self.idx >= 0
42
+
43
+
44
+ @st.cache_data
45
+ def load_data(path: str):
46
+ df = pd.read_parquet(path)
47
+ embs = np.stack(df["embedding"].tolist()).astype("float32")
48
+ faiss.normalize_L2(embs)
49
+ D = embs.shape[1]
50
+ index = faiss.IndexFlatIP(D)
51
+ index.add(embs)
52
+ return df, index, embs
53
+
54
+
55
+ def load_model() -> tuple[AutoModel, AutoProcessor]:
56
+ if "preprocessor" not in st.session_state:
57
+ st.session_state.preprocessor = AutoProcessor.from_pretrained(
58
+ "nvidia/Cosmos-Embed1-224p", trust_remote_code=True, token=True,
59
+ )
60
+ if "model" not in st.session_state:
61
+ model = AutoModel.from_pretrained(
62
+ "nvidia/Cosmos-Embed1-224p", trust_remote_code=True, token=True,
63
+ )
64
+ model.eval()
65
+ st.session_state.model = model
66
+ return st.session_state.model, st.session_state.preprocessor
67
+
68
+
69
+ def preview_video(df, idx, slot, height=420, margin_top=30, autoplay=True, title=None) -> None:
70
+ if title:
71
+ slot.markdown(f"### {title}")
72
+ start = int(df.loc[idx, "span_start"])
73
+ end = int(df.loc[idx, "span_end"])
74
+ youtube_id = df.loc[idx, "youtube_id"]
75
+ url = f"https://www.youtube.com/embed/{youtube_id}?start={start}&end={end}"
76
+ sep = "?" if "?" not in url else "&"
77
+ params = f"{sep}mute=1&rel=0"
78
+ if autoplay:
79
+ params += "&autoplay=1"
80
+ slot.markdown(
81
+ f'''
82
+ <div style="margin-top:{margin_top}px">
83
+ <iframe width="100%" height="{height}"
84
+ src="{url}{params}"
85
+ frameborder="0"
86
+ allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture"
87
+ allow="autoplay; fullscreen" allowfullscreen>
88
+ </iframe>
89
+ </div>
90
+ ''',
91
+ unsafe_allow_html=True
92
+ )
93
+
94
+
95
+ def get_nearest_ids(vec, k=5, ignore_self=True) -> list:
96
+ q = vec.reshape(1, -1).astype("float32")
97
+ faiss.normalize_L2(q)
98
+ topk = k + 1 if ignore_self else k
99
+ _, I = index.search(q, topk)
100
+ ids = I[0]
101
+ return ids[1:].tolist() if ignore_self else ids.tolist()
102
+
103
+
104
+ def get_most_recent_selection() -> tuple[Optional[int], str]:
105
+ if st.session_state.text_selection.is_valid() and st.session_state.click_selection.is_valid():
106
+ if st.session_state.text_selection.timestamp > st.session_state.click_selection.timestamp:
107
+ return st.session_state.text_selection.idx, "text"
108
+ return st.session_state.click_selection.idx, "click"
109
+ if st.session_state.text_selection.is_valid():
110
+ return st.session_state.text_selection.idx, "text"
111
+ if st.session_state.click_selection.is_valid():
112
+ return st.session_state.click_selection.idx, "click"
113
+ return None, ""
114
+
115
+
116
+ def reset_state() -> None:
117
+ if "text_selection" not in st.session_state:
118
+ st.session_state.text_selection = SelectedIndex(-1)
119
+ if "click_selection" not in st.session_state:
120
+ st.session_state.click_selection = SelectedIndex(-1)
121
+ if "text_query" not in st.session_state:
122
+ st.session_state.text_query = ""
123
+
124
+ # ─── App setup ────────────────────────────────────────────────────
125
+
126
+ st.set_page_config(layout="wide")
127
+ reset_state()
128
+ model, preprocessor = load_model()
129
+ file_map = {"kinetics700 (val)": "src/kinetics700_val.parquet", "opendv (val)": "src/opendrive_val.parquet"}
130
+ st.title("πŸ” Search with Cosmos-Embed1")
131
+
132
+ col1, col2 = st.columns([2,2])
133
+ with col1:
134
+ dataset = st.selectbox("Select dataset", list(file_map.keys()), on_change=reset_state)
135
+ df, index, embs = load_data(file_map[dataset])
136
+
137
+ # initialize session state
138
+ if "text_selection" not in st.session_state:
139
+ st.session_state.text_selection = SelectedIndex(-1)
140
+ if "click_selection" not in st.session_state:
141
+ st.session_state.click_selection = SelectedIndex(-1)
142
+ if "text_query" not in st.session_state:
143
+ st.session_state.text_query = ""
144
+
145
+ # ─── Layout ────────────────────────────────────────────────────────
146
+
147
+ # LEFT: scatter
148
+ with col1:
149
+ fig = px.scatter(
150
+ df, x="x", y="y",
151
+ hover_name="tar_key", hover_data=["cluster_id"],
152
+ color="cluster_id", color_continuous_scale="Turbo",
153
+ title="t-SNE projection (click to select)"
154
+ )
155
+ fig.update_layout(
156
+ dragmode="zoom",
157
+ margin=dict(l=5, r=5, t=40, b=5),
158
+ xaxis_title=None, yaxis_title=None,
159
+ coloraxis_colorbar=dict(title="")
160
+ )
161
+ fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False,
162
+ showline=True, linecolor="black", mirror=True)
163
+ fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False,
164
+ showline=True, linecolor="black", mirror=True)
165
+ fig.update_layout(annotations=[dict(
166
+ text="k-means cluster", xref="paper", yref="paper",
167
+ x=1.02, y=0.5, textangle=90, showarrow=False
168
+ )])
169
+
170
+ most_recent_idx, most_recent_method = get_most_recent_selection()
171
+ if most_recent_idx is not None and most_recent_method == "text":
172
+ x0, y0 = df.iloc[most_recent_idx][["x", "y"]]
173
+ span = 6.0
174
+ fig.update_layout(
175
+ xaxis_range=[x0 - span, x0 + span],
176
+ yaxis_range=[y0 - span, y0 + span],
177
+ transition={"duration": 1},
178
+ )
179
+
180
+ click_event = st.plotly_chart(
181
+ fig, use_container_width=True,
182
+ on_select="rerun", selection_mode="points"
183
+ )
184
+
185
+ # RIGHT: text input & preview
186
+ with col2:
187
+ if click_event and click_event.get("selection", {}).get("point_indices"):
188
+ curr_click = click_event["selection"]["point_indices"][0]
189
+ if curr_click != st.session_state.click_selection:
190
+ # new click so update the previous selection and wipe any text query
191
+ st.session_state.click_selection = SelectedIndex(curr_click)
192
+ st.session_state.text_query = ""
193
+
194
+ # text input (will pick up cleared or existing text)
195
+ text_query = st.text_input(
196
+ "Search via text",
197
+ key="text_query",
198
+ help="Type a query and press Enter"
199
+ )
200
+
201
+ # if user typed text (and pressed Enter), override selection
202
+ if text_query:
203
+ with torch.no_grad():
204
+ model_input = preprocessor(text=[text_query])
205
+ emb_out = model.get_text_embeddings(**model_input).text_proj.cpu().numpy()
206
+ idx_text, = get_nearest_ids(emb_out, k=1, ignore_self=False)
207
+ if st.session_state.text_selection != idx_text:
208
+ # new text so update the previous selection and wipe any text query
209
+ st.session_state.text_selection = SelectedIndex(idx_text)
210
+ st.rerun()
211
+
212
+ # main preview
213
+ preview_slot = st.empty()
214
+ most_recent, most_recent_modality = get_most_recent_selection()
215
+ if most_recent is not None:
216
+ preview_video(df, most_recent, preview_slot)
217
+ else:
218
+ preview_slot.write("⏳ Waiting for selection…")
219
 
220
+ # BOTTOM: 5 nearest neighbors
221
+ st.markdown("### 🎬 5 Closest Videos")
222
+ if most_recent is not None:
223
+ ignore_self = most_recent_modality == "click"
224
+ nn_ids = get_nearest_ids(embs[most_recent], k=5, ignore_self=ignore_self)
225
+ cols = st.columns(5)
226
+ for c, nid in zip(cols, nn_ids):
227
+ preview_video(df, nid, c, height=180, margin_top=5, autoplay=False)
228
+ else:
229
+ st.write("Use a click or a text query above to list neighbors.")