jyw3 commited on
Commit
ea7775b
Β·
verified Β·
1 Parent(s): 2dd2c6f

Upload eval 2.ipynb

Browse files
Files changed (1) hide show
  1. eval 2.ipynb +212 -0
eval 2.ipynb ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "id": "initial_id",
6
+ "metadata": {
7
+ "collapsed": true,
8
+ "ExecuteTime": {
9
+ "end_time": "2024-12-16T01:56:57.350322Z",
10
+ "start_time": "2024-12-16T01:56:56.339076Z"
11
+ }
12
+ },
13
+ "source": [
14
+ "import pandas as pd\n",
15
+ "from datasets import Dataset\n",
16
+ "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
17
+ "from torch.utils.data import DataLoader\n",
18
+ "import torch\n",
19
+ "import evaluate\n",
20
+ "from tqdm import tqdm\n",
21
+ "from datasets import load_dataset\n",
22
+ "\n",
23
+ "# 1. Load the model and tokenizer\n",
24
+ "tokenizer = AutoTokenizer.from_pretrained(\"CIS5190ml/bert4\")\n",
25
+ "model = AutoModelForSequenceClassification.from_pretrained(\"CIS5190ml/bert4\")\n",
26
+ "\n",
27
+ "# 2. Load the dataset\n",
28
+ "import pandas as pd \n",
29
+ "\n",
30
+ "ds = load_dataset(\"CIS5190ml/NewData\")\n"
31
+ ],
32
+ "outputs": [],
33
+ "execution_count": 44
34
+ },
35
+ {
36
+ "metadata": {
37
+ "ExecuteTime": {
38
+ "end_time": "2024-12-16T01:56:22.105429Z",
39
+ "start_time": "2024-12-16T01:56:22.089923Z"
40
+ }
41
+ },
42
+ "cell_type": "code",
43
+ "source": [
44
+ "#choose test dataset\n",
45
+ "ds = ds[\"test\"]"
46
+ ],
47
+ "id": "fd95d0347ad1665a",
48
+ "outputs": [],
49
+ "execution_count": 41
50
+ },
51
+ {
52
+ "metadata": {
53
+ "ExecuteTime": {
54
+ "end_time": "2024-12-16T01:56:24.245992Z",
55
+ "start_time": "2024-12-16T01:56:23.609377Z"
56
+ }
57
+ },
58
+ "cell_type": "code",
59
+ "source": [
60
+ "# Preprocessing function\n",
61
+ "def preprocess_function(examples):\n",
62
+ " return tokenizer(examples[\"title\"], truncation=True, padding=\"max_length\")\n",
63
+ "\n",
64
+ "encoded_ds = ds.map(preprocess_function, batched=True)\n",
65
+ "\n",
66
+ "# Keep only the necessary columns (input_ids, attention_mask, labels)\n",
67
+ "desired_cols = [\"input_ids\", \"attention_mask\", \"labels\"]\n",
68
+ "encoded_ds = encoded_ds.remove_columns([col for col in encoded_ds.column_names if col not in desired_cols])\n",
69
+ "encoded_ds.set_format(\"torch\")\n",
70
+ "\n",
71
+ "# Create DataLoader\n",
72
+ "test_loader = DataLoader(encoded_ds, batch_size=8)\n",
73
+ "\n",
74
+ "# Load accuracy metric\n",
75
+ "accuracy = evaluate.load(\"accuracy\")\n",
76
+ "\n",
77
+ "# Set device\n",
78
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
79
+ "model.to(device)\n"
80
+ ],
81
+ "id": "dfefbe70a4ff8696",
82
+ "outputs": [
83
+ {
84
+ "data": {
85
+ "text/plain": [
86
+ "BertForSequenceClassification(\n",
87
+ " (bert): BertModel(\n",
88
+ " (embeddings): BertEmbeddings(\n",
89
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
90
+ " (position_embeddings): Embedding(512, 768)\n",
91
+ " (token_type_embeddings): Embedding(2, 768)\n",
92
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
93
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
94
+ " )\n",
95
+ " (encoder): BertEncoder(\n",
96
+ " (layer): ModuleList(\n",
97
+ " (0-11): 12 x BertLayer(\n",
98
+ " (attention): BertAttention(\n",
99
+ " (self): BertSdpaSelfAttention(\n",
100
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
101
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
102
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
103
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
104
+ " )\n",
105
+ " (output): BertSelfOutput(\n",
106
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
107
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
108
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
109
+ " )\n",
110
+ " )\n",
111
+ " (intermediate): BertIntermediate(\n",
112
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
113
+ " (intermediate_act_fn): GELUActivation()\n",
114
+ " )\n",
115
+ " (output): BertOutput(\n",
116
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
117
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
118
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
119
+ " )\n",
120
+ " )\n",
121
+ " )\n",
122
+ " )\n",
123
+ " (pooler): BertPooler(\n",
124
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
125
+ " (activation): Tanh()\n",
126
+ " )\n",
127
+ " )\n",
128
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
129
+ " (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
130
+ ")"
131
+ ]
132
+ },
133
+ "execution_count": 42,
134
+ "metadata": {},
135
+ "output_type": "execute_result"
136
+ }
137
+ ],
138
+ "execution_count": 42
139
+ },
140
+ {
141
+ "metadata": {
142
+ "ExecuteTime": {
143
+ "end_time": "2024-12-16T01:56:35.444373Z",
144
+ "start_time": "2024-12-16T01:56:26.083442Z"
145
+ }
146
+ },
147
+ "cell_type": "code",
148
+ "source": [
149
+ "# Evaluate\n",
150
+ "model.eval()\n",
151
+ "for batch in tqdm(test_loader, desc=\"Evaluating\"):\n",
152
+ " input_ids = batch[\"input_ids\"].to(device)\n",
153
+ " attention_mask = batch[\"attention_mask\"].to(device)\n",
154
+ " labels = batch[\"labels\"].to(device)\n",
155
+ "\n",
156
+ " with torch.no_grad():\n",
157
+ " outputs = model(input_ids=input_ids, attention_mask=attention_mask)\n",
158
+ " preds = torch.argmax(outputs.logits, dim=-1)\n",
159
+ " accuracy.add_batch(predictions=preds, references=labels)\n",
160
+ "\n",
161
+ "final_accuracy = accuracy.compute()\n",
162
+ "print(\"Accuracy:\", final_accuracy[\"accuracy\"])"
163
+ ],
164
+ "id": "c6e4fd03bd73664f",
165
+ "outputs": [
166
+ {
167
+ "name": "stderr",
168
+ "output_type": "stream",
169
+ "text": [
170
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 95/95 [00:09<00:00, 10.21it/s]"
171
+ ]
172
+ },
173
+ {
174
+ "name": "stdout",
175
+ "output_type": "stream",
176
+ "text": [
177
+ "Accuracy: 0.9182058047493403\n"
178
+ ]
179
+ },
180
+ {
181
+ "name": "stderr",
182
+ "output_type": "stream",
183
+ "text": [
184
+ "\n"
185
+ ]
186
+ }
187
+ ],
188
+ "execution_count": 43
189
+ }
190
+ ],
191
+ "metadata": {
192
+ "kernelspec": {
193
+ "display_name": "Python 3",
194
+ "language": "python",
195
+ "name": "python3"
196
+ },
197
+ "language_info": {
198
+ "codemirror_mode": {
199
+ "name": "ipython",
200
+ "version": 2
201
+ },
202
+ "file_extension": ".py",
203
+ "mimetype": "text/x-python",
204
+ "name": "python",
205
+ "nbconvert_exporter": "python",
206
+ "pygments_lexer": "ipython2",
207
+ "version": "2.7.6"
208
+ }
209
+ },
210
+ "nbformat": 4,
211
+ "nbformat_minor": 5
212
+ }