Spaces:
Sleeping
Sleeping
Shakshi3104
commited on
Commit
·
270a1bc
1
Parent(s):
9c0c936
[fix] implement vector search using DuckDB
Browse files- model/search/vector.py +21 -13
- requirements.txt +2 -0
model/search/vector.py
CHANGED
@@ -13,7 +13,7 @@ from tqdm import tqdm
|
|
13 |
|
14 |
import sentence_transformers as st
|
15 |
|
16 |
-
import
|
17 |
|
18 |
from model.search.base import BaseSearchClient
|
19 |
from model.utils.timer import stop_watch
|
@@ -80,7 +80,7 @@ class RuriEmbedder:
|
|
80 |
|
81 |
class RuriVoyagerSearchClient(BaseSearchClient):
|
82 |
def __init__(self, dataset: pd.DataFrame, target: str,
|
83 |
-
|
84 |
model: RuriEmbedder):
|
85 |
load_dotenv()
|
86 |
# オリジナルのコーパス
|
@@ -90,8 +90,8 @@ class RuriVoyagerSearchClient(BaseSearchClient):
|
|
90 |
# 埋め込みモデル
|
91 |
self.embedder = model
|
92 |
|
93 |
-
#
|
94 |
-
self.
|
95 |
|
96 |
@classmethod
|
97 |
@stop_watch
|
@@ -129,12 +129,12 @@ class RuriVoyagerSearchClient(BaseSearchClient):
|
|
129 |
num_dim = embeddings.shape[1]
|
130 |
logger.debug(f"🚦⚓️ [RuriVoyagerSearchClient] Number of dimensions of Embedding vector is {num_dim}")
|
131 |
|
132 |
-
#
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
|
137 |
-
return cls(_data, _target,
|
138 |
|
139 |
@stop_watch
|
140 |
def search_top_n(self, _query: Union[List[str], str], n: int = 10) -> List[pd.DataFrame]:
|
@@ -169,11 +169,19 @@ class RuriVoyagerSearchClient(BaseSearchClient):
|
|
169 |
# ランキングtop-nをクエリ毎に取得
|
170 |
result = []
|
171 |
for embeddings_query in tqdm(embeddings_queries):
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
# 類似度スコア
|
175 |
-
df_res = deepcopy(self.dataset.iloc[
|
176 |
-
df_res["score"] =
|
177 |
# ランク
|
178 |
df_res["rank"] = deepcopy(df_res.reset_index()).index
|
179 |
|
|
|
13 |
|
14 |
import sentence_transformers as st
|
15 |
|
16 |
+
import duckdb
|
17 |
|
18 |
from model.search.base import BaseSearchClient
|
19 |
from model.utils.timer import stop_watch
|
|
|
80 |
|
81 |
class RuriVoyagerSearchClient(BaseSearchClient):
|
82 |
def __init__(self, dataset: pd.DataFrame, target: str,
|
83 |
+
vector_store_name: str,
|
84 |
model: RuriEmbedder):
|
85 |
load_dotenv()
|
86 |
# オリジナルのコーパス
|
|
|
90 |
# 埋め込みモデル
|
91 |
self.embedder = model
|
92 |
|
93 |
+
# DuckDBのテーブル名
|
94 |
+
self.vector_store_name = vector_store_name
|
95 |
|
96 |
@classmethod
|
97 |
@stop_watch
|
|
|
129 |
num_dim = embeddings.shape[1]
|
130 |
logger.debug(f"🚦⚓️ [RuriVoyagerSearchClient] Number of dimensions of Embedding vector is {num_dim}")
|
131 |
|
132 |
+
# DuckDBに挿入
|
133 |
+
vector_store_name = "ruri_vector_index"
|
134 |
+
vdb = pd.DataFrame({"index": range(len(embeddings)), "embedding": embeddings.tolist()})
|
135 |
+
duckdb.register(vector_store_name, vdb)
|
136 |
|
137 |
+
return cls(_data, _target, vector_store_name,embedder)
|
138 |
|
139 |
@stop_watch
|
140 |
def search_top_n(self, _query: Union[List[str], str], n: int = 10) -> List[pd.DataFrame]:
|
|
|
169 |
# ランキングtop-nをクエリ毎に取得
|
170 |
result = []
|
171 |
for embeddings_query in tqdm(embeddings_queries):
|
172 |
+
num_dim = len(embeddings_query)
|
173 |
+
distance = duckdb.sql(f"""
|
174 |
+
select
|
175 |
+
index,
|
176 |
+
array_cosine_distance(embedding::DOUBLE[{num_dim}], {embeddings_query.tolist()}::DOUBLE[{num_dim}]) as distance
|
177 |
+
from {self.vector_store_name}
|
178 |
+
order by distance
|
179 |
+
limit {n}
|
180 |
+
""").df()
|
181 |
+
|
182 |
# 類似度スコア
|
183 |
+
df_res = deepcopy(self.dataset.iloc[distance["index"].tolist()])
|
184 |
+
df_res["score"] = distance["distance"].tolist()
|
185 |
# ランク
|
186 |
df_res["rank"] = deepcopy(df_res.reset_index()).index
|
187 |
|
requirements.txt
CHANGED
@@ -10,6 +10,7 @@ python-dotenv
|
|
10 |
|
11 |
# Visualization
|
12 |
gradio
|
|
|
13 |
|
14 |
tqdm>=4.65
|
15 |
matplotlib>=3.7
|
@@ -30,6 +31,7 @@ emoji>=2.6.0
|
|
30 |
# search
|
31 |
rank_bm25
|
32 |
voyager
|
|
|
33 |
|
34 |
# NLP
|
35 |
mecab-python3
|
|
|
10 |
|
11 |
# Visualization
|
12 |
gradio
|
13 |
+
streamlit
|
14 |
|
15 |
tqdm>=4.65
|
16 |
matplotlib>=3.7
|
|
|
31 |
# search
|
32 |
rank_bm25
|
33 |
voyager
|
34 |
+
duckdb
|
35 |
|
36 |
# NLP
|
37 |
mecab-python3
|