haoyangli727 commited on
Commit
fc15c14
·
verified ·
1 Parent(s): a7af19d

Added evaluation pipeline for Bert model

Browse files
Files changed (1) hide show
  1. BertScript.ipynb +182 -0
BertScript.ipynb ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {
21
+ "id": "S6jonMPunTP6",
22
+ "colab": {
23
+ "base_uri": "https://localhost:8080/"
24
+ },
25
+ "outputId": "61f37b0f-fb6f-40e4-91bd-6c07d0583ff5"
26
+ },
27
+ "outputs": [
28
+ {
29
+ "output_type": "stream",
30
+ "name": "stdout",
31
+ "text": [
32
+ "Found existing installation: gcsfs 2024.10.0\n",
33
+ "Uninstalling gcsfs-2024.10.0:\n",
34
+ " Successfully uninstalled gcsfs-2024.10.0\n",
35
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.46.3)\n",
36
+ "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.2.0)\n",
37
+ "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.5.2)\n",
38
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.16.1)\n",
39
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.26.5)\n",
40
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n",
41
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.2)\n",
42
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n",
43
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.9.11)\n",
44
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n",
45
+ "Requirement already satisfied: tokenizers<0.21,>=0.20 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.20.3)\n",
46
+ "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5)\n",
47
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.6)\n",
48
+ "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n",
49
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n",
50
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n",
51
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\n",
52
+ "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n",
53
+ "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)\n",
54
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.10)\n",
55
+ "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.13.1)\n",
56
+ "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.4.2)\n",
57
+ "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.5.0)\n",
58
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.4)\n",
59
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
60
+ "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
61
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n",
62
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)\n",
63
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n",
64
+ "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.1)\n",
65
+ "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.18.3)\n",
66
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n",
67
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4.0)\n",
68
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.10)\n",
69
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.2.3)\n",
70
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.8.30)\n",
71
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
72
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n",
73
+ "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n",
74
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n"
75
+ ]
76
+ }
77
+ ],
78
+ "source": [
79
+ "!pip uninstall -y gcsfs\n",
80
+ "!pip install transformers datasets scikit-learn"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "source": [
86
+ "import pandas as pd\n",
87
+ "import torch\n",
88
+ "from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification\n",
89
+ "from sklearn.metrics import accuracy_score, classification_report"
90
+ ],
91
+ "metadata": {
92
+ "id": "X7XPKTuusra_",
93
+ "colab": {
94
+ "base_uri": "https://localhost:8080/"
95
+ },
96
+ "outputId": "ec851a62-853a-4b8b-885a-9384a56b802d"
97
+ },
98
+ "execution_count": null,
99
+ "outputs": [
100
+ {
101
+ "output_type": "stream",
102
+ "name": "stdout",
103
+ "text": [
104
+ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
105
+ ]
106
+ }
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "source": [
112
+ "# please manually adjust the data and model path for your customized testing\n",
113
+ "csv_file_path = '/content/drive/Shared drives/5190_NLP_Project/test_data_random_subset.csv'\n",
114
+ "model_path = '/content/drive/Shared drives/5190_NLP_Project/Bert_trained_model'\n",
115
+ "\n",
116
+ "data = pd.read_csv(csv_file_path)\n",
117
+ "\n",
118
+ "titles = data['title'].tolist()\n",
119
+ "labels = data['labels'].tolist()\n",
120
+ "\n",
121
+ "labels = [1 if label == 0 else 0 for label in labels]\n",
122
+ "\n",
123
+ "tokenizer = BertTokenizer.from_pretrained(model_path)\n",
124
+ "model = BertForSequenceClassification.from_pretrained(model_path)\n",
125
+ "model.eval()\n",
126
+ "\n",
127
+ "encodings = tokenizer(\n",
128
+ " titles,\n",
129
+ " padding=True,\n",
130
+ " truncation=True,\n",
131
+ " max_length=128,\n",
132
+ " return_tensors='pt'\n",
133
+ ")\n",
134
+ "\n",
135
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
136
+ "model.to(device)\n",
137
+ "for key in encodings:\n",
138
+ " encodings[key] = encodings[key].to(device)\n",
139
+ "\n",
140
+ "with torch.no_grad():\n",
141
+ " outputs = model(**encodings)\n",
142
+ " logits = outputs.logits\n",
143
+ "\n",
144
+ "predictions = torch.argmax(logits, dim=-1).cpu().numpy()\n",
145
+ "\n",
146
+ "accuracy = accuracy_score(labels, predictions)\n",
147
+ "report = classification_report(labels, predictions)\n",
148
+ "\n",
149
+ "print(f\"Accuracy: {accuracy:.4f}\")\n",
150
+ "print(\"\\nClassification Report:\\n\", report)"
151
+ ],
152
+ "metadata": {
153
+ "id": "bVBlNhBopf-l",
154
+ "colab": {
155
+ "base_uri": "https://localhost:8080/"
156
+ },
157
+ "outputId": "f648822c-a928-4c78-9319-80dd8f26046f"
158
+ },
159
+ "execution_count": null,
160
+ "outputs": [
161
+ {
162
+ "output_type": "stream",
163
+ "name": "stdout",
164
+ "text": [
165
+ "Accuracy: 0.7000\n",
166
+ "\n",
167
+ "Classification Report:\n",
168
+ " precision recall f1-score support\n",
169
+ "\n",
170
+ " 0 0.83 0.50 0.62 10\n",
171
+ " 1 0.64 0.90 0.75 10\n",
172
+ "\n",
173
+ " accuracy 0.70 20\n",
174
+ " macro avg 0.74 0.70 0.69 20\n",
175
+ "weighted avg 0.74 0.70 0.69 20\n",
176
+ "\n"
177
+ ]
178
+ }
179
+ ]
180
+ }
181
+ ]
182
+ }