{ "cells": [ { "metadata": { "ExecuteTime": { "end_time": "2024-10-18T02:11:36.941568Z", "start_time": "2024-10-18T02:11:33.662041Z" } }, "cell_type": "code", "source": [ "import json\n", "import wandb\n", "from datasets import Dataset\n", "from seqeval.metrics import classification_report\n", "import wandb\n", "\n" ], "id": "5323bc048545031", "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ebk/PycharmProjects/pythonProject/venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "execution_count": 1 }, { "metadata": { "ExecuteTime": { "end_time": "2024-10-18T02:11:36.959310Z", "start_time": "2024-10-18T02:11:36.944303Z" } }, "cell_type": "code", "source": [ "def prepare_data(file_path):\n", " with open(file_path, 'r', encoding='utf-8') as f:\n", " data = json.load(f)\n", " \n", " service_mapping = {\n", " \"hotel\": [\"hotel\", \"hotels\", \"khách sạn\", \"khach san\", \"ks\"],\n", " \"flight\": [\"flight\", \"flights\", \"vé máy bay\", \"máy bay\",\"may bay\"],\n", " \"car rental\": [\"car rental\", \"car rentals\", \"thuê xe\", \"xe\"],\n", " \"ticket\": [\"ticket\", \"tickets\", \"vé\", \"vé tham quan\",\"ve\", \"ve tham quan\"],\n", " \"tour\": [\"tour\", \"tours\", \"du lịch\",\"du lich\"]\n", " }\n", " \n", " processed_data = []\n", " for query in data['queries']:\n", " words = query['text'].split()\n", " labels = ['O'] * len(words) \n", " \n", " lower_words = [w.lower() for w in words]\n", " for start, end, entity_type, entity_text in query['entities']:\n", " if entity_type == \"SERVICE\":\n", " search_terms = service_mapping.get(entity_text.lower(), [entity_text.lower()])\n", " else:\n", " search_terms = [entity_text.lower()]\n", " \n", " found = False\n", " for term in search_terms:\n", " term_words = term.split()\n", " for i in range(len(lower_words) - len(term_words) + 1):\n", " if lower_words[i:i+len(term_words)] == term_words:\n", " for j in range(len(term_words)):\n", " labels[i+j] = f'B-{entity_type}' if j == 0 else f'I-{entity_type}'\n", " found = True\n", " break\n", " if found:\n", " break\n", " \n", " if not found:\n", " print(f\"Warning: Entity '{entity_text}' not found in text '{query['text']}'\")\n", " \n", " # Keep the original capitalization in 'words'\n", " processed_data.append({'words': words, 'labels': labels})\n", " \n", " return Dataset.from_list(processed_data)\n", "\n", "# Load data\n", "train_dataset = prepare_data('/home/ebk/Desktop/NER model/train_dataset.json')\n", "eval_dataset = prepare_data('/home/ebk/Desktop/NER model/eval_dataset.json')\n" ], "id": "3bead420acaa0a20", "outputs": [], "execution_count": 2 }, { "metadata": { "ExecuteTime": { "end_time": "2024-10-18T02:27:59.106340Z", "start_time": "2024-10-18T02:11:37.016627Z" } }, "cell_type": "code", "source": [ "from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, EarlyStoppingCallback\n", "from transformers import DataCollatorForTokenClassification\n", "import numpy as np\n", "\n", "# Load tokenizer and model\n", "model_name = \"xlm-roberta-base\" \n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=5)\n", "\n", "# Prepare label list and id2label, label2id mappings\n", "label_list = [\"O\", \"B-SERVICE\", \"I-SERVICE\", \"B-LOCATION\", \"I-LOCATION\"]\n", "id2label = {i: label for i, label in enumerate(label_list)}\n", "label2id = {v: k for k, v in id2label.items()}\n", "\n", "def tokenize_and_align_labels(examples):\n", " tokenized_inputs = tokenizer(examples[\"words\"], truncation=True, is_split_into_words=True, padding=True, max_length=256)\n", " labels = []\n", " for i, label in enumerate(examples[\"labels\"]):\n", " word_ids = tokenized_inputs.word_ids(batch_index=i)\n", " previous_word_idx = None\n", " label_ids = []\n", " for word_idx in word_ids:\n", " if word_idx is None:\n", " label_ids.append(-100)\n", " elif word_idx != previous_word_idx:\n", " label_ids.append(label2id[label[word_idx]])\n", " else:\n", " label_ids.append(-100)\n", " previous_word_idx = word_idx\n", " labels.append(label_ids)\n", " tokenized_inputs[\"labels\"] = labels\n", " return tokenized_inputs\n", "\n", "# Data augmentation: add lowercase versions\n", "def augment_data(examples):\n", " new_examples = {\n", " \"words\": examples[\"words\"] + [[word.lower() for word in sentence] for sentence in examples[\"words\"]],\n", " \"labels\": examples[\"labels\"] * 2 # Duplicate labels for lowercase versions\n", " }\n", " return new_examples\n", "\n", "# Apply data augmentation and tokenization\n", "augmented_train = train_dataset.map(augment_data, batched=True, remove_columns=train_dataset.column_names)\n", "augmented_eval = eval_dataset.map(augment_data, batched=True, remove_columns=eval_dataset.column_names)\n", "\n", "tokenized_train = augmented_train.map(tokenize_and_align_labels, batched=True)\n", "tokenized_eval = augmented_eval.map(tokenize_and_align_labels, batched=True)\n", "\n", "wandb_token = \"a9e921f396228ca94645883d100e8bc7624a2737\"\n", "wandb.login(key=wandb_token)\n", "run = wandb.init(\n", " project='',\n", " job_type=\"training\",\n", " anonymous=\"allow\"\n", ")\n", "# Define training arguments\n", "training_args = TrainingArguments(\n", " output_dir=\"./results\",\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\", # Add this line to match evaluation_strategy\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=16,\n", " num_train_epochs=10,\n", " weight_decay=0.01,\n", " save_total_limit=2,\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"eval_loss\",\n", " greater_is_better=False,\n", " warmup_steps=500,\n", " lr_scheduler_type=\"linear\",\n", " logging_dir='./logs',\n", " logging_steps=100,\n", " report_to=\"wandb\",\n", ")\n", "\n", "# Define data collator\n", "data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)\n", "\n", "# Define Trainer\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_train,\n", " eval_dataset=tokenized_eval,\n", " tokenizer=tokenizer,\n", " data_collator=data_collator,\n", " callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]\n", ")\n", "\n", "# Train the model\n", "trainer.train()\n" ], "id": "4113c9be904ec16a", "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", "Map: 100%|██████████| 544/544 [00:00<00:00, 62863.71 examples/s]\n", "Map: 100%|██████████| 48/48 [00:00<00:00, 23307.08 examples/s]\n", "Map: 100%|██████████| 1088/1088 [00:00<00:00, 35149.87 examples/s]\n", "Map: 100%|██████████| 96/96 [00:00<00:00, 22730.79 examples/s]\n", "\u001B[34m\u001B[1mwandb\u001B[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n", "\u001B[34m\u001B[1mwandb\u001B[0m: Currently logged in as: \u001B[33mprodranek007\u001B[0m (\u001B[33mprodranek007-eh\u001B[0m). Use \u001B[1m`wandb login --relogin`\u001B[0m to force relogin\n", "\u001B[34m\u001B[1mwandb\u001B[0m: \u001B[33mWARNING\u001B[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n", "\u001B[34m\u001B[1mwandb\u001B[0m: \u001B[33mWARNING\u001B[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n", "\u001B[34m\u001B[1mwandb\u001B[0m: Appending key for api.wandb.ai to your netrc file: /home/ebk/.netrc\n" ] }, { "data": { "text/plain": [ "" ], "text/html": [ "Tracking run with wandb version 0.18.3" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ], "text/html": [ "Run data is saved locally in /home/ebk/Desktop/NER model/wandb/run-20241018_091140-tmfygpmb" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ], "text/html": [ "Syncing run pleasant-violet-4 to Weights & Biases (docs)
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ], "text/html": [ " View project at https://wandb.ai/prodranek007-eh/uncategorized" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ], "text/html": [ " View run at https://wandb.ai/prodranek007-eh/uncategorized/runs/tmfygpmb" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/home/ebk/PycharmProjects/pythonProject/venv/lib/python3.12/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n", "/home/ebk/PycharmProjects/pythonProject/venv/lib/python3.12/site-packages/torch/cuda/__init__.py:654: UserWarning: Can't initialize NVML\n", " warnings.warn(\"Can't initialize NVML\")\n", "\u001B[34m\u001B[1mwandb\u001B[0m: \u001B[33mWARNING\u001B[0m The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.\n" ] }, { "data": { "text/plain": [ "" ], "text/html": [ "\n", "
\n", " \n", " \n", " [680/680 16:15, Epoch 10/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation Loss
1No log1.455081
21.4997000.568691
30.6614000.099232
40.6614000.032710
50.0900000.021099
60.0253000.016014
70.0253000.027342
80.0116000.000233
90.0056000.000714
100.0056000.000568

" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=680, training_loss=0.3376634743283777, metrics={'train_runtime': 977.2101, 'train_samples_per_second': 11.134, 'train_steps_per_second': 0.696, 'total_flos': 99948720268800.0, 'train_loss': 0.3376634743283777, 'epoch': 10.0})" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 3 }, { "metadata": { "ExecuteTime": { "end_time": "2024-10-18T02:28:00.239706Z", "start_time": "2024-10-18T02:27:59.142541Z" } }, "cell_type": "code", "source": [ "# After training\n", "trainer.save_model(\"./results/best_model\")\n", "tokenizer.save_pretrained(\"./results/best_model\")" ], "id": "ac89d99896df0b18", "outputs": [ { "data": { "text/plain": [ "('./results/best_model/tokenizer_config.json',\n", " './results/best_model/special_tokens_map.json',\n", " './results/best_model/tokenizer.json')" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 4 }, { "metadata": { "ExecuteTime": { "end_time": "2024-10-18T02:28:01.465325Z", "start_time": "2024-10-18T02:28:00.270346Z" } }, "cell_type": "code", "source": [ "# Evaluate the model\n", "eval_results = trainer.evaluate()\n", "print(f\"Evaluation results: {eval_results}\")\n", "\n", "# Function to align predictions with labels\n", "def align_predictions(predictions, label_ids):\n", " preds = np.argmax(predictions, axis=2)\n", " batch_size, seq_len = preds.shape\n", " out_label_list = [[] for _ in range(batch_size)]\n", " preds_list = [[] for _ in range(batch_size)]\n", " for i in range(batch_size):\n", " for j in range(seq_len):\n", " if label_ids[i, j] != -100:\n", " out_label_list[i].append(id2label[label_ids[i][j]])\n", " preds_list[i].append(id2label[preds[i][j]])\n", " return preds_list, out_label_list\n", "\n", "# Get predictions\n", "test_results = trainer.predict(tokenized_eval)\n", "predictions, labels, _ = test_results\n", "preds_list, out_label_list = align_predictions(predictions, labels)\n", "\n", "# Print classification report\n", "print(classification_report(out_label_list, preds_list))" ], "id": "facbe70d87a460a", "outputs": [ { "data": { "text/plain": [ "" ], "text/html": [] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluation results: {'eval_loss': 0.00023286677605938166, 'eval_runtime': 0.5843, 'eval_samples_per_second': 164.298, 'eval_steps_per_second': 10.269, 'epoch': 10.0}\n", " precision recall f1-score support\n", "\n", " LOCATION 1.00 1.00 1.00 96\n", " SERVICE 1.00 1.00 1.00 96\n", "\n", " micro avg 1.00 1.00 1.00 192\n", " macro avg 1.00 1.00 1.00 192\n", "weighted avg 1.00 1.00 1.00 192\n", "\n" ] } ], "execution_count": 5 }, { "metadata": { "ExecuteTime": { "end_time": "2024-10-18T02:28:48.953577Z", "start_time": "2024-10-18T02:28:48.262193Z" } }, "cell_type": "code", "source": [ "from transformers import AutoTokenizer, AutoModelForTokenClassification\n", "import torch\n", "\n", "# Load the saved model and tokenizer\n", "model_path = \"./results/best_model\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_path)\n", "model = AutoModelForTokenClassification.from_pretrained(model_path)\n", "\n", "# Define id2label mapping\n", "id2label = {0: \"O\", 1: \"B-SERVICE\", 2: \"I-SERVICE\", 3: \"B-LOCATION\", 4: \"I-LOCATION\"}\n", "\n", "service_mapping = {\n", " \"hotel\": [\"hotel\", \"hotels\", \"khách sạn\", \"khach san\", \"ks\"],\n", " \"flight\": [\"flight\", \"flights\", \"vé máy bay\", \"máy bay\",\"may bay\"],\n", " \"car rental\": [\"car rental\", \"car rentals\", \"thuê xe\", \"xe\"],\n", " \"ticket\": [\"ticket\", \"tickets\", \"vé\", \"vé tham quan\",\"ve\", \"ve tham quan\"],\n", " \"tour\": [\"tour\", \"tours\", \"du lịch\",\"du lich\"]\n", " }\n", "def map_service(service):\n", " service = service.lower()\n", " for key, values in service_mapping.items():\n", " if any(v in service for v in values):\n", " return key\n", " return None\n", "\n", "def predict(text):\n", " # Tokenize the input text\n", " inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=True)\n", " \n", " # Make prediction\n", " with torch.no_grad():\n", " outputs = model(**inputs)\n", " \n", " # Get the predicted label for each token\n", " predictions = torch.argmax(outputs.logits, dim=2)\n", " \n", " # Convert prediction ids to labels\n", " predicted_labels = [id2label[p.item()] for p in predictions[0]]\n", " \n", " # Align predictions with words\n", " word_ids = inputs.word_ids()\n", " aligned_labels = []\n", " current_word = None\n", " for word_id, label in zip(word_ids, predicted_labels):\n", " if word_id != current_word:\n", " aligned_labels.append(label)\n", " current_word = word_id\n", " \n", " # Extract entities\n", " entities = {\"SERVICE\": [], \"LOCATION\": []}\n", " current_entity = None\n", " current_tokens = []\n", " \n", " words = text.split()\n", " for word, label in zip(words, aligned_labels):\n", " if label.startswith(\"B-\"):\n", " if current_entity:\n", " if current_entity == \"SERVICE\":\n", " mapped_service = map_service(\" \".join(current_tokens))\n", " if mapped_service:\n", " entities[current_entity].append(mapped_service)\n", " else:\n", " entities[current_entity].append(\" \".join(current_tokens))\n", " current_entity = label[2:]\n", " current_tokens = [word]\n", " elif label.startswith(\"I-\") and current_entity:\n", " current_tokens.append(word)\n", " else:\n", " if current_entity:\n", " if current_entity == \"SERVICE\":\n", " mapped_service = map_service(\" \".join(current_tokens))\n", " if mapped_service:\n", " entities[current_entity].append(mapped_service)\n", " else:\n", " entities[current_entity].append(\" \".join(current_tokens))\n", " current_entity = None\n", " current_tokens = []\n", " \n", " if current_entity:\n", " if current_entity == \"SERVICE\":\n", " mapped_service = map_service(\" \".join(current_tokens))\n", " if mapped_service:\n", " entities[current_entity].append(mapped_service)\n", " else:\n", " entities[current_entity].append(\" \".join(current_tokens))\n", " \n", " # Remove duplicates and keep only the first service if multiple are detected\n", " if entities[\"SERVICE\"]:\n", " entities[\"SERVICE\"] = [entities[\"SERVICE\"][0]]\n", " \n", " return entities\n", "\n", "# Test function\n", "def test_ner(text):\n", " print(f\"Input: {text}\")\n", " result = predict(text)\n", " print(\"Output:\", result)\n", " return result" ], "id": "bb4a2cc37bdd801e", "outputs": [], "execution_count": 9 }, { "metadata": { "ExecuteTime": { "end_time": "2024-10-18T02:31:21.177339Z", "start_time": "2024-10-18T02:31:21.154825Z" } }, "cell_type": "code", "source": [ "test_texts = [\n", " \"du lich china\"\n", "]\n", "\n", "for text in test_texts:\n", " test_ner(text)\n", " print()" ], "id": "9a7274e392b039d1", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input: du lich china\n", "Output: {'SERVICE': ['tour'], 'LOCATION': ['china']}\n", "\n" ] } ], "execution_count": 17 }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": "", "id": "5680dec4570f16df" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }