{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-10-18T02:33:03.992606Z", "start_time": "2024-10-18T02:33:03.198886Z" } }, "source": [ "from transformers import AutoTokenizer, AutoModelForTokenClassification\n", "import torch\n", "\n", "# Load the saved model and tokenizer\n", "model_path = \"./results/best_model\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_path)\n", "model = AutoModelForTokenClassification.from_pretrained(model_path)\n", "\n", "service_mapping = {\n", " \"hotel\": [\"hotel\", \"hotels\", \"khách sạn\", \"khach san\", \"ks\"],\n", " \"flight\": [\"flight\", \"flights\", \"vé máy bay\", \"máy bay\",\"may bay\"],\n", " \"car rental\": [\"car rental\", \"car rentals\", \"thuê xe\", \"xe\"],\n", " \"ticket\": [\"ticket\", \"tickets\", \"vé\", \"vé tham quan\",\"ve\", \"ve tham quan\"],\n", " \"tour\": [\"tour\", \"tours\", \"du lịch\",\"du lich\"]\n", " }\n", "\n", "# Define id2label mapping\n", "id2label = {0: \"O\", 1: \"B-SERVICE\", 2: \"I-SERVICE\", 3: \"B-LOCATION\", 4: \"I-LOCATION\"}\n", "\n", "def map_service(service):\n", " service = service.lower()\n", " for key, values in service_mapping.items():\n", " if any(v in service for v in values):\n", " return key\n", " return None\n", "\n", "def predict(text):\n", " # Tokenize the input text\n", " inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=True)\n", " with torch.no_grad():\n", " outputs = model(**inputs)\n", " predictions = torch.argmax(outputs.logits, dim=2)\n", " predicted_labels = [id2label[p.item()] for p in predictions[0]]\n", " word_ids = inputs.word_ids()\n", " aligned_labels = []\n", " current_word = None\n", " for word_id, label in zip(word_ids, predicted_labels):\n", " if word_id != current_word:\n", " aligned_labels.append(label)\n", " current_word = word_id\n", " \n", " # Extract entities\n", " entities = {\"SERVICE\": [], \"LOCATION\": []}\n", " current_entity = None\n", " current_tokens = []\n", " \n", " words = text.split()\n", " for word, label in zip(words, aligned_labels):\n", " if label.startswith(\"B-\"):\n", " if current_entity:\n", " if current_entity == \"SERVICE\":\n", " mapped_service = map_service(\" \".join(current_tokens))\n", " if mapped_service:\n", " entities[current_entity].append(mapped_service)\n", " else:\n", " entities[current_entity].append(\" \".join(current_tokens))\n", " current_entity = label[2:]\n", " current_tokens = [word]\n", " elif label.startswith(\"I-\") and current_entity:\n", " current_tokens.append(word)\n", " else:\n", " if current_entity:\n", " if current_entity == \"SERVICE\":\n", " mapped_service = map_service(\" \".join(current_tokens))\n", " if mapped_service:\n", " entities[current_entity].append(mapped_service)\n", " else:\n", " entities[current_entity].append(\" \".join(current_tokens))\n", " current_entity = None\n", " current_tokens = []\n", " \n", " if current_entity:\n", " if current_entity == \"SERVICE\":\n", " mapped_service = map_service(\" \".join(current_tokens))\n", " if mapped_service:\n", " entities[current_entity].append(mapped_service)\n", " else:\n", " entities[current_entity].append(\" \".join(current_tokens))\n", " \n", " if entities[\"SERVICE\"]:\n", " entities[\"SERVICE\"] = [entities[\"SERVICE\"][0]]\n", " \n", " return entities\n", "\n", "# Test function\n", "def test_ner(text):\n", " print(f\"Input: {text}\")\n", " result = predict(text)\n", " print(\"Output:\", result)\n", " return result" ], "outputs": [], "execution_count": 3 }, { "metadata": { "ExecuteTime": { "end_time": "2024-10-18T03:19:29.075358Z", "start_time": "2024-10-18T03:19:28.999181Z" } }, "cell_type": "code", "source": [ "test_texts = [\n", " \"tour du lich gia re da lat\"\n", "]\n", "\n", "for text in test_texts:\n", " test_ner(text)\n", " print()" ], "id": "65a6a938b3590ad2", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input: tour du lich gia re da lat\n", "Output: {'SERVICE': ['tour'], 'LOCATION': ['da lat']}\n", "\n" ] } ], "execution_count": 13 }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": "", "id": "826408096ee93b4" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }