Spaces:
Sleeping
Sleeping
File size: 5,640 Bytes
0217f42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# core\n",
"\n",
"> Fill in a module description here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| default_exp core"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"from nbdev.showdoc import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"import gradio as gr\n",
"from Levenshtein import ratio\n",
"import json\n",
"import xml.etree.ElementTree as ET\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class ApiClient:\n",
"\n",
"\n",
" def __init__(self, data_path: str):\n",
" DATA_PATH = data_path\n",
"\n",
" with open(DATA_PATH, \"r\", encoding=\"utf-8\") as f:\n",
" documents_data = json.load(f)\n",
"\n",
" self.documents_data = documents_data\n",
"\n",
" def extract_text_from_lines(self, element):\n",
" \"\"\"本文タイプの要素からテキストを抽出する\"\"\"\n",
" lines = element.findall(\".//*[@type='本文']\")\n",
" return ''.join(line.text for line in lines)\n",
"\n",
" def format_prediction_result(self, result):\n",
" \"\"\"予測結果を 'vol-page' 形式にフォーマットする\"\"\"\n",
" first_result = result[0]\n",
" return f'{first_result[\"vol\"]}-{first_result[\"page\"]}'\n",
"\n",
"\n",
" def search_similar_texts(self, query, selected_vols, top_n=5, xml_file_path=None):\n",
" \"\"\"テキストの類似検索を実行する関数\n",
"\n",
" Args:\n",
" query (str): 検索クエリテキスト\n",
" selected_vols (list): 検索対象の巻のリスト\n",
" top_n (int, optional): 返す結果の数. デフォルトは5\n",
" xml_file (gradio.File, optional): 比較対象のXMLファイル\n",
"\n",
" Returns:\n",
" list: 検索結果のリスト。XMLファイル処理時は[predict_results]、\n",
" 通常検索時は[top_results]を返す\n",
" \"\"\"\n",
" if xml_file_path is not None:\n",
" \n",
" try:\n",
" with open(xml_file_path, \"r\", encoding=\"utf-8\") as f:\n",
" xml_str = f.read()\n",
" \n",
" root = ET.fromstring(xml_str)\n",
" \n",
" # ページ要素の取得\n",
" elements = root.findall(\".//*[@type='page']\")\n",
"\n",
" # 予測実行\n",
" predict_results = {}\n",
" for i, element in tqdm(enumerate(elements, 1)):\n",
" text = self.extract_text_from_lines(element)\n",
" top_results = self.predict(text, selected_vols, 1)\n",
" predict_results[str(i)] = self.format_prediction_result(top_results)\n",
"\n",
" return [predict_results]\n",
" \n",
" except (ET.ParseError, FileNotFoundError, PermissionError) as e:\n",
" print(f\"XMLファイルの処理中にエラーが発生しました: {str(e)}\")\n",
" return [[], {}]\n",
" \n",
"\n",
" top_results = self.predict(query, selected_vols, top_n)\n",
" \n",
" return [top_results] # , vol_percentages\n",
" \n",
"\n",
" def predict(self, query, selected_vols, top_n=5):\n",
" \"\"\"テキストの類似度を計算し、上位の結果を返す\n",
"\n",
" Args:\n",
" query (str): 検索クエリテキスト\n",
" selected_vols (list): 検索対象の巻のリスト\n",
" top_n (int, optional): 返す結果の数. デフォルトは5\n",
"\n",
" Returns:\n",
" list: スコア順にソートされた上位n件の検索結果\n",
" \"\"\"\n",
" results = []\n",
" \n",
" for doc in self.documents_data:\n",
" # 選択された巻のみを検索対象とする\n",
" if not selected_vols or str(doc[\"vol\"]) in selected_vols:\n",
" score = ratio(query, doc[\"text\"])\n",
" results.append({\n",
" \"vol\": doc[\"vol\"],\n",
" \"page\": doc[\"page\"],\n",
" \"score\": score,\n",
" \"text\": doc[\"text\"]\n",
" })\n",
"\n",
" results.sort(key=lambda x: x[\"score\"], reverse=True)\n",
" top_results = results[:top_n] # top_nで指定された件数だけを取得\n",
"\n",
" return top_results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"import nbdev; nbdev.nbdev_export()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
|