Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -66,17 +66,6 @@ def retrieval_results(results):
|
|
66 |
|
67 |
|
68 |
|
69 |
-
def demo_classification():
|
70 |
-
with st.form("clsform"):
|
71 |
-
#load_data = misc_utils.input_3d_shape('cls')
|
72 |
-
cats = st.text_input("Custom Categories (64 max, separated with comma)")
|
73 |
-
cats = [a.strip() for a in cats.split(',')]
|
74 |
-
if len(cats) > 64:
|
75 |
-
st.error('Maximum 64 custom categories supported in the demo')
|
76 |
-
return
|
77 |
-
lvis_run = st.form_submit_button("Run Classification on LVIS Categories")
|
78 |
-
custom_run = st.form_submit_button("Run Classification on Custom Categories")
|
79 |
-
|
80 |
def demo_captioning():
|
81 |
with st.form("capform"):
|
82 |
cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0, 0.1, key='capcondscl')
|
@@ -146,6 +135,43 @@ def demo_retrieval():
|
|
146 |
prog.progress(1.0, "Idle")
|
147 |
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
def retrieval_pc(load_data, k, sim_th, filter_fn):
|
150 |
pc = load_data(prog)
|
151 |
prog.progress(0.49, "Computing Embeddings")
|
@@ -208,21 +234,21 @@ try:
|
|
208 |
'Choose the source of categories',
|
209 |
("LVIS Categories", "Custom Categories")
|
210 |
)
|
211 |
-
|
212 |
if cls_mode == "LVIS Categories":
|
|
|
|
|
213 |
if st.sidebar.button("submit"):
|
214 |
-
|
215 |
-
prog = st.progress(0.0, "Idle")
|
216 |
-
|
217 |
elif cls_mode == "Custom Categories":
|
|
|
|
|
218 |
cats = st.sidebar.text_input("Custom Categories (64 max, separated with comma)")
|
219 |
cats = [a.strip() for a in cats.split(',')]
|
220 |
if len(cats) > 64:
|
221 |
st.error('Maximum 64 custom categories supported in the demo')
|
222 |
if st.sidebar.button("submit"):
|
223 |
-
|
224 |
-
prog = st.progress(0.0, "Idle")
|
225 |
-
|
226 |
elif task == "Cross-modal retrieval":
|
227 |
input_mode = st.sidebar.selectbox(
|
228 |
'Choose an input modality',
|
@@ -231,22 +257,22 @@ try:
|
|
231 |
k = st.sidebar.slider("Number of items to retrieve", 1, 100, 16, key='rnum')
|
232 |
sim_th, filter_fn = retrieval_filter_expand()
|
233 |
if input_mode == "Point Cloud":
|
|
|
|
|
234 |
load_data = utils.input_3d_shape('rpcinput')
|
235 |
if st.sidebar.button("submit"):
|
236 |
-
st.title("Retrieval with Point Cloud")
|
237 |
-
prog = st.progress(0.0, "Idle")
|
238 |
retrieval_pc(load_data, k, sim_th, filter_fn)
|
239 |
elif input_mode == "Image":
|
|
|
|
|
240 |
pic = st.sidebar.file_uploader("Upload an Image", key='rimageinput')
|
241 |
if st.sidebar.button("submit"):
|
242 |
-
st.title("Retrieval with Image")
|
243 |
-
prog = st.progress(0.0, "Idle")
|
244 |
retrieval_img(pic, k, sim_th, filter_fn)
|
245 |
elif input_mode == "Text":
|
|
|
|
|
246 |
text = st.sidebar.text_input("Input Text", key='rtextinput')
|
247 |
if st.sidebar.button("submit"):
|
248 |
-
st.title("Retrieval with Text")
|
249 |
-
prog = st.progress(0.0, "Idle")
|
250 |
retrieval_text(text, k, sim_th, filter_fn)
|
251 |
elif task == "Cross-modal generation":
|
252 |
generation_mode = st.sidebar.selectbox(
|
@@ -255,14 +281,15 @@ try:
|
|
255 |
)
|
256 |
pc = st.sidebar.text_input("Input pc", key='rtextinput')
|
257 |
if generation_mode == "PointCloud-to-Image":
|
|
|
|
|
258 |
if st.sidebar.button("submit"):
|
259 |
-
st.
|
260 |
-
prog = st.progress(0.0, "Idle")
|
261 |
-
|
262 |
elif generation_mode == "PointCloud-to-Text":
|
|
|
|
|
263 |
if st.sidebar.button("submit"):
|
264 |
-
st.
|
265 |
-
prog = st.progress(0.0, "Idle")
|
266 |
|
267 |
except Exception:
|
268 |
import traceback
|
|
|
66 |
|
67 |
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
def demo_captioning():
|
70 |
with st.form("capform"):
|
71 |
cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0, 0.1, key='capcondscl')
|
|
|
135 |
prog.progress(1.0, "Idle")
|
136 |
|
137 |
|
138 |
+
def classification_lvis(load_data):
|
139 |
+
pc = load_data(prog)
|
140 |
+
col2 = utils.render_pc(pc)
|
141 |
+
prog.progress(0.5, "Running Classification")
|
142 |
+
ref_dev = next(model_g14.parameters()).device
|
143 |
+
enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
|
144 |
+
|
145 |
+
sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc, dim=-1).squeeze())
|
146 |
+
argsort = torch.argsort(sim, descending=True)
|
147 |
+
pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
|
148 |
+
with col2:
|
149 |
+
for i, (cat, sim) in zip(range(5), pred.items()):
|
150 |
+
st.text(cat)
|
151 |
+
st.caption("Similarity %.4f" % sim)
|
152 |
+
prog.progress(1.0, "Idle")
|
153 |
+
|
154 |
+
def classification_custom(load_data, cats):
|
155 |
+
pc = load_data(prog)
|
156 |
+
col2 = utils.render_pc(pc)
|
157 |
+
prog.progress(0.5, "Computing Category Embeddings")
|
158 |
+
device = clip_model.device
|
159 |
+
tn = clip_prep(text=cats, return_tensors='pt', truncation=True, max_length=76, padding=True).to(device)
|
160 |
+
feats = clip_model.get_text_features(**tn).float().cpu()
|
161 |
+
|
162 |
+
prog.progress(0.5, "Running Classification")
|
163 |
+
ref_dev = next(model_g14.parameters()).device
|
164 |
+
enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
|
165 |
+
sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc, dim=-1).squeeze())
|
166 |
+
argsort = torch.argsort(sim, descending=True)
|
167 |
+
pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
|
168 |
+
with col2:
|
169 |
+
for i, (cat, sim) in zip(range(5), pred.items()):
|
170 |
+
st.text(cat)
|
171 |
+
st.caption("Similarity %.4f" % sim)
|
172 |
+
prog.progress(1.0, "Idle")
|
173 |
+
|
174 |
+
|
175 |
def retrieval_pc(load_data, k, sim_th, filter_fn):
|
176 |
pc = load_data(prog)
|
177 |
prog.progress(0.49, "Computing Embeddings")
|
|
|
234 |
'Choose the source of categories',
|
235 |
("LVIS Categories", "Custom Categories")
|
236 |
)
|
237 |
+
load_data = utils.input_3d_shape('rpcinput')
|
238 |
if cls_mode == "LVIS Categories":
|
239 |
+
st.title("Classification with LVIS Categories")
|
240 |
+
prog = st.progress(0.0, "Idle")
|
241 |
if st.sidebar.button("submit"):
|
242 |
+
classification_lvis(load_data)
|
|
|
|
|
243 |
elif cls_mode == "Custom Categories":
|
244 |
+
st.title("Classification with Custom Categories")
|
245 |
+
prog = st.progress(0.0, "Idle")
|
246 |
cats = st.sidebar.text_input("Custom Categories (64 max, separated with comma)")
|
247 |
cats = [a.strip() for a in cats.split(',')]
|
248 |
if len(cats) > 64:
|
249 |
st.error('Maximum 64 custom categories supported in the demo')
|
250 |
if st.sidebar.button("submit"):
|
251 |
+
classification_custom(load_data, cats)
|
|
|
|
|
252 |
elif task == "Cross-modal retrieval":
|
253 |
input_mode = st.sidebar.selectbox(
|
254 |
'Choose an input modality',
|
|
|
257 |
k = st.sidebar.slider("Number of items to retrieve", 1, 100, 16, key='rnum')
|
258 |
sim_th, filter_fn = retrieval_filter_expand()
|
259 |
if input_mode == "Point Cloud":
|
260 |
+
st.title("Retrieval with Point Cloud")
|
261 |
+
prog = st.progress(0.0, "Idle")
|
262 |
load_data = utils.input_3d_shape('rpcinput')
|
263 |
if st.sidebar.button("submit"):
|
|
|
|
|
264 |
retrieval_pc(load_data, k, sim_th, filter_fn)
|
265 |
elif input_mode == "Image":
|
266 |
+
st.title("Retrieval with Image")
|
267 |
+
prog = st.progress(0.0, "Idle")
|
268 |
pic = st.sidebar.file_uploader("Upload an Image", key='rimageinput')
|
269 |
if st.sidebar.button("submit"):
|
|
|
|
|
270 |
retrieval_img(pic, k, sim_th, filter_fn)
|
271 |
elif input_mode == "Text":
|
272 |
+
st.title("Retrieval with Text")
|
273 |
+
prog = st.progress(0.0, "Idle")
|
274 |
text = st.sidebar.text_input("Input Text", key='rtextinput')
|
275 |
if st.sidebar.button("submit"):
|
|
|
|
|
276 |
retrieval_text(text, k, sim_th, filter_fn)
|
277 |
elif task == "Cross-modal generation":
|
278 |
generation_mode = st.sidebar.selectbox(
|
|
|
281 |
)
|
282 |
pc = st.sidebar.text_input("Input pc", key='rtextinput')
|
283 |
if generation_mode == "PointCloud-to-Image":
|
284 |
+
st.title("Image Generation")
|
285 |
+
prog = st.progress(0.0, "Idle")
|
286 |
if st.sidebar.button("submit"):
|
287 |
+
pc = st.text_input("Input pc", key='rtextinput')
|
|
|
|
|
288 |
elif generation_mode == "PointCloud-to-Text":
|
289 |
+
st.title("Text Generation")
|
290 |
+
prog = st.progress(0.0, "Idle")
|
291 |
if st.sidebar.button("submit"):
|
292 |
+
pc = st.text_input("Input pc", key='rtextinput')
|
|
|
293 |
|
294 |
except Exception:
|
295 |
import traceback
|