{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# core\n", "\n", "> Fill in a module description here" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| default_exp core" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "import gradio as gr\n", "from Levenshtein import ratio\n", "import json\n", "import xml.etree.ElementTree as ET\n", "from tqdm import tqdm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class ApiClient:\n", "\n", "\n", " def __init__(self, data_path: str):\n", " DATA_PATH = data_path\n", "\n", " with open(DATA_PATH, \"r\", encoding=\"utf-8\") as f:\n", " documents_data = json.load(f)\n", "\n", " self.documents_data = documents_data\n", "\n", " def extract_text_from_lines(self, element):\n", " \"\"\"本文タイプの要素からテキストを抽出する\"\"\"\n", " lines = element.findall(\".//*[@type='本文']\")\n", " return ''.join(line.text for line in lines)\n", "\n", " def format_prediction_result(self, result):\n", " \"\"\"予測結果を 'vol-page' 形式にフォーマットする\"\"\"\n", " first_result = result[0]\n", " return f'{first_result[\"vol\"]}-{first_result[\"page\"]}'\n", "\n", "\n", " def search_similar_texts(self, query, selected_vols, top_n=5, xml_file_path=None):\n", " \"\"\"テキストの類似検索を実行する関数\n", "\n", " Args:\n", " query (str): 検索クエリテキスト\n", " selected_vols (list): 検索対象の巻のリスト\n", " top_n (int, optional): 返す結果の数. デフォルトは5\n", " xml_file (gradio.File, optional): 比較対象のXMLファイル\n", "\n", " Returns:\n", " list: 検索結果のリスト。XMLファイル処理時は[predict_results]、\n", " 通常検索時は[top_results]を返す\n", " \"\"\"\n", " if xml_file_path is not None:\n", " \n", " try:\n", " with open(xml_file_path, \"r\", encoding=\"utf-8\") as f:\n", " xml_str = f.read()\n", " \n", " root = ET.fromstring(xml_str)\n", " \n", " # ページ要素の取得\n", " elements = root.findall(\".//*[@type='page']\")\n", "\n", " # 予測実行\n", " predict_results = {}\n", " for i, element in tqdm(enumerate(elements, 1)):\n", " text = self.extract_text_from_lines(element)\n", " top_results = self.predict(text, selected_vols, 1)\n", " predict_results[str(i)] = self.format_prediction_result(top_results)\n", "\n", " return [predict_results]\n", " \n", " except (ET.ParseError, FileNotFoundError, PermissionError) as e:\n", " print(f\"XMLファイルの処理中にエラーが発生しました: {str(e)}\")\n", " return [[], {}]\n", " \n", "\n", " top_results = self.predict(query, selected_vols, top_n)\n", " \n", " return [top_results] # , vol_percentages\n", " \n", "\n", " def predict(self, query, selected_vols, top_n=5):\n", " \"\"\"テキストの類似度を計算し、上位の結果を返す\n", "\n", " Args:\n", " query (str): 検索クエリテキスト\n", " selected_vols (list): 検索対象の巻のリスト\n", " top_n (int, optional): 返す結果の数. デフォルトは5\n", "\n", " Returns:\n", " list: スコア順にソートされた上位n件の検索結果\n", " \"\"\"\n", " results = []\n", " \n", " for doc in self.documents_data:\n", " # 選択された巻のみを検索対象とする\n", " if not selected_vols or str(doc[\"vol\"]) in selected_vols:\n", " score = ratio(query, doc[\"text\"])\n", " results.append({\n", " \"vol\": doc[\"vol\"],\n", " \"page\": doc[\"page\"],\n", " \"score\": score,\n", " \"text\": doc[\"text\"]\n", " })\n", "\n", " results.sort(key=lambda x: x[\"score\"], reverse=True)\n", " top_results = results[:top_n] # top_nで指定された件数だけを取得\n", "\n", " return top_results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "import nbdev; nbdev.nbdev_export()" ] } ], "metadata": { "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }