Shakshi3104 commited on
Commit
3201f24
·
unverified ·
2 Parent(s): 6fff7f5 e725583

Merge pull request #1 from Shakshi3104/feature

Browse files
.gitignore CHANGED
@@ -5,6 +5,8 @@
5
  # Develop
6
  .venv/
7
  logs/
 
 
8
 
9
  # Default
10
  # Byte-compiled / optimized / DLL files
 
5
  # Develop
6
  .venv/
7
  logs/
8
+ data/
9
+ models/
10
 
11
  # Default
12
  # Byte-compiled / optimized / DLL files
cli_example.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ from model.search.hybrid import HybridSearchClient
4
+ from model.data.notion_db import fetch_sakurap_corpus
5
+
6
+
7
+ if __name__ == "__main__":
8
+ # Load dataset
9
+ sakurap_df = fetch_sakurap_corpus("./data/sakurap_corpus.csv")
10
+ # sakurap_df = pd.read_csv("./data/sakurap_corpus.csv")
11
+
12
+ # hybrid search
13
+ search_client = HybridSearchClient.from_dataframe(sakurap_df, "content")
14
+ results = search_client.search_top_n("嵐 5人の歴史")
model/data/notion_db.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import abc
4
+
5
+ import pandas as pd
6
+ from dotenv import load_dotenv
7
+
8
+ import notion_client as nt
9
+ from notion2md.exporter.block import StringExporter
10
+
11
+ from loguru import logger
12
+
13
+
14
+ class BaseNotionDatabase:
15
+ """
16
+ Notion DBからページのコンテンツを取り出すベースのクラス
17
+ """
18
+ def __init__(self):
19
+ load_dotenv()
20
+ self.notion_database_id = os.getenv("NOTION_DATABASE_ID")
21
+ self.integration_token = os.getenv("INTEGRATION_TOKEN")
22
+
23
+ # notion2mdの環境変数
24
+ os.environ["NOTION_TOKEN"] = os.getenv("INTEGRATION_TOKEN")
25
+
26
+ self.notion_client = nt.Client(auth=self.integration_token)
27
+
28
+ def load_database(self) -> list[dict]:
29
+ """
30
+ Notion DBのページ一覧を取得
31
+
32
+ Returns:
33
+
34
+ """
35
+ results = []
36
+ has_more = True
37
+ start_cursor = None
38
+
39
+ while has_more:
40
+ db = self.notion_client.databases.query(
41
+ **{
42
+ "database_id": self.notion_database_id,
43
+ "start_cursor": start_cursor
44
+ }
45
+ )
46
+ # 100件までしか1回に取得できない
47
+ # 100件以上ある場合 has_more = True
48
+ has_more = db["has_more"]
49
+ # 次のカーソル
50
+ start_cursor = db["next_cursor"]
51
+
52
+ # 取得結果
53
+ results += db["results"]
54
+
55
+ return results
56
+
57
+ @abc.abstractmethod
58
+ def load_content(self) -> list[dict]:
59
+ """
60
+ Notion DBのページの中身をdictで返す
61
+ Returns:
62
+
63
+ """
64
+ raise NotImplementedError
65
+
66
+
67
+ class SakurapDB(BaseNotionDatabase):
68
+ def load_database(self) -> list[dict]:
69
+ """
70
+ Notion DBのページ一覧を取得
71
+
72
+ Returns:
73
+ results:
74
+ list[dict]
75
+
76
+ """
77
+ results = []
78
+ has_more = True
79
+ start_cursor = None
80
+
81
+ while has_more:
82
+ # "Rap詞 : 櫻井翔"がTrueのもののみ取得
83
+ db = self.notion_client.databases.query(
84
+ **{
85
+ "database_id": self.notion_database_id,
86
+ "filter": {
87
+ "property": "Rap詞 : 櫻井翔",
88
+ "checkbox": {
89
+ "equals": True
90
+ }
91
+ },
92
+ "start_cursor": start_cursor
93
+ }
94
+ )
95
+ # 100件までしか1回に取得できない
96
+ # 100件以上ある場合 has_more = True
97
+ has_more = db["has_more"]
98
+ # 次のカーソル
99
+ start_cursor = db["next_cursor"]
100
+
101
+ # 取得結果
102
+ results += db["results"]
103
+
104
+ return results
105
+
106
+ def __load_blocks(self, block_id: str) -> str:
107
+ """
108
+ Notionのページをプレーンテキストで取得する (Notion Official API)
109
+
110
+ Parameters
111
+ ----------
112
+ block_id:
113
+ str, Block ID
114
+
115
+ Returns
116
+ -------
117
+ texts:
118
+ str
119
+ """
120
+ block = self.notion_client.blocks.children.list(
121
+ **{
122
+ "block_id": block_id
123
+ }
124
+ )
125
+
126
+ # プレーンテキストを繋げる
127
+ def join_plain_texts():
128
+ text = [blck["paragraph"]["rich_text"][0]["plain_text"] if len(blck["paragraph"]["rich_text"])
129
+ else "\n" for blck in block["results"]]
130
+
131
+ texts = "\n".join(text)
132
+ return texts
133
+
134
+ return join_plain_texts()
135
+
136
+ def load_content(self) -> list[dict]:
137
+ """
138
+ Notion DBのページの中身をdictで返す
139
+
140
+ Returns:
141
+ lyrics:
142
+ list[dict]
143
+ """
144
+
145
+ # DBのページ一覧を取得
146
+ db_results = self.load_database()
147
+ logger.info("🚦 [Notion] load database...")
148
+
149
+ # コンテンツ一覧
150
+ lyrics = []
151
+
152
+ logger.info("🚦 [Notion] start to load each page content ...")
153
+ # 各ページの処理
154
+ for result in db_results:
155
+ block_id = result["id"]
156
+ # rap_lyric = self.__load_blocks(block_id)
157
+
158
+ # Markdown形式でページを取得
159
+ rap_lyric = StringExporter(block_id=block_id).export()
160
+ # Markdownの修飾子を削除
161
+ rap_lyric = rap_lyric.replace("\n\n", "\n").replace("<br/>", "\n").replace("*", "")
162
+
163
+ lyrics.append(
164
+ {
165
+ "title": result["properties"]["名前"]["title"][0]["plain_text"],
166
+ "content": rap_lyric
167
+ }
168
+ )
169
+
170
+ logger.info("🚦 [Notion] Finish to load.")
171
+
172
+ return lyrics
173
+
174
+
175
+ def fetch_sakurap_corpus(filepath: str, refetch=False) -> pd.DataFrame:
176
+ """
177
+ サクラップのコーパスを取得する
178
+ CSVファイルが存在しないときにNotionから取得する
179
+
180
+ Parameters
181
+ ----------
182
+ filepath:
183
+ str
184
+ refetch:
185
+ bool
186
+
187
+ Returns
188
+ -------
189
+
190
+ """
191
+ filepath = Path(filepath)
192
+
193
+ if not filepath.exists() or refetch:
194
+ # CSVファイルを保存するディレクトリが存在しなかったら作成する
195
+ if not filepath.parent.exists():
196
+ logger.info(f"🚦 [Notion] mkdir {str(filepath.parent)} ...")
197
+ filepath.parent.mkdir(parents=True, exist_ok=True)
198
+
199
+ logger.info("🚦 [Notion] fetch from Notion DB ...")
200
+ # dictを取得
201
+ rap_db = SakurapDB()
202
+ lyrics = rap_db.load_content()
203
+
204
+ lyrics_df = pd.DataFrame(lyrics)
205
+ lyrics_df.to_csv(filepath, index=False)
206
+ else:
207
+ logger.info("🚦 [Notion] load CSV file.")
208
+
209
+ lyrics_df = pd.read_csv(filepath)
210
+
211
+ return lyrics_df
212
+
213
+
214
+ if __name__ == "__main__":
215
+ sakurap_db = SakurapDB()
216
+ lyrics = sakurap_db.load_content()
model/search/base.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import List, Union
3
+
4
+ import pandas as pd
5
+
6
+
7
+ class BaseSearchClient:
8
+ """
9
+ 検査インタフェースクラス
10
+ """
11
+ corpus: pd.DataFrame | list | None = None
12
+
13
+ @classmethod
14
+ @abc.abstractmethod
15
+ def from_dataframe(cls, _data: pd.DataFrame, _target: str):
16
+ raise NotImplementedError()
17
+
18
+ @abc.abstractmethod
19
+ def search_top_n(self, _query: Union[List[str], str], n: int=10) -> List[pd.DataFrame]:
20
+ raise NotImplementedError()
model/search/hybrid.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List
2
+
3
+ import pandas as pd
4
+ from copy import deepcopy
5
+
6
+ from dotenv import load_dotenv
7
+ from loguru import logger
8
+ from tqdm import tqdm
9
+
10
+ from model.search.base import BaseSearchClient
11
+ from model.search.surface import BM25SearchClient
12
+ from model.search.vector import RuriVoyagerSearchClient
13
+
14
+ from model.utils.timer import stop_watch
15
+
16
+
17
+ def reciprocal_rank_fusion(sparse: pd.DataFrame, dense: pd.DataFrame, k=60) -> pd.DataFrame:
18
+ """
19
+ Reciprocal Rank Fusionを計算する
20
+
21
+ Notes
22
+ ----------
23
+ RRFの計算は以下の式
24
+
25
+ .. math:: RRF = \sum_{i=1}^n \frac{1}{k+r_i}
26
+
27
+ Parameters
28
+ ----------
29
+ sparse:
30
+ pd.DataFrame, 表層検索の検索結果
31
+ dense:
32
+ pd.DataFrame, ベクトル検索の結果
33
+ k:
34
+ int,
35
+
36
+ Returns
37
+ -------
38
+ rank_results:
39
+ pd.DataFrame, RRFによるリランク結果
40
+
41
+ """
42
+ # カラム名を変更
43
+ sparse = sparse.rename(columns={"rank": "rank_sparse"})
44
+ dense = dense.rename(columns={"rank": "rank_dense"})
45
+ # denseはランク以外を落として結合する
46
+ dense_ = dense["rank_dense"]
47
+
48
+ # 順位を1からスタートするようにする
49
+ sparse["rank_sparse"] += 1
50
+ dense_ += 1
51
+
52
+ # 文書のインデックスをキーに結合する
53
+ rank_results = pd.merge(sparse, dense_, how="left", left_index=True, right_index=True)
54
+
55
+ # RRFスコアの計算
56
+ rank_results["rrf_score"] = 1 / (rank_results["rank_dense"] + k) + 1 / (rank_results["rank_sparse"] + k)
57
+
58
+ # RRFスコアのスコアが大きい順にソート
59
+ rank_results = rank_results.sort_values(["rrf_score"], ascending=False)
60
+ rank_results["rank"] = deepcopy(rank_results.reset_index()).index
61
+
62
+ return rank_results
63
+
64
+
65
+ class HybridSearchClient(BaseSearchClient):
66
+ def __init__(self, dense_model: BaseSearchClient, sparse_model: BaseSearchClient):
67
+ self.dense_model = dense_model
68
+ self.sparse_model = sparse_model
69
+
70
+ @classmethod
71
+ @stop_watch
72
+ def from_dataframe(cls, _data: pd.DataFrame, _target: str):
73
+ """
74
+ 検索ドキュメントのpd.DataFrameから初期化する
75
+
76
+ Parameters
77
+ ----------
78
+ _data:
79
+ pd.DataFrame, 検索対象のDataFrame
80
+
81
+ _target:
82
+ str, 検索対象のカラム名
83
+
84
+ Returns
85
+ -------
86
+
87
+ """
88
+ # 表層検索の初期化
89
+ dense_model = BM25SearchClient.from_dataframe(_data, _target)
90
+ # ベクトル検索の初期化
91
+ sparse_model = RuriVoyagerSearchClient.from_dataframe(_data, _target)
92
+
93
+ return cls(dense_model, sparse_model)
94
+
95
+ @stop_watch
96
+ def search_top_n(self, _query: Union[List[str], str], n: int = 10) -> List[pd.DataFrame]:
97
+ """
98
+ クエリに対する検索結果をtop-n個取得する
99
+
100
+ Parameters
101
+ ----------
102
+ _query:
103
+ Union[List[str], str], 検索クエリ
104
+ n:
105
+ int, top-nの個数. デフォルト 10.
106
+
107
+ Returns
108
+ -------
109
+ results:
110
+ List[pd.DataFrame], ランキング結果
111
+ """
112
+
113
+ logger.info(f"🚦 [HybridSearchClient] Search top {n} | {_query}")
114
+
115
+ # 型チェック
116
+ if isinstance(_query, str):
117
+ _query = [_query]
118
+
119
+ # ランキングtop-nをクエリ毎に取得
120
+ result = []
121
+ for query in tqdm(_query):
122
+ assert len(self.sparse_model.corpus) == len(
123
+ self.dense_model.corpus), "The document counts do not match between sparse and dense!"
124
+
125
+ # ドキュメント数
126
+ doc_num = len(self.sparse_model.corpus)
127
+
128
+ # 表層検索
129
+ logger.info(f"🚦 [HybridSearchClient] run surface search ...")
130
+ sparse_res = self.sparse_model.search_top_n(query, n=doc_num)
131
+ # ベクトル検索
132
+ logger.info(f"🚦 [HybridSearchClient] run vector search ...")
133
+ dense_res = self.dense_model.search_top_n(query, n=doc_num)
134
+
135
+ # RRFスコアの計算
136
+ logger.info(f"🚦 [HybridSearchClient] compute RRF scores ...")
137
+ rrf_res = reciprocal_rank_fusion(sparse_res[0], dense_res[0])
138
+
139
+ # 結果をtop Nに絞る
140
+ top_num = 10
141
+ rrf_res = rrf_res.head(top_num)
142
+ logger.info(f"🚦 [HybridSearchClient] return {top_num} results")
143
+
144
+ result.append(rrf_res)
145
+
146
+ return result
model/search/surface.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import List, Union
3
+
4
+ import pandas as pd
5
+ import numpy as np
6
+
7
+ from loguru import logger
8
+ from tqdm import tqdm
9
+
10
+ from rank_bm25 import BM25Okapi
11
+
12
+ from model.search.base import BaseSearchClient
13
+ from model.utils.tokenizer import MeCabTokenizer
14
+ from model.utils.timer import stop_watch
15
+
16
+
17
+ class BM25Wrapper(BM25Okapi):
18
+ def __init__(self, dataset: pd.DataFrame, target, tokenizer=None, k1=1.5, b=0.75, epsilon=0.25):
19
+ self.k1 = k1
20
+ self.b = b
21
+ self.epsilon = epsilon
22
+ self.dataset = dataset
23
+ corpus = dataset[target].values.tolist()
24
+ super().__init__(corpus, tokenizer)
25
+
26
+ def get_top_n(self, query, documents, n=5):
27
+ assert self.corpus_size == len(documents), "The documents given don't match the index corpus!"
28
+
29
+ scores = self.get_scores(query)
30
+ top_n = np.argsort(scores)[::-1][:n]
31
+
32
+ result = deepcopy(self.dataset.iloc[top_n])
33
+ result["score"] = scores[top_n]
34
+ return result
35
+
36
+
37
+ class BM25SearchClient(BaseSearchClient):
38
+ def __init__(self, _model: BM25Okapi, _corpus: List[List[str]]):
39
+ """
40
+
41
+ Parameters
42
+ ----------
43
+ _model:
44
+ BM25Okapi
45
+ _corpus:
46
+ List[List[str]], 検索対象の分かち書き後のフィールド
47
+ """
48
+ self.model = _model
49
+ self.corpus = _corpus
50
+
51
+ @staticmethod
52
+ def tokenize_ja(_text: List[str]):
53
+ """MeCab日本語分かち書きによるコーパス作成
54
+
55
+ Args:
56
+ _text (List[str]): コーパス文のリスト
57
+
58
+ Returns:
59
+ List[List[str]]: 分かち書きされたテキストのリスト
60
+ """
61
+
62
+ # MeCabで分かち書き
63
+ parser = MeCabTokenizer.from_tagger("-Owakati")
64
+
65
+ corpus = []
66
+ with tqdm(_text) as pbar:
67
+ for i, t in enumerate(pbar):
68
+ try:
69
+ # 分かち書きをする
70
+ corpus.append(parser.parse(t).split())
71
+ except TypeError as e:
72
+ if not isinstance(t, str):
73
+ logger.info(f"🚦 [BM25SearchClient] Corpus index of {i} is not instance of String.")
74
+ corpus.append(["[UNKNOWN]"])
75
+ else:
76
+ raise e
77
+ return corpus
78
+
79
+ @classmethod
80
+ def from_dataframe(cls, _data: pd.DataFrame, _target: str):
81
+ """
82
+ 検索ドキュメントのpd.DataFrameから初期化する
83
+
84
+ Parameters
85
+ ----------
86
+ _data:
87
+ pd.DataFrame, 検索対象のDataFrame
88
+
89
+ _target:
90
+ str, 検索対象のカラム名
91
+
92
+ Returns
93
+ -------
94
+
95
+ """
96
+
97
+ logger.info("🚦 [BM25SearchClient] Initialize from DataFrame")
98
+
99
+ search_field = _data[_target]
100
+ corpus = search_field.values.tolist()
101
+
102
+ # 分かち書きをする
103
+ corpus_tokenized = cls.tokenize_ja(corpus)
104
+ _data["tokenized"] = corpus_tokenized
105
+
106
+ bm25 = BM25Wrapper(_data, "tokenized")
107
+ return cls(bm25, corpus_tokenized)
108
+
109
+ @stop_watch
110
+ def search_top_n(self, _query: Union[List[str], str], n: int = 10) -> List[pd.DataFrame]:
111
+ """
112
+ クエリに対する検索結果をtop-n個取得する
113
+
114
+ Parameters
115
+ ----------
116
+ _query:
117
+ Union[List[str], str], 検索クエリ
118
+ n:
119
+ int, top-nの個数. デフォルト 10.
120
+
121
+ Returns
122
+ -------
123
+ results:
124
+ List[pd.DataFrame], ランキング結果
125
+ """
126
+
127
+ logger.info(f"🚦 [BM25SearchClient] Search top {n} | {_query}")
128
+
129
+ # 型チェック
130
+ if isinstance(_query, str):
131
+ _query = [_query]
132
+
133
+ # クエリを分かち書き
134
+ query_tokens = self.tokenize_ja(_query)
135
+
136
+ # ランキングtop-nをクエリ毎に取得
137
+ result = []
138
+ for query in tqdm(query_tokens):
139
+ df_res = self.model.get_top_n(query, self.corpus, n)
140
+ # ランク
141
+ df_res["rank"] = deepcopy(df_res.reset_index()).index
142
+ df_res = df_res.drop(columns=["tokenized"])
143
+ result.append(df_res)
144
+
145
+ logger.success(f"🚦 [BM25SearchClient] Executed")
146
+
147
+ return result
model/search/vector.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from typing import List, Union, Optional
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ from copy import deepcopy
10
+ from dotenv import load_dotenv
11
+ from loguru import logger
12
+ from tqdm import tqdm
13
+
14
+ import sentence_transformers as st
15
+
16
+ import voyager
17
+
18
+ from model.search.base import BaseSearchClient
19
+ from model.utils.timer import stop_watch
20
+
21
+
22
+ def array_to_string(array: np.ndarray) -> str:
23
+ """
24
+ np.ndarrayを文字列に変換する
25
+
26
+ Parameters
27
+ ----------
28
+ array:
29
+ np.ndarray
30
+
31
+ Returns
32
+ -------
33
+ array_string:
34
+ str
35
+ """
36
+ array_string = f"{array.tolist()}"
37
+ return array_string
38
+
39
+
40
+ class RuriEmbedder:
41
+ def __init__(self, model: Optional[st.SentenceTransformer] = None):
42
+
43
+ load_dotenv()
44
+
45
+ # モデルの保存先
46
+ self.model_dir = Path("models/ruri")
47
+ model_filepath = self.model_dir / "ruri-large"
48
+
49
+ # モデル
50
+ if model is None:
51
+ if model_filepath.exists():
52
+ logger.info(f"🚦 [RuriEmbedder] load ruri-large from local path: {model_filepath}")
53
+ self.model = st.SentenceTransformer(str(model_filepath))
54
+ else:
55
+ logger.info(f"🚦 [RuriEmbedder] load ruri-large from HuggingFace🤗")
56
+ token = os.getenv("HF_TOKEN")
57
+ self.model = st.SentenceTransformer("cl-nagoya/ruri-large", token=token)
58
+ # モデルを保存する
59
+ logger.info(f"🚦 [RuriEmbedder] save model ...")
60
+ self.model.save(str(model_filepath))
61
+ else:
62
+ self.model = model
63
+
64
+ def embed(self, text: Union[str, list[str]]) -> np.ndarray:
65
+ """
66
+
67
+ Parameters
68
+ ----------
69
+ text:
70
+ Union[str, list[str]], ベクトル化する文字列
71
+
72
+ Returns
73
+ -------
74
+ embedding:
75
+ np.ndarray, 埋め込み表現. トークンサイズ 1024
76
+ """
77
+ embedding = self.model.encode(text, convert_to_numpy=True)
78
+ return embedding
79
+
80
+
81
+ class RuriVoyagerSearchClient(BaseSearchClient):
82
+ def __init__(self, dataset: pd.DataFrame, target: str,
83
+ index: voyager.Index,
84
+ model: RuriEmbedder):
85
+ load_dotenv()
86
+ # オリジナルのコーパス
87
+ self.dataset = dataset
88
+ self.corpus = dataset[target].values.tolist()
89
+
90
+ # 埋め込みモデル
91
+ self.embedder = model
92
+
93
+ # Voyagerインデックス
94
+ self.index = index
95
+
96
+ @classmethod
97
+ @stop_watch
98
+ def from_dataframe(cls, _data: pd.DataFrame, _target: str):
99
+ """
100
+ 検索ドキュメントのpd.DataFrameから初期化する
101
+
102
+ Parameters
103
+ ----------
104
+ _data:
105
+ pd.DataFrame, 検索対象のDataFrame
106
+
107
+ _target:
108
+ str, 検索対象のカラム名
109
+
110
+ Returns
111
+ -------
112
+
113
+ """
114
+ logger.info("🚦 [RuriVoyagerSearchClient] Initialize from DataFrame")
115
+
116
+ search_field = _data[_target]
117
+ corpus = search_field.values.tolist()
118
+
119
+ # 埋め込みモデルの初期化
120
+ embedder = RuriEmbedder()
121
+
122
+ # Ruriの前処理
123
+ corpus = [f"文章: {c}" for c in corpus]
124
+
125
+ # ベクトル化する
126
+ embeddings = embedder.embed(corpus)
127
+
128
+ # 埋め込みベクトルの次元
129
+ num_dim = embeddings.shape[1]
130
+ logger.debug(f"🚦⚓️ [RuriVoyagerSearchClient] Number of dimensions of Embedding vector is {num_dim}")
131
+
132
+ # Voyagerのインデックスを初期化
133
+ index = voyager.Index(voyager.Space.Cosine, num_dimensions=num_dim)
134
+ # indexにベクトルを追加
135
+ _ = index.add_items(embeddings)
136
+
137
+ return cls(_data, _target, index, embedder)
138
+
139
+ @stop_watch
140
+ def search_top_n(self, _query: Union[List[str], str], n: int = 10) -> List[pd.DataFrame]:
141
+ """
142
+ クエリに対する検索結果をtop-n個取得する
143
+
144
+ Parameters
145
+ ----------
146
+ _query:
147
+ Union[List[str], str], 検索クエリ
148
+ n:
149
+ int, top-nの個数. デフォルト 10.
150
+
151
+ Returns
152
+ -------
153
+ results:
154
+ List[pd.DataFrame], ランキング結果
155
+ """
156
+
157
+ logger.info(f"🚦 [RuriVoyagerSearchClient] Search top {n} | {_query}")
158
+
159
+ # 型チェック
160
+ if isinstance(_query, str):
161
+ _query = [_query]
162
+
163
+ # Ruriの前処理
164
+ _query = [f"クエリ: {q}" for q in _query]
165
+
166
+ # ベクトル化
167
+ embeddings_queries = self.embedder.embed(_query)
168
+
169
+ # ランキングtop-nをクエリ毎に取得
170
+ result = []
171
+ for embeddings_query in tqdm(embeddings_queries):
172
+ # Voyagerのインデックスを探索
173
+ neighbors_indices, distances = self.index.query(embeddings_query, k=n)
174
+ # 類似度スコア
175
+ df_res = deepcopy(self.dataset.iloc[neighbors_indices])
176
+ df_res["score"] = distances
177
+ # ランク
178
+ df_res["rank"] = deepcopy(df_res.reset_index()).index
179
+
180
+ result.append(df_res)
181
+
182
+ return result
model/utils/timer.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ import time
3
+
4
+ from loguru import logger
5
+
6
+
7
+ # https://qiita.com/hisatoshi/items/7354c76a4412dffc4fd7
8
+ def stop_watch(func):
9
+ """
10
+ 処理にかかる時間計測をするデコレータ
11
+ """
12
+ @wraps(func)
13
+ def wrapper(*args, **kargs):
14
+ logger.debug(f"🚦 [@stop_watch] measure time to run `{func.__name__}`.")
15
+ start = time.time()
16
+ result = func(*args, **kargs)
17
+ elapsed_time = time.time() - start
18
+ logger.debug(f"🚦 [@stop_watch] take {elapsed_time:.3f} sec to run `{func.__name__}`.")
19
+ return result
20
+ return wrapper
model/utils/tokenizer.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Optional
3
+
4
+ import MeCab
5
+ # from janome.tokenizer import Tokenizer
6
+
7
+
8
+ class BaseTokenizer:
9
+ @abc.abstractmethod
10
+ def parse(self, _text: str) -> str:
11
+ """
12
+ 分かち書きした結果を返す
13
+
14
+ Parameters
15
+ ----------
16
+ _text:
17
+ str, 入力文章
18
+
19
+ Returns
20
+ -------
21
+ parsed:
22
+ str, 分かち書き後の文章, スペース区切り
23
+ """
24
+ raise NotImplementedError
25
+
26
+
27
+ class MeCabTokenizer(BaseTokenizer):
28
+ def __init__(self, _parser: MeCab.Tagger) -> None:
29
+ self.parser = _parser
30
+
31
+ @classmethod
32
+ def from_tagger(cls, _tagger: Optional[str]):
33
+ parser = MeCab.Tagger(_tagger)
34
+ return cls(parser)
35
+
36
+ def parse(self, _text: str):
37
+ return self.parser.parse(_text)
38
+
39
+
40
+ # class JanomeTokenizer(BaseTokenizer):
41
+ # def __init__(self, _tokenizer: Tokenizer):
42
+ # self.tokenizer = _tokenizer
43
+ #
44
+ # @classmethod
45
+ # def from_user_simple_dictionary(cls, _dict_filepath: Optional[str] = None):
46
+ # """
47
+ # 簡易辞書フォーマットによるユーザー辞書によるイニシャライザー
48
+ #
49
+ # https://mocobeta.github.io/janome/#v0-2-7
50
+ #
51
+ # Parameters
52
+ # ----------
53
+ # _dict_filepath:
54
+ # str, 簡易辞書フォーマットで書かれたユーザー辞書 (CSVファイル)のファイルパス
55
+ # """
56
+ #
57
+ # if _dict_filepath is None:
58
+ # return cls(Tokenizer())
59
+ # else:
60
+ # return cls(Tokenizer(udic=_dict_filepath, udic_type='simpledic'))
61
+ #
62
+ # def parse(self, _text: str) -> str:
63
+ # return " ".join(list(self.tokenizer.tokenize(_text, wakati=True)))
requirements.txt CHANGED
@@ -9,9 +9,7 @@ tqdm
9
  python-dotenv
10
 
11
  # Visualization
12
- streamlit>=1.24
13
- st-pages
14
- streamlit-webrtc
15
 
16
  tqdm>=4.65
17
  matplotlib>=3.7
@@ -25,8 +23,6 @@ pandas>=2.0
25
  opencv-python>=4.8
26
  pillow>=9.5
27
 
28
- # LLM
29
-
30
  # Others
31
  python-magic==0.4.27
32
  emoji>=2.6.0
@@ -39,6 +35,7 @@ voyager
39
  mecab-python3
40
  unidic-lite
41
  fugashi
 
42
  sentence-transformers>=3.0
43
 
44
  # Notion
 
9
  python-dotenv
10
 
11
  # Visualization
12
+ gradio
 
 
13
 
14
  tqdm>=4.65
15
  matplotlib>=3.7
 
23
  opencv-python>=4.8
24
  pillow>=9.5
25
 
 
 
26
  # Others
27
  python-magic==0.4.27
28
  emoji>=2.6.0
 
35
  mecab-python3
36
  unidic-lite
37
  fugashi
38
+ sentencepiece
39
  sentence-transformers>=3.0
40
 
41
  # Notion