Fangrui Liu commited on
Commit
eed1a5c
·
1 Parent(s): 05893e6

switch to new ch client

Browse files
Files changed (2) hide show
  1. app.py +13 -12
  2. requirements.txt +2 -1
app.py CHANGED
@@ -7,7 +7,8 @@ from transformers import CLIPTokenizerFast, AutoTokenizer, CLIPModel
7
  import torch
8
  import logging
9
  from os import environ
10
- from myscaledb import Client
 
11
  environ['TOKENIZERS_PARALLELISM'] = 'true'
12
 
13
 
@@ -36,10 +37,10 @@ def init_db():
36
  meta_field: Meta field that records if an image is viewed or not
37
  client: Database connection object
38
  """
39
- client = Client(
40
- url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"])
41
- # We can check if the connection is alive
42
- assert client.is_alive()
43
  meta_field = {}
44
  return meta_field, client
45
 
@@ -72,18 +73,18 @@ def query(xq, top_k=10):
72
  [f'\'{i}\'' for i, v in st.session_state.meta.items() if v >= 1])
73
  print("Excluded:", exclude_list)
74
  # Using PREWHERE allows you to do column filter before vector search
75
- xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
76
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
77
  FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])} \
78
- PREWHERE id NOT IN ({exclude_list})")
79
  else:
80
- xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
81
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
82
- FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])}")
83
- real_xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
84
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
85
- FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])}")
86
- top_k = real_xc
87
  xc = [xi for xi in xc if xi['id'] not in st.session_state.meta or
88
  st.session_state.meta[xi['id']] < 1]
89
  logging.info(
 
7
  import torch
8
  import logging
9
  from os import environ
10
+ from parse import parse
11
+ from clickhouse_connect import get_client
12
  environ['TOKENIZERS_PARALLELISM'] = 'true'
13
 
14
 
 
37
  meta_field: Meta field that records if an image is viewed or not
38
  client: Database connection object
39
  """
40
+ r = parse("{http_pre}://{host}:{port}", st.secrets["DB_URL"])
41
+ client = get_client(
42
+ host=r['host'], port=r['port'], user=st.secrets["USER"], password=st.secrets["PASSWD"]
43
+ )
44
  meta_field = {}
45
  return meta_field, client
46
 
 
73
  [f'\'{i}\'' for i, v in st.session_state.meta.items() if v >= 1])
74
  print("Excluded:", exclude_list)
75
  # Using PREWHERE allows you to do column filter before vector search
76
+ xc = st.session_state.index.query(f"SELECT id, url, vector,\
77
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
78
  FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])} \
79
+ PREWHERE id NOT IN ({exclude_list})").named_results()
80
  else:
81
+ xc = st.session_state.index.query(f"SELECT id, url, vector,\
82
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
83
+ FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])}").named_results()
84
+ real_xc = st.session_state.index.query(f"SELECT id, url, vector,\
85
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
86
+ FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])}").named_results()
87
+ top_k = [{k: v for k, v in r.items()} for r in real_xc]
88
  xc = [xi for xi in xc if xi['id'] not in st.session_state.meta or
89
  st.session_state.meta[xi['id']] < 1]
90
  logging.info(
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  transformers
2
  tqdm
3
- myscaledb-client==1.1.7
 
4
  streamlit
5
  multilingual-clip
6
  numpy
 
1
  transformers
2
  tqdm
3
+ parse
4
+ clickhouse-connect
5
  streamlit
6
  multilingual-clip
7
  numpy