{ "cells": [ { "metadata": { "ExecuteTime": { "end_time": "2024-10-18T02:56:55.438127Z", "start_time": "2024-10-18T02:56:54.290456Z" } }, "cell_type": "code", "source": [ "import torch\n", "from transformers import AutoTokenizer, AutoModelForTokenClassification\n", "import onnx\n", "import onnxruntime\n", "import numpy as np" ], "id": "2c8eccf897472c15", "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ebk/PycharmProjects/pythonProject/venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "execution_count": 1 }, { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-10-18T02:58:34.503582Z", "start_time": "2024-10-18T02:58:32.670581Z" } }, "source": [ "# Load the ONNX model\n", "ort_session = onnxruntime.InferenceSession(\"ner_model.onnx\")\n", "\n", "# Load the tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(\"./results/best_model\")\n", "\n", "id2label = {0: \"O\", 1: \"B-SERVICE\", 2: \"I-SERVICE\", 3: \"B-LOCATION\", 4: \"I-LOCATION\"}\n", "\n", "# Define service mapping\n", "service_mapping = {\n", " \"hotel\": [\"hotel\", \"hotels\", \"khách sạn\", \"khach san\", \"ks\"],\n", " \"flight\": [\"flight\", \"flights\", \"vé máy bay\", \"máy bay\", \"may bay\"],\n", " \"car rental\": [\"car rental\", \"car rentals\", \"thuê xe\", \"xe\"],\n", " \"ticket\": [\"ticket\", \"tickets\", \"vé\", \"vé tham quan\", \"ve\", \"ve tham quan\"],\n", " \"tour\": [\"tour\", \"tours\", \"du lịch\", \"du lich\"]\n", "}\n", "\n", "def map_service(service):\n", " service = service.lower()\n", " for key, values in service_mapping.items():\n", " if any(v in service for v in values):\n", " return key\n", " return None\n", "\n", "def predict_onnx(text):\n", " inputs = tokenizer(text, return_tensors=\"np\", truncation=True, padding=True)\n", " \n", " # Run inference\n", " ort_inputs = {\n", " \"input_ids\": inputs[\"input_ids\"],\n", " \"attention_mask\": inputs[\"attention_mask\"]\n", " }\n", " ort_outputs = ort_session.run(None, ort_inputs)\n", " predictions = np.argmax(ort_outputs[0], axis=2)\n", " predicted_labels = [id2label[p] for p in predictions[0]]\n", " word_ids = inputs.word_ids()\n", " aligned_labels = []\n", " current_word = None\n", " for word_id, label in zip(word_ids, predicted_labels):\n", " if word_id != current_word:\n", " aligned_labels.append(label)\n", " current_word = word_id\n", " \n", " # Extract entities\n", " entities = {\"SERVICE\": [], \"LOCATION\": []}\n", " current_entity = None\n", " current_tokens = []\n", " \n", " words = text.split()\n", " for word, label in zip(words, aligned_labels):\n", " if label.startswith(\"B-\"):\n", " if current_entity:\n", " if current_entity == \"SERVICE\":\n", " mapped_service = map_service(\" \".join(current_tokens))\n", " if mapped_service:\n", " entities[current_entity].append(mapped_service)\n", " else:\n", " entities[current_entity].append(\" \".join(current_tokens))\n", " current_entity = label[2:]\n", " current_tokens = [word]\n", " elif label.startswith(\"I-\") and current_entity:\n", " current_tokens.append(word)\n", " else:\n", " if current_entity:\n", " if current_entity == \"SERVICE\":\n", " mapped_service = map_service(\" \".join(current_tokens))\n", " if mapped_service:\n", " entities[current_entity].append(mapped_service)\n", " else:\n", " entities[current_entity].append(\" \".join(current_tokens))\n", " current_entity = None\n", " current_tokens = []\n", " \n", " if current_entity:\n", " if current_entity == \"SERVICE\":\n", " mapped_service = map_service(\" \".join(current_tokens))\n", " if mapped_service:\n", " entities[current_entity].append(mapped_service)\n", " else:\n", " entities[current_entity].append(\" \".join(current_tokens))\n", " \n", " # Remove duplicates and keep only the first service if multiple are detected\n", " if entities[\"SERVICE\"]:\n", " entities[\"SERVICE\"] = [entities[\"SERVICE\"][0]]\n", " \n", " return entities" ], "outputs": [], "execution_count": 3 }, { "metadata": { "ExecuteTime": { "end_time": "2024-10-18T02:59:36.198002Z", "start_time": "2024-10-18T02:59:36.078020Z" } }, "cell_type": "code", "source": [ "# Test function\n", "def test_ner_onnx(text):\n", " print(f\"Input: {text}\")\n", " result = predict_onnx(text)\n", " print(\"Output:\", result)\n", " return result\n", "\n", "# Test \n", "sample_texts = [\n", " \"DAT khách sạn ở Hà Nội\",\n", " \"flight to New York\",\n", " \"Thuê xe ở Đà Nẵng\",\n", " \"Đặt tour du lịch Hội An\",\n", " \"I need a ticket for the museum in Paris\"\n", "]\n", "\n", "for text in sample_texts:\n", " test_ner_onnx(text)\n", " print()" ], "id": "7dce81aa4d60eb94", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input: DAT khách sạn ở Hà Nội\n", "Output: {'SERVICE': ['hotel'], 'LOCATION': ['Hà Nội']}\n", "\n", "Input: flight to New York\n", "Output: {'SERVICE': ['flight'], 'LOCATION': ['York']}\n", "\n", "Input: Thuê xe ở Đà Nẵng\n", "Output: {'SERVICE': ['car rental'], 'LOCATION': ['Đà Nẵng']}\n", "\n", "Input: Đặt tour du lịch Hội An\n", "Output: {'SERVICE': ['tour'], 'LOCATION': ['Hội An']}\n", "\n", "Input: I need a ticket for the museum in Paris\n", "Output: {'SERVICE': ['ticket'], 'LOCATION': ['Paris']}\n", "\n" ] } ], "execution_count": 8 }, { "metadata": {}, "cell_type": "code", "source": "", "id": "9fe0694b3e890582", "outputs": [], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }