Spaces:
Sleeping
Sleeping
Fangrui Liu
commited on
Commit
·
aee10cf
1
Parent(s):
0b449a5
refined layout
Browse files
app.py
CHANGED
@@ -258,6 +258,7 @@ def init_clip_mlang():
|
|
258 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
259 |
return tokenizer, clip
|
260 |
|
|
|
261 |
@st.experimental_singleton(show_spinner=False)
|
262 |
def init_clip_vanilla():
|
263 |
""" Initialize CLIP Model
|
@@ -297,11 +298,13 @@ def prompt2vec_mlang(prompt: str, tokenizer, clip):
|
|
297 |
xq = out.squeeze(0).cpu().detach().numpy().tolist()
|
298 |
return xq
|
299 |
|
|
|
300 |
def prompt2vec_vanilla(prompt: str, tokenizer, clip):
|
301 |
inputs = tokenizer(prompt, return_tensors='pt')
|
302 |
out = clip.get_text_features(**inputs)
|
303 |
xq = out.squeeze(0).cpu().detach().numpy().tolist()
|
304 |
-
return xq
|
|
|
305 |
|
306 |
st.markdown("""
|
307 |
<link
|
@@ -345,7 +348,7 @@ text_model_map = {
|
|
345 |
'English': {'Vanilla CLIP': [prompt2vec_vanilla, ],
|
346 |
'CLIP finetuned on RSICD': [prompt2vec_vanilla, ],
|
347 |
}
|
348 |
-
|
349 |
|
350 |
|
351 |
with st.spinner("Connecting DB..."):
|
@@ -354,9 +357,11 @@ with st.spinner("Connecting DB..."):
|
|
354 |
with st.spinner("Loading Models..."):
|
355 |
# Initialize CLIP model
|
356 |
if 'xq' not in st.session_state:
|
357 |
-
text_model_map['Multi Lingual']['Vanilla CLIP'].append(
|
|
|
358 |
text_model_map['English']['Vanilla CLIP'].append(init_clip_vanilla())
|
359 |
-
text_model_map['English']['CLIP finetuned on RSICD'].append(
|
|
|
360 |
st.session_state.query_num = 0
|
361 |
|
362 |
if 'xq' not in st.session_state:
|
@@ -372,30 +377,34 @@ if 'xq' not in st.session_state:
|
|
372 |
del st.session_state.prompt
|
373 |
st.title("Visual Dataset Explorer")
|
374 |
start = [st.empty(), st.empty(), st.empty(), st.empty(),
|
375 |
-
st.empty(), st.empty(), st.empty()]
|
376 |
start[0].info(msg)
|
377 |
start_col = start[1].columns(3)
|
378 |
-
st.session_state.db_name_ref = start_col[0].selectbox(
|
379 |
-
|
380 |
-
st.session_state.
|
|
|
|
|
381 |
list(text_model_map[st.session_state.lang].keys()))
|
382 |
if st.session_state.db_name_ref == "RSICD: Remote Sensing Images 11K":
|
383 |
-
|
384 |
try to use prompt "An aerial photograph of <your-real-query>" \
|
385 |
to obtain best search experience!')
|
386 |
-
prompt = start[2].text_input(
|
387 |
-
"Prompt:", value="", placeholder="Examples: playing corgi, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...")
|
388 |
if len(prompt) > 0:
|
389 |
st.session_state.prompt = prompt.replace(' ', '_')
|
390 |
-
start[
|
391 |
'<p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>\
|
392 |
<p>🌟 We also support multi-language search. Type any language you know to search! ⌨️ </p>',
|
393 |
unsafe_allow_html=True)
|
394 |
-
upld_model = start[
|
395 |
"Or you can upload your previous run!", type='onnx')
|
396 |
-
upld_btn = start[
|
397 |
-
"
|
398 |
-
|
|
|
|
|
|
|
|
|
399 |
col = st.columns(8)
|
400 |
has_no_prompt = (len(prompt) == 0 and upld_model is None)
|
401 |
prompt_xq = col[6].button("Prompt", disabled=len(prompt) == 0)
|
@@ -418,7 +427,8 @@ if 'xq' not in st.session_state:
|
|
418 |
assert len(weights) == 1
|
419 |
xq = numpy_helper.to_array(weights[0]).tolist()
|
420 |
assert len(xq) == DIMS
|
421 |
-
st.session_state.prompt = upld_model.name.split(".onnx")[
|
|
|
422 |
else:
|
423 |
print(f"Input prompt is {prompt}")
|
424 |
# Tokenize the vectors
|
|
|
258 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
259 |
return tokenizer, clip
|
260 |
|
261 |
+
|
262 |
@st.experimental_singleton(show_spinner=False)
|
263 |
def init_clip_vanilla():
|
264 |
""" Initialize CLIP Model
|
|
|
298 |
xq = out.squeeze(0).cpu().detach().numpy().tolist()
|
299 |
return xq
|
300 |
|
301 |
+
|
302 |
def prompt2vec_vanilla(prompt: str, tokenizer, clip):
|
303 |
inputs = tokenizer(prompt, return_tensors='pt')
|
304 |
out = clip.get_text_features(**inputs)
|
305 |
xq = out.squeeze(0).cpu().detach().numpy().tolist()
|
306 |
+
return xq
|
307 |
+
|
308 |
|
309 |
st.markdown("""
|
310 |
<link
|
|
|
348 |
'English': {'Vanilla CLIP': [prompt2vec_vanilla, ],
|
349 |
'CLIP finetuned on RSICD': [prompt2vec_vanilla, ],
|
350 |
}
|
351 |
+
}
|
352 |
|
353 |
|
354 |
with st.spinner("Connecting DB..."):
|
|
|
357 |
with st.spinner("Loading Models..."):
|
358 |
# Initialize CLIP model
|
359 |
if 'xq' not in st.session_state:
|
360 |
+
text_model_map['Multi Lingual']['Vanilla CLIP'].append(
|
361 |
+
init_clip_mlang())
|
362 |
text_model_map['English']['Vanilla CLIP'].append(init_clip_vanilla())
|
363 |
+
text_model_map['English']['CLIP finetuned on RSICD'].append(
|
364 |
+
init_clip_rsicd())
|
365 |
st.session_state.query_num = 0
|
366 |
|
367 |
if 'xq' not in st.session_state:
|
|
|
377 |
del st.session_state.prompt
|
378 |
st.title("Visual Dataset Explorer")
|
379 |
start = [st.empty(), st.empty(), st.empty(), st.empty(),
|
380 |
+
st.empty(), st.empty(), st.empty(), st.empty()]
|
381 |
start[0].info(msg)
|
382 |
start_col = start[1].columns(3)
|
383 |
+
st.session_state.db_name_ref = start_col[0].selectbox(
|
384 |
+
"Select Database:", list(db_name_map.keys()))
|
385 |
+
st.session_state.lang = start_col[1].selectbox(
|
386 |
+
"Select Language:", list(text_model_map.keys()))
|
387 |
+
st.session_state.feat_name = start_col[2].selectbox("Select Image Feature:",
|
388 |
list(text_model_map[st.session_state.lang].keys()))
|
389 |
if st.session_state.db_name_ref == "RSICD: Remote Sensing Images 11K":
|
390 |
+
start[2].warning('If you are searching for Remote Sensing Images, \
|
391 |
try to use prompt "An aerial photograph of <your-real-query>" \
|
392 |
to obtain best search experience!')
|
|
|
|
|
393 |
if len(prompt) > 0:
|
394 |
st.session_state.prompt = prompt.replace(' ', '_')
|
395 |
+
start[4].markdown(
|
396 |
'<p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>\
|
397 |
<p>🌟 We also support multi-language search. Type any language you know to search! ⌨️ </p>',
|
398 |
unsafe_allow_html=True)
|
399 |
+
upld_model = start[6].file_uploader(
|
400 |
"Or you can upload your previous run!", type='onnx')
|
401 |
+
upld_btn = start[7].button(
|
402 |
+
"Use Loaded Weights", disabled=upld_model is None)
|
403 |
+
prompt = start[3].text_input(
|
404 |
+
"Prompt:",
|
405 |
+
value="An aerial photograph of "if st.session_state.db_name_ref == "RSICD: Remote Sensing Images 11K" else "",
|
406 |
+
placeholder="Examples: playing corgi, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...",)
|
407 |
+
with start[5]:
|
408 |
col = st.columns(8)
|
409 |
has_no_prompt = (len(prompt) == 0 and upld_model is None)
|
410 |
prompt_xq = col[6].button("Prompt", disabled=len(prompt) == 0)
|
|
|
427 |
assert len(weights) == 1
|
428 |
xq = numpy_helper.to_array(weights[0]).tolist()
|
429 |
assert len(xq) == DIMS
|
430 |
+
st.session_state.prompt = upld_model.name.split(".onnx")[
|
431 |
+
0].replace(' ', '_')
|
432 |
else:
|
433 |
print(f"Input prompt is {prompt}")
|
434 |
# Tokenize the vectors
|