{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-12-16T01:56:57.350322Z", "start_time": "2024-12-16T01:56:56.339076Z" } }, "source": [ "import pandas as pd\n", "from datasets import Dataset\n", "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", "from torch.utils.data import DataLoader\n", "import torch\n", "import evaluate\n", "from tqdm import tqdm\n", "from datasets import load_dataset\n", "\n", "# 1. Load the model and tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(\"CIS5190ml/bert4\")\n", "model = AutoModelForSequenceClassification.from_pretrained(\"CIS5190ml/bert4\")\n", "\n", "# 2. Load the dataset\n", "import pandas as pd \n", "\n", "ds = load_dataset(\"CIS5190ml/NewData\")\n" ], "outputs": [], "execution_count": 44 }, { "metadata": { "ExecuteTime": { "end_time": "2024-12-16T01:56:22.105429Z", "start_time": "2024-12-16T01:56:22.089923Z" } }, "cell_type": "code", "source": [ "#choose test dataset\n", "ds = ds[\"test\"]" ], "id": "fd95d0347ad1665a", "outputs": [], "execution_count": 41 }, { "metadata": { "ExecuteTime": { "end_time": "2024-12-16T01:56:24.245992Z", "start_time": "2024-12-16T01:56:23.609377Z" } }, "cell_type": "code", "source": [ "# Preprocessing function\n", "def preprocess_function(examples):\n", " return tokenizer(examples[\"title\"], truncation=True, padding=\"max_length\")\n", "\n", "encoded_ds = ds.map(preprocess_function, batched=True)\n", "\n", "# Keep only the necessary columns (input_ids, attention_mask, labels)\n", "desired_cols = [\"input_ids\", \"attention_mask\", \"labels\"]\n", "encoded_ds = encoded_ds.remove_columns([col for col in encoded_ds.column_names if col not in desired_cols])\n", "encoded_ds.set_format(\"torch\")\n", "\n", "# Create DataLoader\n", "test_loader = DataLoader(encoded_ds, batch_size=8)\n", "\n", "# Load accuracy metric\n", "accuracy = evaluate.load(\"accuracy\")\n", "\n", "# Set device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model.to(device)\n" ], "id": "dfefbe70a4ff8696", "outputs": [ { "data": { "text/plain": [ "BertForSequenceClassification(\n", " (bert): BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0-11): 12 x BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSdpaSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (pooler): BertPooler(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (activation): Tanh()\n", " )\n", " )\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (classifier): Linear(in_features=768, out_features=2, bias=True)\n", ")" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 42 }, { "metadata": { "ExecuteTime": { "end_time": "2024-12-16T01:56:35.444373Z", "start_time": "2024-12-16T01:56:26.083442Z" } }, "cell_type": "code", "source": [ "# Evaluate\n", "model.eval()\n", "for batch in tqdm(test_loader, desc=\"Evaluating\"):\n", " input_ids = batch[\"input_ids\"].to(device)\n", " attention_mask = batch[\"attention_mask\"].to(device)\n", " labels = batch[\"labels\"].to(device)\n", "\n", " with torch.no_grad():\n", " outputs = model(input_ids=input_ids, attention_mask=attention_mask)\n", " preds = torch.argmax(outputs.logits, dim=-1)\n", " accuracy.add_batch(predictions=preds, references=labels)\n", "\n", "final_accuracy = accuracy.compute()\n", "print(\"Accuracy:\", final_accuracy[\"accuracy\"])" ], "id": "c6e4fd03bd73664f", "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Evaluating: 100%|██████████| 95/95 [00:09<00:00, 10.21it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.9182058047493403\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "execution_count": 43 } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }