{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "S6jonMPunTP6", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "61f37b0f-fb6f-40e4-91bd-6c07d0583ff5" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Found existing installation: gcsfs 2024.10.0\n", "Uninstalling gcsfs-2024.10.0:\n", " Successfully uninstalled gcsfs-2024.10.0\n", "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.46.3)\n", "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.2.0)\n", "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.5.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.16.1)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.26.5)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.9.11)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", "Requirement already satisfied: tokenizers<0.21,>=0.20 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.20.3)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.6)\n", "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\n", "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.10)\n", "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.13.1)\n", "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.4.2)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.5.0)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.4)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n", "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.1)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.18.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.2.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.8.30)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n" ] } ], "source": [ "!pip uninstall -y gcsfs\n", "!pip install transformers datasets scikit-learn" ] }, { "cell_type": "code", "source": [ "import pandas as pd\n", "import torch\n", "from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification\n", "from sklearn.metrics import accuracy_score, classification_report" ], "metadata": { "id": "X7XPKTuusra_", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "ec851a62-853a-4b8b-885a-9384a56b802d" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" ] } ] }, { "cell_type": "code", "source": [ "# please manually adjust the data and model path for your customized testing\n", "csv_file_path = '/content/drive/Shared drives/5190_NLP_Project/test_data_random_subset.csv'\n", "model_path = '/content/drive/Shared drives/5190_NLP_Project/Bert_trained_model'\n", "\n", "data = pd.read_csv(csv_file_path)\n", "\n", "titles = data['title'].tolist()\n", "labels = data['labels'].tolist()\n", "\n", "labels = [1 if label == 0 else 0 for label in labels]\n", "\n", "tokenizer = BertTokenizer.from_pretrained(model_path)\n", "model = BertForSequenceClassification.from_pretrained(model_path)\n", "model.eval()\n", "\n", "encodings = tokenizer(\n", " titles,\n", " padding=True,\n", " truncation=True,\n", " max_length=128,\n", " return_tensors='pt'\n", ")\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model.to(device)\n", "for key in encodings:\n", " encodings[key] = encodings[key].to(device)\n", "\n", "with torch.no_grad():\n", " outputs = model(**encodings)\n", " logits = outputs.logits\n", "\n", "predictions = torch.argmax(logits, dim=-1).cpu().numpy()\n", "\n", "accuracy = accuracy_score(labels, predictions)\n", "report = classification_report(labels, predictions)\n", "\n", "print(f\"Accuracy: {accuracy:.4f}\")\n", "print(\"\\nClassification Report:\\n\", report)" ], "metadata": { "id": "bVBlNhBopf-l", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "f648822c-a928-4c78-9319-80dd8f26046f" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Accuracy: 0.7000\n", "\n", "Classification Report:\n", " precision recall f1-score support\n", "\n", " 0 0.83 0.50 0.62 10\n", " 1 0.64 0.90 0.75 10\n", "\n", " accuracy 0.70 20\n", " macro avg 0.74 0.70 0.69 20\n", "weighted avg 0.74 0.70 0.69 20\n", "\n" ] } ] } ] }