nakamura196 commited on
Commit
ad0a74a
·
1 Parent(s): 49b45ce

feat: add paramters

Browse files
Files changed (2) hide show
  1. app.py +18 -4
  2. predict.ipynb +156 -0
app.py CHANGED
@@ -11,7 +11,11 @@ def load_index():
11
  allow_dangerous_deserialization=True
12
  )
13
 
14
- def search_documents(input_text, courseOfStudy, index):
 
 
 
 
15
  if not input_text:
16
  return []
17
 
@@ -20,7 +24,13 @@ def search_documents(input_text, courseOfStudy, index):
20
  metadata["学校種別"] = courseOfStudy
21
 
22
  try:
23
- docs_and_scores = index.similarity_search_with_score(input_text, filter=metadata)
 
 
 
 
 
 
24
  except Exception as e:
25
  print(f"Error during search: {e}")
26
  return []
@@ -61,6 +71,10 @@ def setup_gradio_interface():
61
  datatype=["html", "str", "str", "number", "str"],
62
  label="検索結果"
63
  )
 
 
 
 
64
 
65
  index = load_index()
66
 
@@ -72,8 +86,8 @@ def setup_gradio_interface():
72
  ]
73
 
74
  interface = gr.Interface(
75
- fn=lambda text, courseOfStudy: search_documents(text, courseOfStudy, index),
76
- inputs=[text_input, metadata_selector],
77
  outputs=[output_table, gr.JSON(label="JSON")],
78
  title="Japanese Course of Study Predictor",
79
  description="入力したテキストに類似するテキストを持つ学習指導要領コードを検索します。",
 
11
  allow_dangerous_deserialization=True
12
  )
13
 
14
+ default_k = 4
15
+ default_fetch_k = 500
16
+ default_threshold = 0.5
17
+
18
+ def search_documents(input_text, courseOfStudy, index, k = default_k, fetch_k = default_fetch_k, score_threshold = default_threshold):
19
  if not input_text:
20
  return []
21
 
 
24
  metadata["学校種別"] = courseOfStudy
25
 
26
  try:
27
+ docs_and_scores = index.similarity_search_with_score(
28
+ input_text,
29
+ filter=metadata,
30
+ k=k,
31
+ fetch_k=fetch_k,
32
+ score_threshold=score_threshold
33
+ )
34
  except Exception as e:
35
  print(f"Error during search: {e}")
36
  return []
 
71
  datatype=["html", "str", "str", "number", "str"],
72
  label="検索結果"
73
  )
74
+
75
+ k = gr.Number(value=default_k, label="検索結果の数", minimum=1, maximum=200)
76
+ fetch_k = gr.Number(value=default_fetch_k, label="検索対象の数", minimum=1, maximum=2000)
77
+ score_threshold = gr.Slider(value=default_threshold, label="スコアの閾値", minimum=0.0, maximum=1.0, step=0.01)
78
 
79
  index = load_index()
80
 
 
86
  ]
87
 
88
  interface = gr.Interface(
89
+ fn=lambda text, courseOfStudy, k, fetch_k, score_threshold: search_documents(text, courseOfStudy, index, k=k, fetch_k=fetch_k, score_threshold=score_threshold),
90
+ inputs=[text_input, metadata_selector, k, fetch_k, score_threshold],
91
  outputs=[output_table, gr.JSON(label="JSON")],
92
  title="Japanese Course of Study Predictor",
93
  description="入力したテキストに類似するテキストを持つ学習指導要領コードを検索します。",
predict.ipynb ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 8,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from langchain_community.vectorstores.faiss import FAISS\n",
10
+ "from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings\n",
11
+ "from pprint import pprint"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 11,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "def load_index():\n",
21
+ " index_path = \"./storage\"\n",
22
+ " embedding_model = HuggingFaceEmbeddings(model_name=\"intfloat/multilingual-e5-large\")\n",
23
+ " return FAISS.load_local(\n",
24
+ " folder_path=index_path, \n",
25
+ " embeddings=embedding_model,\n",
26
+ " allow_dangerous_deserialization=True\n",
27
+ " )\n",
28
+ " \n",
29
+ "index = load_index()"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 41,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "def search_documents(input_text, courseOfStudy, index, k = 4, fetch_k = 20, score_threshold=0.5):\n",
39
+ " if not input_text:\n",
40
+ " return []\n",
41
+ " \n",
42
+ " metadata = {}\n",
43
+ " if courseOfStudy:\n",
44
+ " metadata[\"学校種別\"] = courseOfStudy\n",
45
+ " \n",
46
+ " try:\n",
47
+ " docs_and_scores = index.similarity_search_with_score(\n",
48
+ " input_text, \n",
49
+ " filter=metadata,\n",
50
+ " fetch_k=fetch_k,\n",
51
+ " k = k,\n",
52
+ " score_threshold=score_threshold\n",
53
+ " )\n",
54
+ " except Exception as e:\n",
55
+ " print(f\"Error during search: {e}\")\n",
56
+ " return []\n",
57
+ " \n",
58
+ " rows = [\n",
59
+ " [\n",
60
+ " f\"<a href='https://w3id.org/jp-cos/{doc.metadata['id']}' target='_blank'>{doc.metadata['id']}</a>\", \n",
61
+ " doc.metadata[\"学校種別\"],\n",
62
+ " doc.metadata[\"教科等\"],\n",
63
+ " round(float(score), 3),\n",
64
+ " doc.page_content,\n",
65
+ " ]\n",
66
+ " for doc, score in docs_and_scores\n",
67
+ " ]\n",
68
+ "\n",
69
+ " json_data = [\n",
70
+ " {\n",
71
+ " \"dcterms:identifier\": doc.metadata['id'],\n",
72
+ " \"jp-cos:courseOfStudy\": doc.metadata[\"学校種別\"],\n",
73
+ " \"jp-cos:subjectArea\": doc.metadata[\"教科等\"],\n",
74
+ " \"score\": round(float(score), 3),\n",
75
+ " \"jp-cos:sectionText\": doc.page_content,\n",
76
+ " }\n",
77
+ " for doc, score in docs_and_scores\n",
78
+ " ]\n",
79
+ "\n",
80
+ " return [\n",
81
+ " rows,\n",
82
+ " json_data\n",
83
+ " ]"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 42,
89
+ "metadata": {},
90
+ "outputs": [
91
+ {
92
+ "name": "stdout",
93
+ "output_type": "stream",
94
+ "text": [
95
+ "[{'dcterms:identifier': '8362225540000000',\n",
96
+ " 'jp-cos:courseOfStudy': '中学校',\n",
97
+ " 'jp-cos:sectionText': 'アの(ウ)の㋑については,各器官の働きを中心に扱うこと。',\n",
98
+ " 'jp-cos:subjectArea': '理科',\n",
99
+ " 'score': 0.416},\n",
100
+ " {'dcterms:identifier': '8361235820000000',\n",
101
+ " 'jp-cos:courseOfStudy': '中学校',\n",
102
+ " 'jp-cos:sectionText': 'アの(ア)の㋑については,pHにも触れること。',\n",
103
+ " 'jp-cos:subjectArea': '理科',\n",
104
+ " 'score': 0.417},\n",
105
+ " {'dcterms:identifier': '8362235720000000',\n",
106
+ " 'jp-cos:courseOfStudy': '中学校',\n",
107
+ " 'jp-cos:sectionText': 'アの(ア)の㋑については,有性生殖の仕組みを減数分裂と関連付けて扱うこと。「無性生殖」については,単細胞生物の分裂や栄養生殖にも触れること。',\n",
108
+ " 'jp-cos:subjectArea': '理科',\n",
109
+ " 'score': 0.419},\n",
110
+ " {'dcterms:identifier': '8361225630000000',\n",
111
+ " 'jp-cos:courseOfStudy': '中学校',\n",
112
+ " 'jp-cos:sectionText': 'アの(イ)の㋑の「酸化や還元」については,簡単なものを扱うこと。',\n",
113
+ " 'jp-cos:subjectArea': '理科',\n",
114
+ " 'score': 0.422}]\n"
115
+ ]
116
+ }
117
+ ],
118
+ "source": [
119
+ "input_text = \"小さい体でジャンプするトビムシは、菌糸類を食べて糞にする。\"\n",
120
+ "\n",
121
+ "grade = \"中学校\"\n",
122
+ "\n",
123
+ "k = 4\n",
124
+ "\n",
125
+ "fetch_k = 500\n",
126
+ "\n",
127
+ "threshold = 0.5\n",
128
+ "\n",
129
+ "result = search_documents(input_text, grade, index, k, fetch_k, threshold)\n",
130
+ "\n",
131
+ "pprint(result[1])"
132
+ ]
133
+ }
134
+ ],
135
+ "metadata": {
136
+ "kernelspec": {
137
+ "display_name": ".venv",
138
+ "language": "python",
139
+ "name": "python3"
140
+ },
141
+ "language_info": {
142
+ "codemirror_mode": {
143
+ "name": "ipython",
144
+ "version": 3
145
+ },
146
+ "file_extension": ".py",
147
+ "mimetype": "text/x-python",
148
+ "name": "python",
149
+ "nbconvert_exporter": "python",
150
+ "pygments_lexer": "ipython3",
151
+ "version": "3.9.11"
152
+ }
153
+ },
154
+ "nbformat": 4,
155
+ "nbformat_minor": 2
156
+ }