Fangrui Liu commited on
Commit
aee10cf
·
1 Parent(s): 0b449a5

refined layout

Browse files
Files changed (1) hide show
  1. app.py +27 -17
app.py CHANGED
@@ -258,6 +258,7 @@ def init_clip_mlang():
258
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
259
  return tokenizer, clip
260
 
 
261
  @st.experimental_singleton(show_spinner=False)
262
  def init_clip_vanilla():
263
  """ Initialize CLIP Model
@@ -297,11 +298,13 @@ def prompt2vec_mlang(prompt: str, tokenizer, clip):
297
  xq = out.squeeze(0).cpu().detach().numpy().tolist()
298
  return xq
299
 
 
300
  def prompt2vec_vanilla(prompt: str, tokenizer, clip):
301
  inputs = tokenizer(prompt, return_tensors='pt')
302
  out = clip.get_text_features(**inputs)
303
  xq = out.squeeze(0).cpu().detach().numpy().tolist()
304
- return xq
 
305
 
306
  st.markdown("""
307
  <link
@@ -345,7 +348,7 @@ text_model_map = {
345
  'English': {'Vanilla CLIP': [prompt2vec_vanilla, ],
346
  'CLIP finetuned on RSICD': [prompt2vec_vanilla, ],
347
  }
348
- }
349
 
350
 
351
  with st.spinner("Connecting DB..."):
@@ -354,9 +357,11 @@ with st.spinner("Connecting DB..."):
354
  with st.spinner("Loading Models..."):
355
  # Initialize CLIP model
356
  if 'xq' not in st.session_state:
357
- text_model_map['Multi Lingual']['Vanilla CLIP'].append(init_clip_mlang())
 
358
  text_model_map['English']['Vanilla CLIP'].append(init_clip_vanilla())
359
- text_model_map['English']['CLIP finetuned on RSICD'].append(init_clip_rsicd())
 
360
  st.session_state.query_num = 0
361
 
362
  if 'xq' not in st.session_state:
@@ -372,30 +377,34 @@ if 'xq' not in st.session_state:
372
  del st.session_state.prompt
373
  st.title("Visual Dataset Explorer")
374
  start = [st.empty(), st.empty(), st.empty(), st.empty(),
375
- st.empty(), st.empty(), st.empty()]
376
  start[0].info(msg)
377
  start_col = start[1].columns(3)
378
- st.session_state.db_name_ref = start_col[0].selectbox("Select Database:", list(db_name_map.keys()))
379
- st.session_state.lang = start_col[1].selectbox("Select Language:", list(text_model_map.keys()))
380
- st.session_state.feat_name = start_col[2].selectbox("Select Image Feature:",
 
 
381
  list(text_model_map[st.session_state.lang].keys()))
382
  if st.session_state.db_name_ref == "RSICD: Remote Sensing Images 11K":
383
- st.warning('If you are searching for Remote Sensing Images, \
384
  try to use prompt "An aerial photograph of <your-real-query>" \
385
  to obtain best search experience!')
386
- prompt = start[2].text_input(
387
- "Prompt:", value="", placeholder="Examples: playing corgi, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...")
388
  if len(prompt) > 0:
389
  st.session_state.prompt = prompt.replace(' ', '_')
390
- start[3].markdown(
391
  '<p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>\
392
  <p>🌟 We also support multi-language search. Type any language you know to search! ⌨️ </p>',
393
  unsafe_allow_html=True)
394
- upld_model = start[5].file_uploader(
395
  "Or you can upload your previous run!", type='onnx')
396
- upld_btn = start[6].button(
397
- "Used Loaded Weights", disabled=upld_model is None)
398
- with start[4]:
 
 
 
 
399
  col = st.columns(8)
400
  has_no_prompt = (len(prompt) == 0 and upld_model is None)
401
  prompt_xq = col[6].button("Prompt", disabled=len(prompt) == 0)
@@ -418,7 +427,8 @@ if 'xq' not in st.session_state:
418
  assert len(weights) == 1
419
  xq = numpy_helper.to_array(weights[0]).tolist()
420
  assert len(xq) == DIMS
421
- st.session_state.prompt = upld_model.name.split(".onnx")[0].replace(' ', '_')
 
422
  else:
423
  print(f"Input prompt is {prompt}")
424
  # Tokenize the vectors
 
258
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
259
  return tokenizer, clip
260
 
261
+
262
  @st.experimental_singleton(show_spinner=False)
263
  def init_clip_vanilla():
264
  """ Initialize CLIP Model
 
298
  xq = out.squeeze(0).cpu().detach().numpy().tolist()
299
  return xq
300
 
301
+
302
  def prompt2vec_vanilla(prompt: str, tokenizer, clip):
303
  inputs = tokenizer(prompt, return_tensors='pt')
304
  out = clip.get_text_features(**inputs)
305
  xq = out.squeeze(0).cpu().detach().numpy().tolist()
306
+ return xq
307
+
308
 
309
  st.markdown("""
310
  <link
 
348
  'English': {'Vanilla CLIP': [prompt2vec_vanilla, ],
349
  'CLIP finetuned on RSICD': [prompt2vec_vanilla, ],
350
  }
351
+ }
352
 
353
 
354
  with st.spinner("Connecting DB..."):
 
357
  with st.spinner("Loading Models..."):
358
  # Initialize CLIP model
359
  if 'xq' not in st.session_state:
360
+ text_model_map['Multi Lingual']['Vanilla CLIP'].append(
361
+ init_clip_mlang())
362
  text_model_map['English']['Vanilla CLIP'].append(init_clip_vanilla())
363
+ text_model_map['English']['CLIP finetuned on RSICD'].append(
364
+ init_clip_rsicd())
365
  st.session_state.query_num = 0
366
 
367
  if 'xq' not in st.session_state:
 
377
  del st.session_state.prompt
378
  st.title("Visual Dataset Explorer")
379
  start = [st.empty(), st.empty(), st.empty(), st.empty(),
380
+ st.empty(), st.empty(), st.empty(), st.empty()]
381
  start[0].info(msg)
382
  start_col = start[1].columns(3)
383
+ st.session_state.db_name_ref = start_col[0].selectbox(
384
+ "Select Database:", list(db_name_map.keys()))
385
+ st.session_state.lang = start_col[1].selectbox(
386
+ "Select Language:", list(text_model_map.keys()))
387
+ st.session_state.feat_name = start_col[2].selectbox("Select Image Feature:",
388
  list(text_model_map[st.session_state.lang].keys()))
389
  if st.session_state.db_name_ref == "RSICD: Remote Sensing Images 11K":
390
+ start[2].warning('If you are searching for Remote Sensing Images, \
391
  try to use prompt "An aerial photograph of <your-real-query>" \
392
  to obtain best search experience!')
 
 
393
  if len(prompt) > 0:
394
  st.session_state.prompt = prompt.replace(' ', '_')
395
+ start[4].markdown(
396
  '<p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>\
397
  <p>🌟 We also support multi-language search. Type any language you know to search! ⌨️ </p>',
398
  unsafe_allow_html=True)
399
+ upld_model = start[6].file_uploader(
400
  "Or you can upload your previous run!", type='onnx')
401
+ upld_btn = start[7].button(
402
+ "Use Loaded Weights", disabled=upld_model is None)
403
+ prompt = start[3].text_input(
404
+ "Prompt:",
405
+ value="An aerial photograph of "if st.session_state.db_name_ref == "RSICD: Remote Sensing Images 11K" else "",
406
+ placeholder="Examples: playing corgi, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...",)
407
+ with start[5]:
408
  col = st.columns(8)
409
  has_no_prompt = (len(prompt) == 0 and upld_model is None)
410
  prompt_xq = col[6].button("Prompt", disabled=len(prompt) == 0)
 
427
  assert len(weights) == 1
428
  xq = numpy_helper.to_array(weights[0]).tolist()
429
  assert len(xq) == DIMS
430
+ st.session_state.prompt = upld_model.name.split(".onnx")[
431
+ 0].replace(' ', '_')
432
  else:
433
  print(f"Input prompt is {prompt}")
434
  # Tokenize the vectors