Shakshi3104 commited on
Commit
270a1bc
·
1 Parent(s): 9c0c936

[fix] implement vector search using DuckDB

Browse files
Files changed (2) hide show
  1. model/search/vector.py +21 -13
  2. requirements.txt +2 -0
model/search/vector.py CHANGED
@@ -13,7 +13,7 @@ from tqdm import tqdm
13
 
14
  import sentence_transformers as st
15
 
16
- import voyager
17
 
18
  from model.search.base import BaseSearchClient
19
  from model.utils.timer import stop_watch
@@ -80,7 +80,7 @@ class RuriEmbedder:
80
 
81
  class RuriVoyagerSearchClient(BaseSearchClient):
82
  def __init__(self, dataset: pd.DataFrame, target: str,
83
- index: voyager.Index,
84
  model: RuriEmbedder):
85
  load_dotenv()
86
  # オリジナルのコーパス
@@ -90,8 +90,8 @@ class RuriVoyagerSearchClient(BaseSearchClient):
90
  # 埋め込みモデル
91
  self.embedder = model
92
 
93
- # Voyagerインデックス
94
- self.index = index
95
 
96
  @classmethod
97
  @stop_watch
@@ -129,12 +129,12 @@ class RuriVoyagerSearchClient(BaseSearchClient):
129
  num_dim = embeddings.shape[1]
130
  logger.debug(f"🚦⚓️ [RuriVoyagerSearchClient] Number of dimensions of Embedding vector is {num_dim}")
131
 
132
- # Voyagerのインデックスを初期化
133
- index = voyager.Index(voyager.Space.Cosine, num_dimensions=num_dim)
134
- # indexにベクトルを追加
135
- _ = index.add_items(embeddings)
136
 
137
- return cls(_data, _target, index, embedder)
138
 
139
  @stop_watch
140
  def search_top_n(self, _query: Union[List[str], str], n: int = 10) -> List[pd.DataFrame]:
@@ -169,11 +169,19 @@ class RuriVoyagerSearchClient(BaseSearchClient):
169
  # ランキングtop-nをクエリ毎に取得
170
  result = []
171
  for embeddings_query in tqdm(embeddings_queries):
172
- # Voyagerのインデックスを探索
173
- neighbors_indices, distances = self.index.query(embeddings_query, k=n)
 
 
 
 
 
 
 
 
174
  # 類似度スコア
175
- df_res = deepcopy(self.dataset.iloc[neighbors_indices])
176
- df_res["score"] = distances
177
  # ランク
178
  df_res["rank"] = deepcopy(df_res.reset_index()).index
179
 
 
13
 
14
  import sentence_transformers as st
15
 
16
+ import duckdb
17
 
18
  from model.search.base import BaseSearchClient
19
  from model.utils.timer import stop_watch
 
80
 
81
  class RuriVoyagerSearchClient(BaseSearchClient):
82
  def __init__(self, dataset: pd.DataFrame, target: str,
83
+ vector_store_name: str,
84
  model: RuriEmbedder):
85
  load_dotenv()
86
  # オリジナルのコーパス
 
90
  # 埋め込みモデル
91
  self.embedder = model
92
 
93
+ # DuckDBのテーブル名
94
+ self.vector_store_name = vector_store_name
95
 
96
  @classmethod
97
  @stop_watch
 
129
  num_dim = embeddings.shape[1]
130
  logger.debug(f"🚦⚓️ [RuriVoyagerSearchClient] Number of dimensions of Embedding vector is {num_dim}")
131
 
132
+ # DuckDBに挿入
133
+ vector_store_name = "ruri_vector_index"
134
+ vdb = pd.DataFrame({"index": range(len(embeddings)), "embedding": embeddings.tolist()})
135
+ duckdb.register(vector_store_name, vdb)
136
 
137
+ return cls(_data, _target, vector_store_name,embedder)
138
 
139
  @stop_watch
140
  def search_top_n(self, _query: Union[List[str], str], n: int = 10) -> List[pd.DataFrame]:
 
169
  # ランキングtop-nをクエリ毎に取得
170
  result = []
171
  for embeddings_query in tqdm(embeddings_queries):
172
+ num_dim = len(embeddings_query)
173
+ distance = duckdb.sql(f"""
174
+ select
175
+ index,
176
+ array_cosine_distance(embedding::DOUBLE[{num_dim}], {embeddings_query.tolist()}::DOUBLE[{num_dim}]) as distance
177
+ from {self.vector_store_name}
178
+ order by distance
179
+ limit {n}
180
+ """).df()
181
+
182
  # 類似度スコア
183
+ df_res = deepcopy(self.dataset.iloc[distance["index"].tolist()])
184
+ df_res["score"] = distance["distance"].tolist()
185
  # ランク
186
  df_res["rank"] = deepcopy(df_res.reset_index()).index
187
 
requirements.txt CHANGED
@@ -10,6 +10,7 @@ python-dotenv
10
 
11
  # Visualization
12
  gradio
 
13
 
14
  tqdm>=4.65
15
  matplotlib>=3.7
@@ -30,6 +31,7 @@ emoji>=2.6.0
30
  # search
31
  rank_bm25
32
  voyager
 
33
 
34
  # NLP
35
  mecab-python3
 
10
 
11
  # Visualization
12
  gradio
13
+ streamlit
14
 
15
  tqdm>=4.65
16
  matplotlib>=3.7
 
31
  # search
32
  rank_bm25
33
  voyager
34
+ duckdb
35
 
36
  # NLP
37
  mecab-python3