TUEN-YUE commited on
Commit
e79e9a8
·
verified ·
1 Parent(s): f3e9a7f

Upload train&test.ipynb

Browse files
Files changed (1) hide show
  1. train&test.ipynb +1408 -0
train&test.ipynb ADDED
@@ -0,0 +1,1408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "metadata": {},
5
+ "cell_type": "markdown",
6
+ "source": [
7
+ "# Installing dependencies\n",
8
+ "## Please make a copy of this notebook."
9
+ ],
10
+ "id": "13156d7ed48b282"
11
+ },
12
+ {
13
+ "metadata": {},
14
+ "cell_type": "markdown",
15
+ "source": [
16
+ "# Huggingface login\n",
17
+ "You will require your personal token."
18
+ ],
19
+ "id": "432a756039e6399"
20
+ },
21
+ {
22
+ "metadata": {},
23
+ "cell_type": "code",
24
+ "source": "# !huggingface-cli login",
25
+ "id": "2e73da09a7c6171e",
26
+ "outputs": [],
27
+ "execution_count": null
28
+ },
29
+ {
30
+ "metadata": {},
31
+ "cell_type": "markdown",
32
+ "source": "# Part 1: Load Data",
33
+ "id": "c731d9c1ebb477dc"
34
+ },
35
+ {
36
+ "metadata": {},
37
+ "cell_type": "markdown",
38
+ "source": "## Downloading the train and test dataset",
39
+ "id": "14070f20b547688f"
40
+ },
41
+ {
42
+ "metadata": {},
43
+ "cell_type": "markdown",
44
+ "source": "",
45
+ "id": "b8920847b7cc378d"
46
+ },
47
+ {
48
+ "metadata": {},
49
+ "cell_type": "code",
50
+ "source": [
51
+ "from datasets import load_dataset\n",
52
+ "\n",
53
+ "dataset_train = load_dataset(\"CISProject/FOX_NBC\", split=\"train\")\n",
54
+ "dataset_test = load_dataset(\"CISProject/FOX_NBC\", split=\"test\")\n",
55
+ "# dataset_test = load_dataset(\"CISProject/FOX_NBC\", split=\"test_data_random_subset\")\n"
56
+ ],
57
+ "id": "877c90c978d62b7d",
58
+ "outputs": [],
59
+ "execution_count": 32
60
+ },
61
+ {
62
+ "metadata": {
63
+ "ExecuteTime": {
64
+ "end_time": "2024-12-16T08:10:32.645990Z",
65
+ "start_time": "2024-12-16T08:10:32.636717Z"
66
+ }
67
+ },
68
+ "cell_type": "code",
69
+ "source": [
70
+ "import numpy as np\n",
71
+ "import torch\n",
72
+ "import re\n",
73
+ "from transformers import BertTokenizer\n",
74
+ "from transformers import RobertaTokenizer\n",
75
+ "from sklearn.feature_extraction.text import CountVectorizer\n",
76
+ "from gensim.models import KeyedVectors\n",
77
+ "from sklearn.feature_extraction.text import TfidfVectorizer\n",
78
+ "\n",
79
+ "def preprocess_data(data,\n",
80
+ " mode=\"train\",\n",
81
+ " vectorizer=None,\n",
82
+ " w2v_model=None,\n",
83
+ " max_features=4096,\n",
84
+ " max_seq_length=128,\n",
85
+ " num_proc=4):\n",
86
+ " if w2v_model is None:\n",
87
+ " raise ValueError(\"w2v_model must be provided for Word2Vec embeddings.\")\n",
88
+ "\n",
89
+ " # tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
90
+ " tokenizer = RobertaTokenizer.from_pretrained(\"roberta-base\")\n",
91
+ " # 1. Clean text once\n",
92
+ " def clean_text(examples):\n",
93
+ " import re\n",
94
+ " cleaned = []\n",
95
+ " for text in examples[\"title\"]:\n",
96
+ " text = text.lower()\n",
97
+ " text = re.sub(r'[^\\w\\s]', '', text)\n",
98
+ " text = text.strip()\n",
99
+ " cleaned.append(text)\n",
100
+ " return {\"clean_title\": cleaned}\n",
101
+ "\n",
102
+ " data = data.map(clean_text, batched=True, num_proc=num_proc)\n",
103
+ "\n",
104
+ " # 2. Fit CountVectorizer on training data if needed\n",
105
+ " if mode == \"train\" and vectorizer is None:\n",
106
+ " # Collect all cleaned titles to fit\n",
107
+ " all_titles = data[\"clean_title\"]\n",
108
+ " #vectorizer = CountVectorizer(max_features=max_features, ngram_range=(1,2))\n",
109
+ " vectorizer = TfidfVectorizer(max_features=max_features)\n",
110
+ " vectorizer.fit(all_titles)\n",
111
+ " print(\"vectorizer fitted on training data.\")\n",
112
+ "\n",
113
+ " # 3. Transform titles with vectorizer once\n",
114
+ " def vectorize_batch(examples):\n",
115
+ " import numpy as np\n",
116
+ " freq = vectorizer.transform(examples[\"clean_title\"]).toarray().astype(np.float32)\n",
117
+ " return {\"freq_inputs\": freq}\n",
118
+ "\n",
119
+ " data = data.map(vectorize_batch, batched=True, num_proc=num_proc)\n",
120
+ "\n",
121
+ " # 4. Tokenize with BERT once\n",
122
+ " def tokenize_batch(examples):\n",
123
+ " tokenized = tokenizer(\n",
124
+ " examples[\"title\"],\n",
125
+ " padding=\"max_length\",\n",
126
+ " truncation=True,\n",
127
+ " max_length=max_seq_length\n",
128
+ " )\n",
129
+ " return {\n",
130
+ " \"input_ids\": tokenized[\"input_ids\"],\n",
131
+ " \"attention_mask\": tokenized[\"attention_mask\"]\n",
132
+ " }\n",
133
+ "\n",
134
+ " data = data.map(tokenize_batch, batched=True, num_proc=num_proc)\n",
135
+ "\n",
136
+ " # 5. Convert titles into tokens for W2V\n",
137
+ " def split_tokens(examples):\n",
138
+ " tokens_list = [t.split() for t in examples[\"clean_title\"]]\n",
139
+ " return {\"tokens\": tokens_list}\n",
140
+ "\n",
141
+ " data = data.map(split_tokens, batched=True, num_proc=num_proc)\n",
142
+ "\n",
143
+ " # Build an embedding dictionary for all unique tokens (do this once before embedding map)\n",
144
+ " unique_tokens = set()\n",
145
+ " for tokens in data[\"tokens\"]:\n",
146
+ " unique_tokens.update(tokens)\n",
147
+ "\n",
148
+ " embedding_dim = w2v_model.vector_size\n",
149
+ " embedding_dict = {}\n",
150
+ " for tk in unique_tokens:\n",
151
+ " if tk in w2v_model:\n",
152
+ " embedding_dict[tk] = w2v_model[tk].astype(np.float32)\n",
153
+ " else:\n",
154
+ " embedding_dict[tk] = np.zeros((embedding_dim,), dtype=np.float32)\n",
155
+ "\n",
156
+ " def w2v_embedding_batch(examples):\n",
157
+ " import numpy as np\n",
158
+ " batch_w2v = []\n",
159
+ " for tokens in examples[\"tokens\"]:\n",
160
+ " vectors = [embedding_dict[tk] for tk in tokens[:max_seq_length]]\n",
161
+ " if len(vectors) < max_seq_length:\n",
162
+ " vectors += [np.zeros((embedding_dim,), dtype=np.float32)] * (max_seq_length - len(vectors))\n",
163
+ " batch_w2v.append(vectors)\n",
164
+ " return {\"pos_inputs\": batch_w2v}\n",
165
+ "\n",
166
+ "\n",
167
+ " data = data.map(w2v_embedding_batch, batched=True, batch_size=32, num_proc=num_proc)\n",
168
+ "\n",
169
+ " # 7. Create labels\n",
170
+ " def make_labels(examples):\n",
171
+ " labels = [1.0 if agency == \"fox\" else 0.0 for agency in examples[\"news\"]]\n",
172
+ " return {\"labels\": labels}\n",
173
+ "\n",
174
+ " data = data.map(make_labels, batched=True, num_proc=num_proc)\n",
175
+ "\n",
176
+ " # Convert freq_inputs and pos_inputs to torch tensors in a final map step\n",
177
+ " def to_tensors(examples):\n",
178
+ " import torch\n",
179
+ "\n",
180
+ " freq_inputs = torch.tensor(examples[\"freq_inputs\"], dtype=torch.float32)\n",
181
+ " input_ids = torch.tensor(examples[\"input_ids\"])\n",
182
+ " attention_mask = torch.tensor(examples[\"attention_mask\"])\n",
183
+ " pos_inputs = torch.tensor(examples[\"pos_inputs\"], dtype=torch.float32)\n",
184
+ " labels = torch.tensor(examples[\"labels\"],dtype=torch.long)\n",
185
+ "\n",
186
+ " # seq_inputs shape: (batch_size, 2, seq_len)\n",
187
+ " seq_inputs = torch.stack([input_ids, attention_mask], dim=1)\n",
188
+ "\n",
189
+ " return {\n",
190
+ " \"freq_inputs\": freq_inputs,\n",
191
+ " \"seq_inputs\": seq_inputs,\n",
192
+ " \"pos_inputs\": pos_inputs,\n",
193
+ " \"labels\": labels\n",
194
+ " }\n",
195
+ "\n",
196
+ " # Apply final conversion to tensor\n",
197
+ " processed_data = data.map(to_tensors, batched=True, num_proc=num_proc)\n",
198
+ "\n",
199
+ " return processed_data, vectorizer\n"
200
+ ],
201
+ "id": "dc2ba675ce880d6d",
202
+ "outputs": [],
203
+ "execution_count": 33
204
+ },
205
+ {
206
+ "metadata": {
207
+ "ExecuteTime": {
208
+ "end_time": "2024-12-16T08:11:25.586667Z",
209
+ "start_time": "2024-12-16T08:10:32.651505Z"
210
+ }
211
+ },
212
+ "cell_type": "code",
213
+ "source": [
214
+ "from gensim.models import KeyedVectors\n",
215
+ "w2v_model = KeyedVectors.load_word2vec_format(\"./GoogleNews-vectors-negative300.bin\", binary=True)\n",
216
+ "\n",
217
+ "dataset_train,vectorizer = preprocess_data(\n",
218
+ " data=dataset_train,\n",
219
+ " mode=\"train\",\n",
220
+ " w2v_model=w2v_model,\n",
221
+ " max_features=8192,\n",
222
+ " max_seq_length=128\n",
223
+ ")\n",
224
+ "\n",
225
+ "dataset_test, _ = preprocess_data(\n",
226
+ " data=dataset_test,\n",
227
+ " mode=\"test\",\n",
228
+ " vectorizer=vectorizer,\n",
229
+ " w2v_model=w2v_model,\n",
230
+ " max_features=8192,\n",
231
+ " max_seq_length=128\n",
232
+ ")"
233
+ ],
234
+ "id": "158b99950fb22d1",
235
+ "outputs": [
236
+ {
237
+ "name": "stdout",
238
+ "output_type": "stream",
239
+ "text": [
240
+ "vectorizer fitted on training data.\n"
241
+ ]
242
+ },
243
+ {
244
+ "data": {
245
+ "text/plain": [
246
+ "Map (num_proc=4): 0%| | 0/3044 [00:00<?, ? examples/s]"
247
+ ],
248
+ "application/vnd.jupyter.widget-view+json": {
249
+ "version_major": 2,
250
+ "version_minor": 0,
251
+ "model_id": "bfd4d90535a94d6089d9713c39003468"
252
+ }
253
+ },
254
+ "metadata": {},
255
+ "output_type": "display_data"
256
+ },
257
+ {
258
+ "data": {
259
+ "text/plain": [
260
+ "Map (num_proc=4): 0%| | 0/3044 [00:00<?, ? examples/s]"
261
+ ],
262
+ "application/vnd.jupyter.widget-view+json": {
263
+ "version_major": 2,
264
+ "version_minor": 0,
265
+ "model_id": "88c8f4ba83ba48c1b047549b7086be66"
266
+ }
267
+ },
268
+ "metadata": {},
269
+ "output_type": "display_data"
270
+ },
271
+ {
272
+ "data": {
273
+ "text/plain": [
274
+ "Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
275
+ ],
276
+ "application/vnd.jupyter.widget-view+json": {
277
+ "version_major": 2,
278
+ "version_minor": 0,
279
+ "model_id": "7b73903c642241bca21fabc4c37b9919"
280
+ }
281
+ },
282
+ "metadata": {},
283
+ "output_type": "display_data"
284
+ },
285
+ {
286
+ "data": {
287
+ "text/plain": [
288
+ "Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
289
+ ],
290
+ "application/vnd.jupyter.widget-view+json": {
291
+ "version_major": 2,
292
+ "version_minor": 0,
293
+ "model_id": "fa333143b6344de8bbab7ddb27f618cf"
294
+ }
295
+ },
296
+ "metadata": {},
297
+ "output_type": "display_data"
298
+ },
299
+ {
300
+ "data": {
301
+ "text/plain": [
302
+ "Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
303
+ ],
304
+ "application/vnd.jupyter.widget-view+json": {
305
+ "version_major": 2,
306
+ "version_minor": 0,
307
+ "model_id": "43ea10bbe58b43e4a630c05296f4a92d"
308
+ }
309
+ },
310
+ "metadata": {},
311
+ "output_type": "display_data"
312
+ },
313
+ {
314
+ "data": {
315
+ "text/plain": [
316
+ "Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
317
+ ],
318
+ "application/vnd.jupyter.widget-view+json": {
319
+ "version_major": 2,
320
+ "version_minor": 0,
321
+ "model_id": "8917e01c927e43c2adcf3fd9b9d8a602"
322
+ }
323
+ },
324
+ "metadata": {},
325
+ "output_type": "display_data"
326
+ },
327
+ {
328
+ "data": {
329
+ "text/plain": [
330
+ "Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
331
+ ],
332
+ "application/vnd.jupyter.widget-view+json": {
333
+ "version_major": 2,
334
+ "version_minor": 0,
335
+ "model_id": "b7ebd99939cf4bc48ba23029bdd1d651"
336
+ }
337
+ },
338
+ "metadata": {},
339
+ "output_type": "display_data"
340
+ },
341
+ {
342
+ "data": {
343
+ "text/plain": [
344
+ "Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
345
+ ],
346
+ "application/vnd.jupyter.widget-view+json": {
347
+ "version_major": 2,
348
+ "version_minor": 0,
349
+ "model_id": "99f71c080bbe41018345656f102529f7"
350
+ }
351
+ },
352
+ "metadata": {},
353
+ "output_type": "display_data"
354
+ },
355
+ {
356
+ "data": {
357
+ "text/plain": [
358
+ "Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
359
+ ],
360
+ "application/vnd.jupyter.widget-view+json": {
361
+ "version_major": 2,
362
+ "version_minor": 0,
363
+ "model_id": "27db7129c3b848ada119588691e6a602"
364
+ }
365
+ },
366
+ "metadata": {},
367
+ "output_type": "display_data"
368
+ }
369
+ ],
370
+ "execution_count": 34
371
+ },
372
+ {
373
+ "metadata": {
374
+ "ExecuteTime": {
375
+ "end_time": "2024-12-16T08:11:25.595796Z",
376
+ "start_time": "2024-12-16T08:11:25.593319Z"
377
+ }
378
+ },
379
+ "cell_type": "code",
380
+ "source": [
381
+ "print(dataset_train)\n",
382
+ "print(dataset_test)"
383
+ ],
384
+ "id": "edd80d33175c96a0",
385
+ "outputs": [
386
+ {
387
+ "name": "stdout",
388
+ "output_type": "stream",
389
+ "text": [
390
+ "Dataset({\n",
391
+ " features: ['title', 'news', 'index', 'url', 'clean_title', 'freq_inputs', 'input_ids', 'attention_mask', 'tokens', 'pos_inputs', 'labels', 'seq_inputs'],\n",
392
+ " num_rows: 3044\n",
393
+ "})\n",
394
+ "Dataset({\n",
395
+ " features: ['title', 'news', 'index', 'url', 'clean_title', 'freq_inputs', 'input_ids', 'attention_mask', 'tokens', 'pos_inputs', 'labels', 'seq_inputs'],\n",
396
+ " num_rows: 761\n",
397
+ "})\n"
398
+ ]
399
+ }
400
+ ],
401
+ "execution_count": 35
402
+ },
403
+ {
404
+ "metadata": {},
405
+ "cell_type": "markdown",
406
+ "source": "# Part 2: Model",
407
+ "id": "c9a49fc1fbca29d7"
408
+ },
409
+ {
410
+ "metadata": {},
411
+ "cell_type": "markdown",
412
+ "source": "## Defining the Custom Model",
413
+ "id": "aebe5e51f0e611cc"
414
+ },
415
+ {
416
+ "metadata": {},
417
+ "cell_type": "markdown",
418
+ "source": "",
419
+ "id": "f0eae08a025b6ed9"
420
+ },
421
+ {
422
+ "metadata": {
423
+ "ExecuteTime": {
424
+ "end_time": "2024-12-16T08:11:25.680033Z",
425
+ "start_time": "2024-12-16T08:11:25.672667Z"
426
+ }
427
+ },
428
+ "cell_type": "code",
429
+ "source": [
430
+ "# TODO: import all packages necessary for your custom model\n",
431
+ "import pandas as pd\n",
432
+ "import os\n",
433
+ "from torch.utils.data import DataLoader\n",
434
+ "from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel\n",
435
+ "import torch\n",
436
+ "import torch.nn as nn\n",
437
+ "from transformers import RobertaModel, RobertaConfig,RobertaForSequenceClassification, BertModel\n",
438
+ "from model.network import Classifier\n",
439
+ "from model.frequential import FreqNetwork\n",
440
+ "from model.sequential import SeqNetwork\n",
441
+ "from model.positional import PosNetwork\n",
442
+ "\n",
443
+ "class CustomConfig(PretrainedConfig):\n",
444
+ " model_type = \"headlineclassifier\"\n",
445
+ "\n",
446
+ " def __init__(\n",
447
+ " self,\n",
448
+ " base_exp_dir=\"./exp/fox_nbc/\",\n",
449
+ " # dataset={\"data_dir\": \"./data/CASE_NAME/data.csv\", \"transform\": True},\n",
450
+ " train={\n",
451
+ " \"learning_rate\": 2e-5,\n",
452
+ " \"learning_rate_alpha\": 0.05,\n",
453
+ " \"end_iter\": 10,\n",
454
+ " \"batch_size\": 32,\n",
455
+ " \"warm_up_end\": 2,\n",
456
+ " \"anneal_end\": 5,\n",
457
+ " \"save_freq\": 1,\n",
458
+ " \"val_freq\": 1,\n",
459
+ " },\n",
460
+ " model={\n",
461
+ " \"freq\": {\n",
462
+ " \"tfidf_input_dim\": 8145,\n",
463
+ " \"tfidf_output_dim\": 128,\n",
464
+ " \"tfidf_hidden_dim\": 512,\n",
465
+ " \"n_layers\": 2,\n",
466
+ " \"skip_in\": [80],\n",
467
+ " \"weight_norm\": True,\n",
468
+ " },\n",
469
+ " \"pos\": {\n",
470
+ " \"input_dim\": 300,\n",
471
+ " \"output_dim\": 128,\n",
472
+ " \"hidden_dim\": 256,\n",
473
+ " \"n_layers\": 2,\n",
474
+ " \"skip_in\": [80],\n",
475
+ " \"weight_norm\": True,\n",
476
+ " },\n",
477
+ " \"cls\": {\n",
478
+ " \"combined_input\": 1024, #1024\n",
479
+ " \"combined_dim\": 128,\n",
480
+ " \"num_classes\": 1,\n",
481
+ " \"n_layers\": 2,\n",
482
+ " \"skip_in\": [80],\n",
483
+ " \"weight_norm\": True,\n",
484
+ " },\n",
485
+ " },\n",
486
+ " **kwargs,\n",
487
+ " ):\n",
488
+ " super().__init__(**kwargs)\n",
489
+ "\n",
490
+ " self.base_exp_dir = base_exp_dir\n",
491
+ " # self.dataset = dataset\n",
492
+ " self.train = train\n",
493
+ " self.model = model\n",
494
+ "\n",
495
+ "# TODO: define all parameters needed for your model, as well as calling the model itself\n",
496
+ "class CustomModel(PreTrainedModel):\n",
497
+ " config_class = CustomConfig\n",
498
+ "\n",
499
+ " def __init__(self, config):\n",
500
+ " super().__init__(config)\n",
501
+ " self.conf = config\n",
502
+ " self.freq = FreqNetwork(**self.conf.model[\"freq\"])\n",
503
+ " self.pos = PosNetwork(**self.conf.model[\"pos\"])\n",
504
+ " self.fc = nn.Linear(self.conf.model[\"cls\"][\"combined_input\"],2)\n",
505
+ " self.seq = RobertaModel.from_pretrained(\"roberta-base\")\n",
506
+ " # self.seq = BertModel.from_pretrained(\"bert-base-uncased\")\n",
507
+ " #for param in self.roberta.parameters():\n",
508
+ " # param.requires_grad = False\n",
509
+ " self.dropout = nn.Dropout(0.1)\n",
510
+ "\n",
511
+ " def forward(self, x):\n",
512
+ " freq_inputs = x[\"freq_inputs\"]\n",
513
+ " seq_inputs = x[\"seq_inputs\"]\n",
514
+ " pos_inputs = x[\"pos_inputs\"]\n",
515
+ " seq_feature = self.seq(\n",
516
+ " input_ids=seq_inputs[:,0,:],\n",
517
+ " attention_mask=seq_inputs[:,1,:]\n",
518
+ " ).pooler_output # last_hidden_state[:, 0, :]\n",
519
+ " freq_feature = self.freq(freq_inputs) # Shape: (batch_size, 128)\n",
520
+ "\n",
521
+ " pos_feature = self.pos(pos_inputs) #Shape: (batch_size, 128)\n",
522
+ " inputs = torch.cat((seq_feature, freq_feature, pos_feature), dim=1) # Shape: (batch_size, 384)\n",
523
+ " # inputs = torch.cat((seq_feature, freq_feature), dim=1) # Shape: (batch_size,256)\n",
524
+ " # inputs = seq_feature\n",
525
+ "\n",
526
+ " x = inputs\n",
527
+ " x = self.dropout(x)\n",
528
+ " outputs = self.fc(x)\n",
529
+ "\n",
530
+ " return outputs\n",
531
+ "\n",
532
+ " def save_model(self, save_path):\n",
533
+ " \"\"\"Save the model locally using the Hugging Face format.\"\"\"\n",
534
+ " self.save_pretrained(save_path)\n",
535
+ "\n",
536
+ " def push_model(self, repo_name):\n",
537
+ " \"\"\"Push the model to the Hugging Face Hub.\"\"\"\n",
538
+ " self.push_to_hub(repo_name)"
539
+ ],
540
+ "id": "21f079d0c52d7d",
541
+ "outputs": [],
542
+ "execution_count": 36
543
+ },
544
+ {
545
+ "metadata": {
546
+ "ExecuteTime": {
547
+ "end_time": "2024-12-16T08:11:25.888720Z",
548
+ "start_time": "2024-12-16T08:11:25.683038Z"
549
+ }
550
+ },
551
+ "cell_type": "code",
552
+ "source": [
553
+ "from huggingface_hub import hf_hub_download\n",
554
+ "\n",
555
+ "AutoConfig.register(\"headlineclassifier\", CustomConfig)\n",
556
+ "AutoModel.register(CustomConfig, CustomModel)\n",
557
+ "config = CustomConfig()\n",
558
+ "model = CustomModel(config)\n",
559
+ "\n",
560
+ "REPO_NAME = \"CISProject/News-Headline-Classifier-Notebook\" # TODO: PROVIDE A STRING TO YOUR REPO ON HUGGINGFACE"
561
+ ],
562
+ "id": "b6ba3f96d3ce21",
563
+ "outputs": [
564
+ {
565
+ "name": "stderr",
566
+ "output_type": "stream",
567
+ "text": [
568
+ "C:\\Users\\swall\\anaconda3\\envs\\newsCLS\\Lib\\site-packages\\torch\\nn\\utils\\weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n",
569
+ " WeightNorm.apply(module, name, dim)\n",
570
+ "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
571
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
572
+ ]
573
+ }
574
+ ],
575
+ "execution_count": 37
576
+ },
577
+ {
578
+ "metadata": {
579
+ "ExecuteTime": {
580
+ "end_time": "2024-12-16T08:11:25.904773Z",
581
+ "start_time": "2024-12-16T08:11:25.895279Z"
582
+ }
583
+ },
584
+ "cell_type": "code",
585
+ "source": [
586
+ "import torch\n",
587
+ "from tqdm import tqdm\n",
588
+ "import os\n",
589
+ "\n",
590
+ "\n",
591
+ "class Trainer:\n",
592
+ " def __init__(self, model, train_loader, val_loader, config, device=\"cuda\"):\n",
593
+ " self.model = model.to(device)\n",
594
+ " self.train_loader = train_loader\n",
595
+ " self.val_loader = val_loader\n",
596
+ " self.device = device\n",
597
+ " self.conf = config\n",
598
+ "\n",
599
+ " self.end_iter = self.conf.train[\"end_iter\"]\n",
600
+ " self.save_freq = self.conf.train[\"save_freq\"]\n",
601
+ " self.val_freq = self.conf.train[\"val_freq\"]\n",
602
+ "\n",
603
+ " self.batch_size = self.conf.train['batch_size']\n",
604
+ " self.learning_rate = self.conf.train['learning_rate']\n",
605
+ " self.learning_rate_alpha = self.conf.train['learning_rate_alpha']\n",
606
+ " self.warm_up_end = self.conf.train['warm_up_end']\n",
607
+ " self.anneal_end = self.conf.train['anneal_end']\n",
608
+ "\n",
609
+ " self.optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)\n",
610
+ " #self.criterion = torch.nn.BCEWithLogitsLoss()\n",
611
+ " self.criterion = torch.nn.CrossEntropyLoss()\n",
612
+ " self.save_path = os.path.join(self.conf.base_exp_dir, \"checkpoints\")\n",
613
+ " os.makedirs(self.save_path, exist_ok=True)\n",
614
+ "\n",
615
+ " self.iter_step = 0\n",
616
+ "\n",
617
+ " self.val_loss = None\n",
618
+ "\n",
619
+ " def get_cos_anneal_ratio(self):\n",
620
+ " if self.anneal_end == 0.0:\n",
621
+ " return 1.0\n",
622
+ " else:\n",
623
+ " return np.min([1.0, self.iter_step / self.anneal_end])\n",
624
+ "\n",
625
+ " def update_learning_rate(self):\n",
626
+ " if self.iter_step < self.warm_up_end:\n",
627
+ " learning_factor = self.iter_step / self.warm_up_end\n",
628
+ " else:\n",
629
+ " alpha = self.learning_rate_alpha\n",
630
+ " progress = (self.iter_step - self.warm_up_end) / (self.end_iter - self.warm_up_end)\n",
631
+ " learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha\n",
632
+ "\n",
633
+ " for g in self.optimizer.param_groups:\n",
634
+ " g['lr'] = self.learning_rate * learning_factor\n",
635
+ "\n",
636
+ " def train(self):\n",
637
+ " for epoch in range(self.end_iter):\n",
638
+ " self.update_learning_rate()\n",
639
+ " self.model.train()\n",
640
+ " epoch_loss = 0.0\n",
641
+ " correct = 0\n",
642
+ " total = 0\n",
643
+ "\n",
644
+ " for batch_inputs, labels in tqdm(self.train_loader, desc=f\"Epoch {epoch + 1}/{self.end_iter}\"):\n",
645
+ " # Extract features\n",
646
+ "\n",
647
+ " freq_inputs = batch_inputs[\"freq_inputs\"].to(self.device)\n",
648
+ " seq_inputs = batch_inputs[\"seq_inputs\"].to(self.device)\n",
649
+ " pos_inputs = batch_inputs[\"pos_inputs\"].to(self.device)\n",
650
+ " # y_train = labels.to(self.device)[:,None]\n",
651
+ " y_train = labels.to(self.device)\n",
652
+ "\n",
653
+ " # Forward pass\n",
654
+ " preds = self.model({\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs})\n",
655
+ " loss = self.criterion(preds, y_train)\n",
656
+ "\n",
657
+ " # preds = (torch.sigmoid(preds) > 0.5).int()\n",
658
+ " # Backward pass\n",
659
+ " self.optimizer.zero_grad()\n",
660
+ " loss.backward()\n",
661
+ " self.optimizer.step()\n",
662
+ " _, preds = torch.max(preds, dim=1)\n",
663
+ " # Metrics\n",
664
+ " epoch_loss += loss.item()\n",
665
+ " total += y_train.size(0)\n",
666
+ " # print(preds.shape)\n",
667
+ " correct += (preds == y_train).sum().item()\n",
668
+ "\n",
669
+ " # Log epoch metrics\n",
670
+ " print(f\"Train Loss: {epoch_loss / len(self.train_loader):.4f}\")\n",
671
+ " print(f\"Train Accuracy: {correct / total:.4f}\")\n",
672
+ "\n",
673
+ " # Validation and Save Checkpoints\n",
674
+ " if (epoch + 1) % self.val_freq == 0:\n",
675
+ " self.val()\n",
676
+ " if (epoch + 1) % self.save_freq == 0:\n",
677
+ " self.save_checkpoint(epoch + 1)\n",
678
+ "\n",
679
+ " # Update learning rate\n",
680
+ " self.iter_step += 1\n",
681
+ " self.update_learning_rate()\n",
682
+ "\n",
683
+ "\n",
684
+ " def val(self):\n",
685
+ " self.model.eval()\n",
686
+ " val_loss = 0.0\n",
687
+ " correct = 0\n",
688
+ " total = 0\n",
689
+ "\n",
690
+ " with torch.no_grad():\n",
691
+ " for batch_inputs, labels in tqdm(self.val_loader, desc=\"Validation\", leave=False):\n",
692
+ " freq_inputs = batch_inputs[\"freq_inputs\"].to(self.device)\n",
693
+ " seq_inputs = batch_inputs[\"seq_inputs\"].to(self.device)\n",
694
+ " pos_inputs = batch_inputs[\"pos_inputs\"].to(self.device)\n",
695
+ " y_val = labels.to(self.device)\n",
696
+ "\n",
697
+ " preds = self.model({\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs})\n",
698
+ " loss = self.criterion(preds, y_val)\n",
699
+ " # preds = (torch.sigmoid(preds)>0.5).float()\n",
700
+ " _, preds = torch.max(preds, dim=1)\n",
701
+ " val_loss += loss.item()\n",
702
+ " total += y_val.size(0)\n",
703
+ " correct += (preds == y_val).sum().item()\n",
704
+ " if self.val_loss is None or val_loss < self.val_loss:\n",
705
+ " self.val_loss = val_loss\n",
706
+ " self.save_checkpoint(\"best\")\n",
707
+ " # Log validation metrics\n",
708
+ " print(f\"Validation Loss: {val_loss / len(self.val_loader):.4f}\")\n",
709
+ " print(f\"Validation Accuracy: {correct / total:.4f}\")\n",
710
+ "\n",
711
+ " def save_checkpoint(self, epoch):\n",
712
+ " \"\"\"Save model in Hugging Face format.\"\"\"\n",
713
+ " checkpoint_dir = os.path.join(self.save_path, f\"checkpoint_epoch_{epoch}\")\n",
714
+ " if epoch ==\"best\":\n",
715
+ " checkpoint_dir = os.path.join(self.save_path, \"best\")\n",
716
+ " self.model.save_pretrained(checkpoint_dir)\n",
717
+ " print(f\"Checkpoint saved at {checkpoint_dir}\")"
718
+ ],
719
+ "id": "7be377251b81a25d",
720
+ "outputs": [],
721
+ "execution_count": 38
722
+ },
723
+ {
724
+ "metadata": {
725
+ "ExecuteTime": {
726
+ "end_time": "2024-12-16T08:23:59.280460Z",
727
+ "start_time": "2024-12-16T08:11:25.911156Z"
728
+ }
729
+ },
730
+ "cell_type": "code",
731
+ "source": [
732
+ "from torch.utils.data import DataLoader\n",
733
+ "\n",
734
+ "# Define a collate function to handle the batched data\n",
735
+ "def collate_fn(batch):\n",
736
+ " freq_inputs = torch.stack([torch.tensor(item[\"freq_inputs\"]) for item in batch])\n",
737
+ " seq_inputs = torch.stack([torch.tensor(item[\"seq_inputs\"]) for item in batch])\n",
738
+ " pos_inputs = torch.stack([torch.tensor(item[\"pos_inputs\"]) for item in batch])\n",
739
+ " labels = torch.tensor([torch.tensor(item[\"labels\"],dtype=torch.long) for item in batch])\n",
740
+ " return {\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs}, labels\n",
741
+ "\n",
742
+ "train_loader = DataLoader(dataset_train, batch_size=config.train[\"batch_size\"], shuffle=True,collate_fn=collate_fn)\n",
743
+ "test_loader = DataLoader(dataset_test, batch_size=config.train[\"batch_size\"], shuffle=False,collate_fn=collate_fn)\n",
744
+ "trainer = Trainer(model, train_loader, test_loader, config)\n",
745
+ "\n",
746
+ "# Train the model\n",
747
+ "trainer.train()\n",
748
+ "# Save the final model in Hugging Face format\n",
749
+ "final_save_path = os.path.join(config.base_exp_dir, \"checkpoints\")\n",
750
+ "model.save_pretrained(final_save_path)\n",
751
+ "print(f\"Final model saved at {final_save_path}\")\n"
752
+ ],
753
+ "id": "dd1749c306f148eb",
754
+ "outputs": [
755
+ {
756
+ "name": "stderr",
757
+ "output_type": "stream",
758
+ "text": [
759
+ "Epoch 1/10: 100%|██████████| 96/96 [01:00<00:00, 1.60it/s]\n"
760
+ ]
761
+ },
762
+ {
763
+ "name": "stdout",
764
+ "output_type": "stream",
765
+ "text": [
766
+ "Train Loss: 0.6920\n",
767
+ "Train Accuracy: 0.5217\n"
768
+ ]
769
+ },
770
+ {
771
+ "name": "stderr",
772
+ "output_type": "stream",
773
+ "text": [
774
+ " \r"
775
+ ]
776
+ },
777
+ {
778
+ "name": "stdout",
779
+ "output_type": "stream",
780
+ "text": [
781
+ "Checkpoint saved at ./exp/fox_nbc/checkpoints\\best\n",
782
+ "Validation Loss: 0.6914\n",
783
+ "Validation Accuracy: 0.5322\n"
784
+ ]
785
+ },
786
+ {
787
+ "name": "stderr",
788
+ "output_type": "stream",
789
+ "text": [
790
+ "Epoch 2/10: 100%|██████████| 96/96 [01:00<00:00, 1.59it/s]\n"
791
+ ]
792
+ },
793
+ {
794
+ "name": "stdout",
795
+ "output_type": "stream",
796
+ "text": [
797
+ "Train Loss: 0.6137\n",
798
+ "Train Accuracy: 0.6390\n"
799
+ ]
800
+ },
801
+ {
802
+ "name": "stderr",
803
+ "output_type": "stream",
804
+ "text": [
805
+ " \r"
806
+ ]
807
+ },
808
+ {
809
+ "name": "stdout",
810
+ "output_type": "stream",
811
+ "text": [
812
+ "Checkpoint saved at ./exp/fox_nbc/checkpoints\\best\n",
813
+ "Validation Loss: 0.4313\n",
814
+ "Validation Accuracy: 0.7937\n"
815
+ ]
816
+ },
817
+ {
818
+ "name": "stderr",
819
+ "output_type": "stream",
820
+ "text": [
821
+ "Epoch 3/10: 100%|██████████| 96/96 [01:00<00:00, 1.59it/s]\n"
822
+ ]
823
+ },
824
+ {
825
+ "name": "stdout",
826
+ "output_type": "stream",
827
+ "text": [
828
+ "Train Loss: 0.3599\n",
829
+ "Train Accuracy: 0.8466\n"
830
+ ]
831
+ },
832
+ {
833
+ "name": "stderr",
834
+ "output_type": "stream",
835
+ "text": [
836
+ " \r"
837
+ ]
838
+ },
839
+ {
840
+ "name": "stdout",
841
+ "output_type": "stream",
842
+ "text": [
843
+ "Validation Loss: 0.4837\n",
844
+ "Validation Accuracy: 0.7911\n"
845
+ ]
846
+ },
847
+ {
848
+ "name": "stderr",
849
+ "output_type": "stream",
850
+ "text": [
851
+ "Epoch 4/10: 100%|██████████| 96/96 [01:00<00:00, 1.59it/s]\n"
852
+ ]
853
+ },
854
+ {
855
+ "name": "stdout",
856
+ "output_type": "stream",
857
+ "text": [
858
+ "Train Loss: 0.2002\n",
859
+ "Train Accuracy: 0.9208\n"
860
+ ]
861
+ },
862
+ {
863
+ "name": "stderr",
864
+ "output_type": "stream",
865
+ "text": [
866
+ " \r"
867
+ ]
868
+ },
869
+ {
870
+ "name": "stdout",
871
+ "output_type": "stream",
872
+ "text": [
873
+ "Checkpoint saved at ./exp/fox_nbc/checkpoints\\best\n",
874
+ "Validation Loss: 0.3466\n",
875
+ "Validation Accuracy: 0.8647\n"
876
+ ]
877
+ },
878
+ {
879
+ "name": "stderr",
880
+ "output_type": "stream",
881
+ "text": [
882
+ "Epoch 5/10: 100%|██████████| 96/96 [01:00<00:00, 1.57it/s]\n"
883
+ ]
884
+ },
885
+ {
886
+ "name": "stdout",
887
+ "output_type": "stream",
888
+ "text": [
889
+ "Train Loss: 0.0939\n",
890
+ "Train Accuracy: 0.9655\n"
891
+ ]
892
+ },
893
+ {
894
+ "name": "stderr",
895
+ "output_type": "stream",
896
+ "text": [
897
+ " \r"
898
+ ]
899
+ },
900
+ {
901
+ "name": "stdout",
902
+ "output_type": "stream",
903
+ "text": [
904
+ "Validation Loss: 0.5384\n",
905
+ "Validation Accuracy: 0.8371\n",
906
+ "Checkpoint saved at ./exp/fox_nbc/checkpoints\\checkpoint_epoch_5\n"
907
+ ]
908
+ },
909
+ {
910
+ "name": "stderr",
911
+ "output_type": "stream",
912
+ "text": [
913
+ "Epoch 6/10: 100%|██████████| 96/96 [01:00<00:00, 1.58it/s]\n"
914
+ ]
915
+ },
916
+ {
917
+ "name": "stdout",
918
+ "output_type": "stream",
919
+ "text": [
920
+ "Train Loss: 0.0555\n",
921
+ "Train Accuracy: 0.9800\n"
922
+ ]
923
+ },
924
+ {
925
+ "name": "stderr",
926
+ "output_type": "stream",
927
+ "text": [
928
+ " \r"
929
+ ]
930
+ },
931
+ {
932
+ "name": "stdout",
933
+ "output_type": "stream",
934
+ "text": [
935
+ "Validation Loss: 0.4440\n",
936
+ "Validation Accuracy: 0.8857\n"
937
+ ]
938
+ },
939
+ {
940
+ "name": "stderr",
941
+ "output_type": "stream",
942
+ "text": [
943
+ "Epoch 7/10: 100%|██████████| 96/96 [01:00<00:00, 1.58it/s]\n"
944
+ ]
945
+ },
946
+ {
947
+ "name": "stdout",
948
+ "output_type": "stream",
949
+ "text": [
950
+ "Train Loss: 0.0422\n",
951
+ "Train Accuracy: 0.9855\n"
952
+ ]
953
+ },
954
+ {
955
+ "name": "stderr",
956
+ "output_type": "stream",
957
+ "text": [
958
+ " \r"
959
+ ]
960
+ },
961
+ {
962
+ "name": "stdout",
963
+ "output_type": "stream",
964
+ "text": [
965
+ "Validation Loss: 0.5519\n",
966
+ "Validation Accuracy: 0.8739\n"
967
+ ]
968
+ },
969
+ {
970
+ "name": "stderr",
971
+ "output_type": "stream",
972
+ "text": [
973
+ "Epoch 8/10: 100%|██████████| 96/96 [01:01<00:00, 1.56it/s]\n"
974
+ ]
975
+ },
976
+ {
977
+ "name": "stdout",
978
+ "output_type": "stream",
979
+ "text": [
980
+ "Train Loss: 0.0163\n",
981
+ "Train Accuracy: 0.9944\n"
982
+ ]
983
+ },
984
+ {
985
+ "name": "stderr",
986
+ "output_type": "stream",
987
+ "text": [
988
+ " \r"
989
+ ]
990
+ },
991
+ {
992
+ "name": "stdout",
993
+ "output_type": "stream",
994
+ "text": [
995
+ "Validation Loss: 0.6396\n",
996
+ "Validation Accuracy: 0.8581\n"
997
+ ]
998
+ },
999
+ {
1000
+ "name": "stderr",
1001
+ "output_type": "stream",
1002
+ "text": [
1003
+ "Epoch 9/10: 100%|██████████| 96/96 [01:01<00:00, 1.57it/s]\n"
1004
+ ]
1005
+ },
1006
+ {
1007
+ "name": "stdout",
1008
+ "output_type": "stream",
1009
+ "text": [
1010
+ "Train Loss: 0.0125\n",
1011
+ "Train Accuracy: 0.9974\n"
1012
+ ]
1013
+ },
1014
+ {
1015
+ "name": "stderr",
1016
+ "output_type": "stream",
1017
+ "text": [
1018
+ " \r"
1019
+ ]
1020
+ },
1021
+ {
1022
+ "name": "stdout",
1023
+ "output_type": "stream",
1024
+ "text": [
1025
+ "Validation Loss: 0.4890\n",
1026
+ "Validation Accuracy: 0.8857\n"
1027
+ ]
1028
+ },
1029
+ {
1030
+ "name": "stderr",
1031
+ "output_type": "stream",
1032
+ "text": [
1033
+ "Epoch 10/10: 100%|██████████| 96/96 [01:17<00:00, 1.24it/s]\n"
1034
+ ]
1035
+ },
1036
+ {
1037
+ "name": "stdout",
1038
+ "output_type": "stream",
1039
+ "text": [
1040
+ "Train Loss: 0.0039\n",
1041
+ "Train Accuracy: 0.9997\n"
1042
+ ]
1043
+ },
1044
+ {
1045
+ "name": "stderr",
1046
+ "output_type": "stream",
1047
+ "text": [
1048
+ " \r"
1049
+ ]
1050
+ },
1051
+ {
1052
+ "name": "stdout",
1053
+ "output_type": "stream",
1054
+ "text": [
1055
+ "Validation Loss: 0.5024\n",
1056
+ "Validation Accuracy: 0.8883\n",
1057
+ "Checkpoint saved at ./exp/fox_nbc/checkpoints\\checkpoint_epoch_10\n",
1058
+ "Final model saved at ./exp/fox_nbc/checkpoints\n"
1059
+ ]
1060
+ }
1061
+ ],
1062
+ "execution_count": 39
1063
+ },
1064
+ {
1065
+ "metadata": {},
1066
+ "cell_type": "markdown",
1067
+ "source": "## Evaluate Model",
1068
+ "id": "4af000263dd99bca"
1069
+ },
1070
+ {
1071
+ "metadata": {
1072
+ "ExecuteTime": {
1073
+ "end_time": "2024-12-16T08:24:40.129577Z",
1074
+ "start_time": "2024-12-16T08:24:26.080416Z"
1075
+ }
1076
+ },
1077
+ "cell_type": "code",
1078
+ "source": [
1079
+ "from transformers import AutoConfig, AutoModel\n",
1080
+ "from sklearn.metrics import accuracy_score, classification_report\n",
1081
+ "def load_last_checkpoint(checkpoint_dir):\n",
1082
+ " # Find all checkpoints in the directory\n",
1083
+ " checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith(\"checkpoint_epoch_\")]\n",
1084
+ " if not checkpoints:\n",
1085
+ " raise FileNotFoundError(f\"No checkpoints found in {checkpoint_dir}!\")\n",
1086
+ " # Sort checkpoints by epoch number\n",
1087
+ " checkpoints.sort(key=lambda x: int(x.split(\"_\")[-1]))\n",
1088
+ "\n",
1089
+ " # Load the last checkpoint\n",
1090
+ " last_checkpoint = os.path.join(checkpoint_dir, checkpoints[-1])\n",
1091
+ " # print(f\"Loading checkpoint from {last_checkpoint}\")\n",
1092
+ " # Load the best checkpoint\n",
1093
+ " #if os.path.join(checkpoint_dir, \"best\") is not None:\n",
1094
+ " # last_checkpoint = os.path.join(checkpoint_dir, \"best\")\n",
1095
+ " print(f\"Loading checkpoint from {last_checkpoint}\")\n",
1096
+ " # Load model and config\n",
1097
+ " config = AutoConfig.from_pretrained(last_checkpoint)\n",
1098
+ " model = AutoModel.from_pretrained(last_checkpoint, config=config)\n",
1099
+ " return model\n",
1100
+ "\n",
1101
+ "# Step 1: Define paths and setup\n",
1102
+ "checkpoint_dir = os.path.join(config.base_exp_dir, \"checkpoints\") # Directory where checkpoints are stored\n",
1103
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
1104
+ "model = load_last_checkpoint(checkpoint_dir)\n",
1105
+ "model.to(device)\n",
1106
+ "\n",
1107
+ "# criterion = torch.nn.BCEWithLogitsLoss()\n",
1108
+ "\n",
1109
+ "criterion = torch.nn.CrossEntropyLoss()\n",
1110
+ "\n",
1111
+ "def evaluate_model(model, val_loader, criterion, device=\"cuda\"):\n",
1112
+ " model.eval()\n",
1113
+ " val_loss = 0.0\n",
1114
+ " correct = 0\n",
1115
+ " total = 0\n",
1116
+ " all_preds = []\n",
1117
+ " all_labels = []\n",
1118
+ " with torch.no_grad():\n",
1119
+ " for batch_inputs, labels in tqdm(val_loader, desc=\"Testing\", leave=False):\n",
1120
+ " freq_inputs = batch_inputs[\"freq_inputs\"].to(device)\n",
1121
+ " seq_inputs = batch_inputs[\"seq_inputs\"].to(device)\n",
1122
+ " pos_inputs = batch_inputs[\"pos_inputs\"].to(device)\n",
1123
+ " labels = labels.to(device)\n",
1124
+ "\n",
1125
+ " preds= model({\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs})\n",
1126
+ " loss = criterion(preds, labels)\n",
1127
+ " _, preds = torch.max(preds, dim=1)\n",
1128
+ " # preds = (torch.sigmoid(preds) > 0.5).float()\n",
1129
+ " val_loss += loss.item()\n",
1130
+ " total += labels.size(0)\n",
1131
+ " # preds = (torch.sigmoid(preds) > 0.5).int()\n",
1132
+ " correct += (preds == labels).sum().item()\n",
1133
+ " all_preds.extend(preds.cpu().numpy())\n",
1134
+ " all_labels.extend(labels.cpu().numpy())\n",
1135
+ "\n",
1136
+ " return accuracy_score(all_labels, all_preds), classification_report(all_labels, all_preds)\n",
1137
+ "\n",
1138
+ "\n",
1139
+ "accuracy, report = evaluate_model(model, test_loader, criterion)\n",
1140
+ "print(f\"Accuracy: {accuracy:.4f}\")\n",
1141
+ "print(report)\n"
1142
+ ],
1143
+ "id": "b75d2dc8a300cdf6",
1144
+ "outputs": [
1145
+ {
1146
+ "name": "stderr",
1147
+ "output_type": "stream",
1148
+ "text": [
1149
+ "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
1150
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
1151
+ "Some weights of the model checkpoint at ./exp/fox_nbc/checkpoints\\checkpoint_epoch_10 were not used when initializing CustomModel: ['freq.lin0.parametrizations.weight.original0', 'freq.lin0.parametrizations.weight.original1', 'freq.lin1.parametrizations.weight.original0', 'freq.lin1.parametrizations.weight.original1', 'freq.lin2.parametrizations.weight.original0', 'freq.lin2.parametrizations.weight.original1', 'pos.lin0.parametrizations.weight.original0', 'pos.lin0.parametrizations.weight.original1', 'pos.lin1.parametrizations.weight.original0', 'pos.lin1.parametrizations.weight.original1', 'pos.lin2.parametrizations.weight.original0', 'pos.lin2.parametrizations.weight.original1']\n",
1152
+ "- This IS expected if you are initializing CustomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
1153
+ "- This IS NOT expected if you are initializing CustomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
1154
+ ]
1155
+ },
1156
+ {
1157
+ "name": "stdout",
1158
+ "output_type": "stream",
1159
+ "text": [
1160
+ "Loading checkpoint from ./exp/fox_nbc/checkpoints\\checkpoint_epoch_10\n"
1161
+ ]
1162
+ },
1163
+ {
1164
+ "name": "stderr",
1165
+ "output_type": "stream",
1166
+ "text": [
1167
+ "Some weights of CustomModel were not initialized from the model checkpoint at ./exp/fox_nbc/checkpoints\\checkpoint_epoch_10 and are newly initialized: ['freq.lin0.weight_g', 'freq.lin0.weight_v', 'freq.lin1.weight_g', 'freq.lin1.weight_v', 'freq.lin2.weight_g', 'freq.lin2.weight_v', 'pos.lin0.weight_g', 'pos.lin0.weight_v', 'pos.lin1.weight_g', 'pos.lin1.weight_v', 'pos.lin2.weight_g', 'pos.lin2.weight_v']\n",
1168
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
1169
+ " "
1170
+ ]
1171
+ },
1172
+ {
1173
+ "name": "stdout",
1174
+ "output_type": "stream",
1175
+ "text": [
1176
+ "Accuracy: 0.8883\n",
1177
+ " precision recall f1-score support\n",
1178
+ "\n",
1179
+ " 0 0.86 0.91 0.88 356\n",
1180
+ " 1 0.92 0.87 0.89 405\n",
1181
+ "\n",
1182
+ " accuracy 0.89 761\n",
1183
+ " macro avg 0.89 0.89 0.89 761\n",
1184
+ "weighted avg 0.89 0.89 0.89 761\n",
1185
+ "\n"
1186
+ ]
1187
+ },
1188
+ {
1189
+ "name": "stderr",
1190
+ "output_type": "stream",
1191
+ "text": [
1192
+ "\r"
1193
+ ]
1194
+ }
1195
+ ],
1196
+ "execution_count": 41
1197
+ },
1198
+ {
1199
+ "metadata": {},
1200
+ "cell_type": "markdown",
1201
+ "source": "# Part 3. Pushing the Model to the Hugging Face",
1202
+ "id": "d2ffeb383ea00beb"
1203
+ },
1204
+ {
1205
+ "metadata": {
1206
+ "ExecuteTime": {
1207
+ "end_time": "2024-12-16T08:25:19.716888Z",
1208
+ "start_time": "2024-12-16T08:24:56.833580Z"
1209
+ }
1210
+ },
1211
+ "cell_type": "code",
1212
+ "source": "model.push_model(REPO_NAME)",
1213
+ "id": "f55c22b0a1b2a66b",
1214
+ "outputs": [
1215
+ {
1216
+ "data": {
1217
+ "text/plain": [
1218
+ "README.md: 0%| | 0.00/326 [00:00<?, ?B/s]"
1219
+ ],
1220
+ "application/vnd.jupyter.widget-view+json": {
1221
+ "version_major": 2,
1222
+ "version_minor": 0,
1223
+ "model_id": "7595a65d4b2f4332a7655b4270bb6f5b"
1224
+ }
1225
+ },
1226
+ "metadata": {},
1227
+ "output_type": "display_data"
1228
+ },
1229
+ {
1230
+ "data": {
1231
+ "text/plain": [
1232
+ "model.safetensors: 0%| | 0.00/517M [00:00<?, ?B/s]"
1233
+ ],
1234
+ "application/vnd.jupyter.widget-view+json": {
1235
+ "version_major": 2,
1236
+ "version_minor": 0,
1237
+ "model_id": "6ad5e72579ed463b9d6bbb1c3ef4d2a9"
1238
+ }
1239
+ },
1240
+ "metadata": {},
1241
+ "output_type": "display_data"
1242
+ }
1243
+ ],
1244
+ "execution_count": 42
1245
+ },
1246
+ {
1247
+ "metadata": {},
1248
+ "cell_type": "markdown",
1249
+ "source": "### NOTE: You need to ensure that your Hugging Face token has both read and write access to your repository and Hugging Face organization.",
1250
+ "id": "3826c0b6195a8fd5"
1251
+ },
1252
+ {
1253
+ "metadata": {
1254
+ "ExecuteTime": {
1255
+ "end_time": "2024-12-16T08:25:45.224733Z",
1256
+ "start_time": "2024-12-16T08:25:21.297050Z"
1257
+ }
1258
+ },
1259
+ "cell_type": "code",
1260
+ "source": [
1261
+ "# Load model directly\n",
1262
+ "from transformers import AutoModel, AutoConfig\n",
1263
+ "config = AutoConfig.from_pretrained(\"CISProject/News-Headline-Classifier-Notebook\")\n",
1264
+ "model = AutoModel.from_pretrained(\"CISProject/News-Headline-Classifier-Notebook\",config = config)"
1265
+ ],
1266
+ "id": "33a0ca269c24d700",
1267
+ "outputs": [
1268
+ {
1269
+ "data": {
1270
+ "text/plain": [
1271
+ "config.json: 0%| | 0.00/1.09k [00:00<?, ?B/s]"
1272
+ ],
1273
+ "application/vnd.jupyter.widget-view+json": {
1274
+ "version_major": 2,
1275
+ "version_minor": 0,
1276
+ "model_id": "11974267c2544c6885478d2f63a8c41c"
1277
+ }
1278
+ },
1279
+ "metadata": {},
1280
+ "output_type": "display_data"
1281
+ },
1282
+ {
1283
+ "data": {
1284
+ "text/plain": [
1285
+ "model.safetensors: 0%| | 0.00/517M [00:00<?, ?B/s]"
1286
+ ],
1287
+ "application/vnd.jupyter.widget-view+json": {
1288
+ "version_major": 2,
1289
+ "version_minor": 0,
1290
+ "model_id": "140a2fb1413a409782bb4322ec70e4d4"
1291
+ }
1292
+ },
1293
+ "metadata": {},
1294
+ "output_type": "display_data"
1295
+ },
1296
+ {
1297
+ "name": "stderr",
1298
+ "output_type": "stream",
1299
+ "text": [
1300
+ "C:\\Users\\swall\\anaconda3\\envs\\newsCLS\\Lib\\site-packages\\torch\\nn\\utils\\weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n",
1301
+ " WeightNorm.apply(module, name, dim)\n",
1302
+ "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
1303
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
1304
+ "Some weights of the model checkpoint at CISProject/News-Headline-Classifier-Notebook were not used when initializing CustomModel: ['freq.lin0.parametrizations.weight.original0', 'freq.lin0.parametrizations.weight.original1', 'freq.lin1.parametrizations.weight.original0', 'freq.lin1.parametrizations.weight.original1', 'freq.lin2.parametrizations.weight.original0', 'freq.lin2.parametrizations.weight.original1', 'pos.lin0.parametrizations.weight.original0', 'pos.lin0.parametrizations.weight.original1', 'pos.lin1.parametrizations.weight.original0', 'pos.lin1.parametrizations.weight.original1', 'pos.lin2.parametrizations.weight.original0', 'pos.lin2.parametrizations.weight.original1']\n",
1305
+ "- This IS expected if you are initializing CustomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
1306
+ "- This IS NOT expected if you are initializing CustomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
1307
+ "Some weights of CustomModel were not initialized from the model checkpoint at CISProject/News-Headline-Classifier-Notebook and are newly initialized: ['freq.lin0.weight_g', 'freq.lin0.weight_v', 'freq.lin1.weight_g', 'freq.lin1.weight_v', 'freq.lin2.weight_g', 'freq.lin2.weight_v', 'pos.lin0.weight_g', 'pos.lin0.weight_v', 'pos.lin1.weight_g', 'pos.lin1.weight_v', 'pos.lin2.weight_g', 'pos.lin2.weight_v']\n",
1308
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1309
+ ]
1310
+ }
1311
+ ],
1312
+ "execution_count": 43
1313
+ },
1314
+ {
1315
+ "metadata": {
1316
+ "ExecuteTime": {
1317
+ "end_time": "2024-12-16T08:26:00.658712Z",
1318
+ "start_time": "2024-12-16T08:25:47.304227Z"
1319
+ }
1320
+ },
1321
+ "cell_type": "code",
1322
+ "source": [
1323
+ "from transformers import AutoConfig, AutoModel\n",
1324
+ "from sklearn.metrics import accuracy_score, classification_report\n",
1325
+ "\n",
1326
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
1327
+ "model.to(device)\n",
1328
+ "\n",
1329
+ "#criterion = torch.nn.BCEWithLogitsLoss()\n",
1330
+ "\n",
1331
+ "criterion = torch.nn.CrossEntropyLoss()\n",
1332
+ "def evaluate_model(model, val_loader, criterion, device=\"cuda\"):\n",
1333
+ " model.eval()\n",
1334
+ " val_loss = 0.0\n",
1335
+ " correct = 0\n",
1336
+ " total = 0\n",
1337
+ " all_preds = []\n",
1338
+ " all_labels = []\n",
1339
+ " with torch.no_grad():\n",
1340
+ " for batch_inputs, labels in tqdm(val_loader, desc=\"Testing\", leave=False):\n",
1341
+ " freq_inputs = batch_inputs[\"freq_inputs\"].to(device)\n",
1342
+ " seq_inputs = batch_inputs[\"seq_inputs\"].to(device)\n",
1343
+ " pos_inputs = batch_inputs[\"pos_inputs\"].to(device)\n",
1344
+ " labels = labels.to(device)\n",
1345
+ "\n",
1346
+ " preds = model({\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs})\n",
1347
+ " loss = criterion(preds, labels)\n",
1348
+ " _, preds = torch.max(preds, dim=1)\n",
1349
+ " # preds = (torch.sigmoid(preds) > 0.5).float()\n",
1350
+ " val_loss += loss.item()\n",
1351
+ " total += labels.size(0)\n",
1352
+ " correct += (preds == labels).sum().item()\n",
1353
+ " all_preds.extend(preds.cpu().numpy())\n",
1354
+ " all_labels.extend(labels.cpu().numpy())\n",
1355
+ "\n",
1356
+ " return accuracy_score(all_labels, all_preds), classification_report(all_labels, all_preds)\n",
1357
+ "\n",
1358
+ "\n",
1359
+ "accuracy, report = evaluate_model(model, test_loader, criterion)\n",
1360
+ "print(f\"Accuracy: {accuracy:.4f}\")\n",
1361
+ "print(report)\n"
1362
+ ],
1363
+ "id": "cc313b4396f87690",
1364
+ "outputs": [
1365
+ {
1366
+ "name": "stderr",
1367
+ "output_type": "stream",
1368
+ "text": [
1369
+ " "
1370
+ ]
1371
+ },
1372
+ {
1373
+ "name": "stdout",
1374
+ "output_type": "stream",
1375
+ "text": [
1376
+ "Accuracy: 0.8883\n",
1377
+ " precision recall f1-score support\n",
1378
+ "\n",
1379
+ " 0 0.86 0.91 0.88 356\n",
1380
+ " 1 0.92 0.87 0.89 405\n",
1381
+ "\n",
1382
+ " accuracy 0.89 761\n",
1383
+ " macro avg 0.89 0.89 0.89 761\n",
1384
+ "weighted avg 0.89 0.89 0.89 761\n",
1385
+ "\n"
1386
+ ]
1387
+ },
1388
+ {
1389
+ "name": "stderr",
1390
+ "output_type": "stream",
1391
+ "text": [
1392
+ "\r"
1393
+ ]
1394
+ }
1395
+ ],
1396
+ "execution_count": 44
1397
+ }
1398
+ ],
1399
+ "metadata": {
1400
+ "kernelspec": {
1401
+ "name": "python3",
1402
+ "language": "python",
1403
+ "display_name": "Python 3 (ipykernel)"
1404
+ }
1405
+ },
1406
+ "nbformat": 5,
1407
+ "nbformat_minor": 9
1408
+ }