TUEN-YUE commited on
Commit
8e86e93
·
verified ·
1 Parent(s): d6cf153

Delete train&test.ipynb

Browse files
Files changed (1) hide show
  1. train&test.ipynb +0 -1309
train&test.ipynb DELETED
@@ -1,1309 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "metadata": {},
5
- "cell_type": "markdown",
6
- "source": [
7
- "# Installing dependencies\n",
8
- "## Please make a copy of this notebook."
9
- ],
10
- "id": "13156d7ed48b282"
11
- },
12
- {
13
- "metadata": {},
14
- "cell_type": "markdown",
15
- "source": [
16
- "# Huggingface login\n",
17
- "You will require your personal token."
18
- ],
19
- "id": "432a756039e6399"
20
- },
21
- {
22
- "metadata": {},
23
- "cell_type": "code",
24
- "source": "source": [
25
- "!pip install geopy > delete.txt\n",
26
- "!pip install datasets > delete.txt\n",
27
- "!pip install torch torchvision datasets > delete.txt\n",
28
- "!pip install huggingface_hub > delete.txt\n",
29
- "!pip install pyhocon > delete.txt\n",
30
- "!pip install transformers > delete.txt\n",
31
- "!pip install gensim > delete.txt\n",
32
- "!rm delete.txt"
33
- ],
34
- "id": "2e73da09a7c6171e",
35
- "outputs": [],
36
- "execution_count": null
37
- },
38
- {
39
- "metadata": {},
40
- "cell_type": "markdown",
41
- "source": "# Part 1: Load Data",
42
- "id": "c731d9c1ebb477dc"
43
- },
44
- {
45
- "metadata": {},
46
- "cell_type": "markdown",
47
- "source": "## Downloading the train and test dataset",
48
- "id": "14070f20b547688f"
49
- },
50
- {
51
- "metadata": {},
52
- "cell_type": "markdown",
53
- "source": "",
54
- "id": "b8920847b7cc378d"
55
- },
56
- {
57
- "metadata": {},
58
- "cell_type": "code",
59
- "source": [
60
- "from datasets import load_dataset\n",
61
- "\n",
62
- "dataset_train = load_dataset(\"CISProject/FOX_NBC\", split=\"train\")\n",
63
- "dataset_test = load_dataset(\"CISProject/FOX_NBC\", split=\"test\")\n",
64
- "# dataset_test = load_dataset(\"CISProject/FOX_NBC\", split=\"test_data_random_subset\")\n"
65
- ],
66
- "id": "877c90c978d62b7d",
67
- "outputs": [],
68
- "execution_count": 12
69
- },
70
- {
71
- "metadata": {
72
- "ExecuteTime": {
73
- "end_time": "2024-12-16T18:33:00.318956Z",
74
- "start_time": "2024-12-16T18:33:00.310428Z"
75
- }
76
- },
77
- "cell_type": "code",
78
- "source": [
79
- "import numpy as np\n",
80
- "import torch\n",
81
- "import re\n",
82
- "from transformers import BertTokenizer\n",
83
- "from transformers import RobertaTokenizer\n",
84
- "from sklearn.feature_extraction.text import CountVectorizer\n",
85
- "from gensim.models import KeyedVectors\n",
86
- "from sklearn.feature_extraction.text import TfidfVectorizer\n",
87
- "\n",
88
- "def preprocess_data(data,\n",
89
- " mode=\"train\",\n",
90
- " vectorizer=None,\n",
91
- " w2v_model=None,\n",
92
- " max_features=4096,\n",
93
- " max_seq_length=128,\n",
94
- " num_proc=4):\n",
95
- " if w2v_model is None:\n",
96
- " raise ValueError(\"w2v_model must be provided for Word2Vec embeddings.\")\n",
97
- "\n",
98
- " # tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
99
- " tokenizer = RobertaTokenizer.from_pretrained(\"roberta-base\")\n",
100
- " # 1. Clean text once\n",
101
- " def clean_text(examples):\n",
102
- " import re\n",
103
- " cleaned = []\n",
104
- " for text in examples[\"title\"]:\n",
105
- " text = text.lower()\n",
106
- " text = re.sub(r'[^\\w\\s]', '', text)\n",
107
- " text = text.strip()\n",
108
- " cleaned.append(text)\n",
109
- " return {\"clean_title\": cleaned}\n",
110
- "\n",
111
- " data = data.map(clean_text, batched=True, num_proc=num_proc)\n",
112
- "\n",
113
- " # 2. Fit CountVectorizer on training data if needed\n",
114
- " if mode == \"train\" and vectorizer is None:\n",
115
- " # Collect all cleaned titles to fit\n",
116
- " all_titles = data[\"clean_title\"]\n",
117
- " #vectorizer = CountVectorizer(max_features=max_features, ngram_range=(1,2))\n",
118
- " vectorizer = TfidfVectorizer(max_features=max_features)\n",
119
- " vectorizer.fit(all_titles)\n",
120
- " print(\"vectorizer fitted on training data.\")\n",
121
- "\n",
122
- " # 3. Transform titles with vectorizer once\n",
123
- " def vectorize_batch(examples):\n",
124
- " import numpy as np\n",
125
- " freq = vectorizer.transform(examples[\"clean_title\"]).toarray().astype(np.float32)\n",
126
- " return {\"freq_inputs\": freq}\n",
127
- "\n",
128
- " data = data.map(vectorize_batch, batched=True, num_proc=num_proc)\n",
129
- "\n",
130
- " # 4. Tokenize with BERT once\n",
131
- " def tokenize_batch(examples):\n",
132
- " tokenized = tokenizer(\n",
133
- " examples[\"title\"],\n",
134
- " padding=\"max_length\",\n",
135
- " truncation=True,\n",
136
- " max_length=max_seq_length\n",
137
- " )\n",
138
- " return {\n",
139
- " \"input_ids\": tokenized[\"input_ids\"],\n",
140
- " \"attention_mask\": tokenized[\"attention_mask\"]\n",
141
- " }\n",
142
- "\n",
143
- " data = data.map(tokenize_batch, batched=True, num_proc=num_proc)\n",
144
- "\n",
145
- " # 5. Convert titles into tokens for W2V\n",
146
- " def split_tokens(examples):\n",
147
- " tokens_list = [t.split() for t in examples[\"clean_title\"]]\n",
148
- " return {\"tokens\": tokens_list}\n",
149
- "\n",
150
- " data = data.map(split_tokens, batched=True, num_proc=num_proc)\n",
151
- "\n",
152
- " # Build an embedding dictionary for all unique tokens (do this once before embedding map)\n",
153
- " unique_tokens = set()\n",
154
- " for tokens in data[\"tokens\"]:\n",
155
- " unique_tokens.update(tokens)\n",
156
- "\n",
157
- " embedding_dim = w2v_model.vector_size\n",
158
- " embedding_dict = {}\n",
159
- " for tk in unique_tokens:\n",
160
- " if tk in w2v_model:\n",
161
- " embedding_dict[tk] = w2v_model[tk].astype(np.float32)\n",
162
- " else:\n",
163
- " embedding_dict[tk] = np.zeros((embedding_dim,), dtype=np.float32)\n",
164
- "\n",
165
- " def w2v_embedding_batch(examples):\n",
166
- " import numpy as np\n",
167
- " batch_w2v = []\n",
168
- " for tokens in examples[\"tokens\"]:\n",
169
- " vectors = [embedding_dict[tk] for tk in tokens[:max_seq_length]]\n",
170
- " if len(vectors) < max_seq_length:\n",
171
- " vectors += [np.zeros((embedding_dim,), dtype=np.float32)] * (max_seq_length - len(vectors))\n",
172
- " batch_w2v.append(vectors)\n",
173
- " return {\"pos_inputs\": batch_w2v}\n",
174
- "\n",
175
- "\n",
176
- " data = data.map(w2v_embedding_batch, batched=True, batch_size=32, num_proc=num_proc)\n",
177
- "\n",
178
- " # 7. Create labels\n",
179
- " def make_labels(examples):\n",
180
- " labels = examples[\"labels\"]\n",
181
- " return {\"labels\": labels}\n",
182
- "\n",
183
- " data = data.map(make_labels, batched=True, num_proc=num_proc)\n",
184
- "\n",
185
- " # Convert freq_inputs and pos_inputs to torch tensors in a final map step\n",
186
- " def to_tensors(examples):\n",
187
- " import torch\n",
188
- "\n",
189
- " freq_inputs = torch.tensor(examples[\"freq_inputs\"], dtype=torch.float32)\n",
190
- " input_ids = torch.tensor(examples[\"input_ids\"])\n",
191
- " attention_mask = torch.tensor(examples[\"attention_mask\"])\n",
192
- " pos_inputs = torch.tensor(examples[\"pos_inputs\"], dtype=torch.float32)\n",
193
- " labels = torch.tensor(examples[\"labels\"],dtype=torch.long)\n",
194
- "\n",
195
- " # seq_inputs shape: (batch_size, 2, seq_len)\n",
196
- " seq_inputs = torch.stack([input_ids, attention_mask], dim=1)\n",
197
- "\n",
198
- " return {\n",
199
- " \"freq_inputs\": freq_inputs,\n",
200
- " \"seq_inputs\": seq_inputs,\n",
201
- " \"pos_inputs\": pos_inputs,\n",
202
- " \"labels\": labels\n",
203
- " }\n",
204
- "\n",
205
- " # Apply final conversion to tensor\n",
206
- " processed_data = data.map(to_tensors, batched=True, num_proc=num_proc)\n",
207
- "\n",
208
- " return processed_data, vectorizer\n"
209
- ],
210
- "id": "dc2ba675ce880d6d",
211
- "outputs": [],
212
- "execution_count": 13
213
- },
214
- {
215
- "metadata": {
216
- "ExecuteTime": {
217
- "end_time": "2024-12-16T18:33:26.890102Z",
218
- "start_time": "2024-12-16T18:33:00.323837Z"
219
- }
220
- },
221
- "cell_type": "code",
222
- "source": [
223
- "from gensim.models import KeyedVectors\n",
224
- "w2v_model = KeyedVectors.load_word2vec_format(\"./GoogleNews-vectors-negative300.bin\", binary=True)\n",
225
- "\n",
226
- "dataset_train,vectorizer = preprocess_data(\n",
227
- " data=dataset_train,\n",
228
- " mode=\"train\",\n",
229
- " w2v_model=w2v_model,\n",
230
- " max_features=8192,\n",
231
- " max_seq_length=128\n",
232
- ")\n",
233
- "\n",
234
- "dataset_test, _ = preprocess_data(\n",
235
- " data=dataset_test,\n",
236
- " mode=\"test\",\n",
237
- " vectorizer=vectorizer,\n",
238
- " w2v_model=w2v_model,\n",
239
- " max_features=8192,\n",
240
- " max_seq_length=128\n",
241
- ")"
242
- ],
243
- "id": "158b99950fb22d1",
244
- "outputs": [
245
- {
246
- "name": "stdout",
247
- "output_type": "stream",
248
- "text": [
249
- "vectorizer fitted on training data.\n"
250
- ]
251
- }
252
- ],
253
- "execution_count": 14
254
- },
255
- {
256
- "metadata": {
257
- "ExecuteTime": {
258
- "end_time": "2024-12-16T18:33:26.904401Z",
259
- "start_time": "2024-12-16T18:33:26.899278Z"
260
- }
261
- },
262
- "cell_type": "code",
263
- "source": [
264
- "print(dataset_train)\n",
265
- "print(dataset_test)"
266
- ],
267
- "id": "edd80d33175c96a0",
268
- "outputs": [
269
- {
270
- "name": "stdout",
271
- "output_type": "stream",
272
- "text": [
273
- "Dataset({\n",
274
- " features: ['title', 'outlet', 'index', 'url', 'labels', 'clean_title', 'freq_inputs', 'input_ids', 'attention_mask', 'tokens', 'pos_inputs', 'seq_inputs'],\n",
275
- " num_rows: 3044\n",
276
- "})\n",
277
- "Dataset({\n",
278
- " features: ['title', 'outlet', 'index', 'url', 'labels', 'clean_title', 'freq_inputs', 'input_ids', 'attention_mask', 'tokens', 'pos_inputs', 'seq_inputs'],\n",
279
- " num_rows: 761\n",
280
- "})\n"
281
- ]
282
- }
283
- ],
284
- "execution_count": 15
285
- },
286
- {
287
- "metadata": {},
288
- "cell_type": "markdown",
289
- "source": "# Part 2: Model",
290
- "id": "c9a49fc1fbca29d7"
291
- },
292
- {
293
- "metadata": {},
294
- "cell_type": "markdown",
295
- "source": "## Defining the Custom Model",
296
- "id": "aebe5e51f0e611cc"
297
- },
298
- {
299
- "metadata": {},
300
- "cell_type": "markdown",
301
- "source": "",
302
- "id": "f0eae08a025b6ed9"
303
- },
304
- {
305
- "metadata": {
306
- "ExecuteTime": {
307
- "end_time": "2024-12-16T18:33:26.937874Z",
308
- "start_time": "2024-12-16T18:33:26.926248Z"
309
- }
310
- },
311
- "cell_type": "code",
312
- "source": [
313
- "# TODO: import all packages necessary for your custom model\n",
314
- "import pandas as pd\n",
315
- "import os\n",
316
- "from torch.utils.data import DataLoader\n",
317
- "from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel\n",
318
- "import torch\n",
319
- "import torch.nn as nn\n",
320
- "from transformers import RobertaModel, RobertaConfig,RobertaForSequenceClassification, BertModel\n",
321
- "from model.network import Classifier\n",
322
- "from model.frequential import FreqNetwork\n",
323
- "from model.sequential import SeqNetwork\n",
324
- "from model.positional import PosNetwork\n",
325
- "\n",
326
- "class CustomConfig(PretrainedConfig):\n",
327
- " model_type = \"headlineclassifier\"\n",
328
- "\n",
329
- " def __init__(\n",
330
- " self,\n",
331
- " base_exp_dir=\"./exp/fox_nbc/\",\n",
332
- " # dataset={\"data_dir\": \"./data/CASE_NAME/data.csv\", \"transform\": True},\n",
333
- " train={\n",
334
- " \"learning_rate\": 2e-5,\n",
335
- " \"learning_rate_alpha\": 0.05,\n",
336
- " \"end_iter\": 10,\n",
337
- " \"batch_size\": 32,\n",
338
- " \"warm_up_end\": 2,\n",
339
- " \"anneal_end\": 5,\n",
340
- " \"save_freq\": 1,\n",
341
- " \"val_freq\": 1,\n",
342
- " },\n",
343
- " model={\n",
344
- " \"freq\": {\n",
345
- " \"tfidf_input_dim\": 8145,\n",
346
- " \"tfidf_output_dim\": 128,\n",
347
- " \"tfidf_hidden_dim\": 512,\n",
348
- " \"n_layers\": 2,\n",
349
- " \"skip_in\": [80],\n",
350
- " \"weight_norm\": True,\n",
351
- " },\n",
352
- " \"pos\": {\n",
353
- " \"input_dim\": 300,\n",
354
- " \"output_dim\": 128,\n",
355
- " \"hidden_dim\": 256,\n",
356
- " \"n_layers\": 2,\n",
357
- " \"skip_in\": [80],\n",
358
- " \"weight_norm\": True,\n",
359
- " },\n",
360
- " \"cls\": {\n",
361
- " \"combined_input\": 1024, #1024\n",
362
- " \"combined_dim\": 128,\n",
363
- " \"num_classes\": 2,\n",
364
- " \"n_layers\": 2,\n",
365
- " \"skip_in\": [80],\n",
366
- " \"weight_norm\": True,\n",
367
- " },\n",
368
- " },\n",
369
- " **kwargs,\n",
370
- " ):\n",
371
- " super().__init__(**kwargs)\n",
372
- "\n",
373
- " self.base_exp_dir = base_exp_dir\n",
374
- " # self.dataset = dataset\n",
375
- " self.train = train\n",
376
- " self.model = model\n",
377
- "\n",
378
- "# TODO: define all parameters needed for your model, as well as calling the model itself\n",
379
- "class CustomModel(PreTrainedModel):\n",
380
- " config_class = CustomConfig\n",
381
- "\n",
382
- " def __init__(self, config):\n",
383
- " super().__init__(config)\n",
384
- " self.conf = config\n",
385
- " self.freq = FreqNetwork(**self.conf.model[\"freq\"])\n",
386
- " self.pos = PosNetwork(**self.conf.model[\"pos\"])\n",
387
- " self.cls = Classifier(**self.conf.model[\"cls\"])\n",
388
- " self.fc = nn.Linear(self.conf.model[\"cls\"][\"combined_input\"],2)\n",
389
- " self.seq = RobertaModel.from_pretrained(\"roberta-base\")\n",
390
- " # self.seq = BertModel.from_pretrained(\"bert-base-uncased\")\n",
391
- " #for param in self.roberta.parameters():\n",
392
- " # param.requires_grad = False\n",
393
- " self.dropout = nn.Dropout(0.2)\n",
394
- "\n",
395
- " def forward(self, x):\n",
396
- " freq_inputs = x[\"freq_inputs\"]\n",
397
- " seq_inputs = x[\"seq_inputs\"]\n",
398
- " pos_inputs = x[\"pos_inputs\"]\n",
399
- " seq_feature = self.seq(\n",
400
- " input_ids=seq_inputs[:,0,:],\n",
401
- " attention_mask=seq_inputs[:,1,:]\n",
402
- " ).pooler_output # last_hidden_state[:, 0, :]\n",
403
- " freq_feature = self.freq(freq_inputs) # Shape: (batch_size, 128)\n",
404
- "\n",
405
- " pos_feature = self.pos(pos_inputs) #Shape: (batch_size, 128)\n",
406
- " inputs = torch.cat((seq_feature, freq_feature, pos_feature), dim=1) # Shape: (batch_size, 384)\n",
407
- " # inputs = torch.cat((seq_feature, freq_feature), dim=1) # Shape: (batch_size,256)\n",
408
- " # inputs = seq_feature\n",
409
- "\n",
410
- " x = inputs\n",
411
- " x = self.dropout(x)\n",
412
- " outputs = self.fc(x)\n",
413
- "\n",
414
- " return outputs\n",
415
- "\n",
416
- " def save_model(self, save_path):\n",
417
- " \"\"\"Save the model locally using the Hugging Face format.\"\"\"\n",
418
- " self.save_pretrained(save_path)\n",
419
- "\n",
420
- " def push_model(self, repo_name):\n",
421
- " \"\"\"Push the model to the Hugging Face Hub.\"\"\"\n",
422
- " self.push_to_hub(repo_name)"
423
- ],
424
- "id": "21f079d0c52d7d",
425
- "outputs": [],
426
- "execution_count": 16
427
- },
428
- {
429
- "metadata": {
430
- "ExecuteTime": {
431
- "end_time": "2024-12-16T18:33:27.235482Z",
432
- "start_time": "2024-12-16T18:33:26.951564Z"
433
- }
434
- },
435
- "cell_type": "code",
436
- "source": [
437
- "from huggingface_hub import hf_hub_download\n",
438
- "\n",
439
- "AutoConfig.register(\"headlineclassifier\", CustomConfig)\n",
440
- "AutoModel.register(CustomConfig, CustomModel)\n",
441
- "config = CustomConfig()\n",
442
- "model = CustomModel(config)\n",
443
- "\n",
444
- "REPO_NAME = \"CISProject/News-Headline-Classifier-Notebook\" # TODO: PROVIDE A STRING TO YOUR REPO ON HUGGINGFACE"
445
- ],
446
- "id": "b6ba3f96d3ce21",
447
- "outputs": [
448
- {
449
- "name": "stderr",
450
- "output_type": "stream",
451
- "text": [
452
- "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
453
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
454
- ]
455
- }
456
- ],
457
- "execution_count": 17
458
- },
459
- {
460
- "metadata": {
461
- "ExecuteTime": {
462
- "end_time": "2024-12-16T18:33:27.279248Z",
463
- "start_time": "2024-12-16T18:33:27.261675Z"
464
- }
465
- },
466
- "cell_type": "code",
467
- "source": [
468
- "import torch\n",
469
- "from tqdm import tqdm\n",
470
- "import os\n",
471
- "\n",
472
- "\n",
473
- "class Trainer:\n",
474
- " def __init__(self, model, train_loader, val_loader, config, device=\"cuda\"):\n",
475
- " self.model = model.to(device)\n",
476
- " self.train_loader = train_loader\n",
477
- " self.val_loader = val_loader\n",
478
- " self.device = device\n",
479
- " self.conf = config\n",
480
- "\n",
481
- " self.end_iter = self.conf.train[\"end_iter\"]\n",
482
- " self.save_freq = self.conf.train[\"save_freq\"]\n",
483
- " self.val_freq = self.conf.train[\"val_freq\"]\n",
484
- "\n",
485
- " self.batch_size = self.conf.train['batch_size']\n",
486
- " self.learning_rate = self.conf.train['learning_rate']\n",
487
- " self.learning_rate_alpha = self.conf.train['learning_rate_alpha']\n",
488
- " self.warm_up_end = self.conf.train['warm_up_end']\n",
489
- " self.anneal_end = self.conf.train['anneal_end']\n",
490
- "\n",
491
- " self.optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)\n",
492
- " #self.criterion = torch.nn.BCEWithLogitsLoss()\n",
493
- " self.criterion = torch.nn.CrossEntropyLoss()\n",
494
- " self.save_path = os.path.join(self.conf.base_exp_dir, \"checkpoints\")\n",
495
- " os.makedirs(self.save_path, exist_ok=True)\n",
496
- "\n",
497
- " self.iter_step = 0\n",
498
- "\n",
499
- " self.val_loss = None\n",
500
- "\n",
501
- " def get_cos_anneal_ratio(self):\n",
502
- " if self.anneal_end == 0.0:\n",
503
- " return 1.0\n",
504
- " else:\n",
505
- " return np.min([1.0, self.iter_step / self.anneal_end])\n",
506
- "\n",
507
- " def update_learning_rate(self):\n",
508
- " if self.iter_step < self.warm_up_end:\n",
509
- " learning_factor = self.iter_step / self.warm_up_end\n",
510
- " else:\n",
511
- " alpha = self.learning_rate_alpha\n",
512
- " progress = (self.iter_step - self.warm_up_end) / (self.end_iter - self.warm_up_end)\n",
513
- " learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha\n",
514
- "\n",
515
- " for g in self.optimizer.param_groups:\n",
516
- " g['lr'] = self.learning_rate * learning_factor\n",
517
- "\n",
518
- " def train(self):\n",
519
- " for epoch in range(self.end_iter):\n",
520
- " self.update_learning_rate()\n",
521
- " self.model.train()\n",
522
- " epoch_loss = 0.0\n",
523
- " correct = 0\n",
524
- " total = 0\n",
525
- "\n",
526
- " for batch_inputs, labels in tqdm(self.train_loader, desc=f\"Epoch {epoch + 1}/{self.end_iter}\"):\n",
527
- " # Extract features\n",
528
- "\n",
529
- " freq_inputs = batch_inputs[\"freq_inputs\"].to(self.device)\n",
530
- " seq_inputs = batch_inputs[\"seq_inputs\"].to(self.device)\n",
531
- " pos_inputs = batch_inputs[\"pos_inputs\"].to(self.device)\n",
532
- " # y_train = labels.to(self.device)[:,None]\n",
533
- " y_train = labels.to(self.device)\n",
534
- "\n",
535
- " # Forward pass\n",
536
- " preds = self.model({\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs})\n",
537
- " loss = self.criterion(preds, y_train)\n",
538
- "\n",
539
- " # preds = (torch.sigmoid(preds) > 0.5).int()\n",
540
- " # Backward pass\n",
541
- " self.optimizer.zero_grad()\n",
542
- " loss.backward()\n",
543
- " self.optimizer.step()\n",
544
- " _, preds = torch.max(preds, dim=1)\n",
545
- " # Metrics\n",
546
- " epoch_loss += loss.item()\n",
547
- " total += y_train.size(0)\n",
548
- " # print(preds.shape)\n",
549
- " correct += (preds == y_train).sum().item()\n",
550
- "\n",
551
- " # Log epoch metrics\n",
552
- " print(f\"Train Loss: {epoch_loss / len(self.train_loader):.4f}\")\n",
553
- " print(f\"Train Accuracy: {correct / total:.4f}\")\n",
554
- "\n",
555
- " # Validation and Save Checkpoints\n",
556
- " if (epoch + 1) % self.val_freq == 0:\n",
557
- " self.val()\n",
558
- " if (epoch + 1) % self.save_freq == 0:\n",
559
- " self.save_checkpoint(epoch + 1)\n",
560
- "\n",
561
- " # Update learning rate\n",
562
- " self.iter_step += 1\n",
563
- " self.update_learning_rate()\n",
564
- "\n",
565
- "\n",
566
- " def val(self):\n",
567
- " self.model.eval()\n",
568
- " val_loss = 0.0\n",
569
- " correct = 0\n",
570
- " total = 0\n",
571
- "\n",
572
- " with torch.no_grad():\n",
573
- " for batch_inputs, labels in tqdm(self.val_loader, desc=\"Validation\", leave=False):\n",
574
- " freq_inputs = batch_inputs[\"freq_inputs\"].to(self.device)\n",
575
- " seq_inputs = batch_inputs[\"seq_inputs\"].to(self.device)\n",
576
- " pos_inputs = batch_inputs[\"pos_inputs\"].to(self.device)\n",
577
- " y_val = labels.to(self.device)\n",
578
- "\n",
579
- " preds = self.model({\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs})\n",
580
- " loss = self.criterion(preds, y_val)\n",
581
- " # preds = (torch.sigmoid(preds)>0.5).float()\n",
582
- " _, preds = torch.max(preds, dim=1)\n",
583
- " val_loss += loss.item()\n",
584
- " total += y_val.size(0)\n",
585
- " correct += (preds == y_val).sum().item()\n",
586
- " if self.val_loss is None or val_loss < self.val_loss:\n",
587
- " self.val_loss = val_loss\n",
588
- " self.save_checkpoint(\"best\")\n",
589
- " # Log validation metrics\n",
590
- " print(f\"Validation Loss: {val_loss / len(self.val_loader):.4f}\")\n",
591
- " print(f\"Validation Accuracy: {correct / total:.4f}\")\n",
592
- "\n",
593
- " def save_checkpoint(self, epoch):\n",
594
- " \"\"\"Save model in Hugging Face format.\"\"\"\n",
595
- " checkpoint_dir = os.path.join(self.save_path, f\"checkpoint_epoch_{epoch}\")\n",
596
- " if epoch ==\"best\":\n",
597
- " checkpoint_dir = os.path.join(self.save_path, \"best\")\n",
598
- " self.model.save_pretrained(checkpoint_dir)\n",
599
- " print(f\"Checkpoint saved at {checkpoint_dir}\")"
600
- ],
601
- "id": "7be377251b81a25d",
602
- "outputs": [],
603
- "execution_count": 18
604
- },
605
- {
606
- "metadata": {
607
- "ExecuteTime": {
608
- "end_time": "2024-12-16T18:49:49.983176Z",
609
- "start_time": "2024-12-16T18:33:27.283252Z"
610
- }
611
- },
612
- "cell_type": "code",
613
- "source": [
614
- "from torch.utils.data import DataLoader\n",
615
- "\n",
616
- "# Define a collate function to handle the batched data\n",
617
- "def collate_fn(batch):\n",
618
- " freq_inputs = torch.stack([torch.tensor(item[\"freq_inputs\"]) for item in batch])\n",
619
- " seq_inputs = torch.stack([torch.tensor(item[\"seq_inputs\"]) for item in batch])\n",
620
- " pos_inputs = torch.stack([torch.tensor(item[\"pos_inputs\"]) for item in batch])\n",
621
- " labels = torch.tensor([torch.tensor(item[\"labels\"],dtype=torch.long) for item in batch])\n",
622
- " return {\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs}, labels\n",
623
- "\n",
624
- "train_loader = DataLoader(dataset_train, batch_size=config.train[\"batch_size\"], shuffle=True,collate_fn=collate_fn)\n",
625
- "test_loader = DataLoader(dataset_test, batch_size=config.train[\"batch_size\"], shuffle=False,collate_fn=collate_fn)\n",
626
- "trainer = Trainer(model, train_loader, test_loader, config)\n",
627
- "\n",
628
- "# Train the model\n",
629
- "trainer.train()\n",
630
- "# Save the final model in Hugging Face format\n",
631
- "final_save_path = os.path.join(config.base_exp_dir, \"checkpoints\")\n",
632
- "model.save_pretrained(final_save_path)\n",
633
- "print(f\"Final model saved at {final_save_path}\")\n"
634
- ],
635
- "id": "dd1749c306f148eb",
636
- "outputs": [
637
- {
638
- "name": "stderr",
639
- "output_type": "stream",
640
- "text": [
641
- "Epoch 1/10: 100%|██████████| 96/96 [02:28<00:00, 1.55s/it]\n"
642
- ]
643
- },
644
- {
645
- "name": "stdout",
646
- "output_type": "stream",
647
- "text": [
648
- "Train Loss: 0.6943\n",
649
- "Train Accuracy: 0.4947\n"
650
- ]
651
- },
652
- {
653
- "name": "stderr",
654
- "output_type": "stream",
655
- "text": [
656
- " \r"
657
- ]
658
- },
659
- {
660
- "name": "stdout",
661
- "output_type": "stream",
662
- "text": [
663
- "Checkpoint saved at ./exp/fox_nbc/checkpoints\\best\n",
664
- "Validation Loss: 0.6931\n",
665
- "Validation Accuracy: 0.4980\n",
666
- "Checkpoint saved at ./exp/fox_nbc/checkpoints\\checkpoint_epoch_1\n"
667
- ]
668
- },
669
- {
670
- "name": "stderr",
671
- "output_type": "stream",
672
- "text": [
673
- "Epoch 2/10: 100%|██████████| 96/96 [01:34<00:00, 1.01it/s]\n"
674
- ]
675
- },
676
- {
677
- "name": "stdout",
678
- "output_type": "stream",
679
- "text": [
680
- "Train Loss: 0.6006\n",
681
- "Train Accuracy: 0.6597\n"
682
- ]
683
- },
684
- {
685
- "name": "stderr",
686
- "output_type": "stream",
687
- "text": [
688
- " \r"
689
- ]
690
- },
691
- {
692
- "name": "stdout",
693
- "output_type": "stream",
694
- "text": [
695
- "Checkpoint saved at ./exp/fox_nbc/checkpoints\\best\n",
696
- "Validation Loss: 0.4140\n",
697
- "Validation Accuracy: 0.8252\n",
698
- "Checkpoint saved at ./exp/fox_nbc/checkpoints\\checkpoint_epoch_2\n"
699
- ]
700
- },
701
- {
702
- "name": "stderr",
703
- "output_type": "stream",
704
- "text": [
705
- "Epoch 3/10: 100%|██████████| 96/96 [01:31<00:00, 1.05it/s]\n"
706
- ]
707
- },
708
- {
709
- "name": "stdout",
710
- "output_type": "stream",
711
- "text": [
712
- "Train Loss: 0.3597\n",
713
- "Train Accuracy: 0.8469\n"
714
- ]
715
- },
716
- {
717
- "name": "stderr",
718
- "output_type": "stream",
719
- "text": [
720
- " \r"
721
- ]
722
- },
723
- {
724
- "name": "stdout",
725
- "output_type": "stream",
726
- "text": [
727
- "Checkpoint saved at ./exp/fox_nbc/checkpoints\\best\n",
728
- "Validation Loss: 0.3259\n",
729
- "Validation Accuracy: 0.8541\n",
730
- "Checkpoint saved at ./exp/fox_nbc/checkpoints\\checkpoint_epoch_3\n"
731
- ]
732
- },
733
- {
734
- "name": "stderr",
735
- "output_type": "stream",
736
- "text": [
737
- "Epoch 4/10: 100%|██████████| 96/96 [01:00<00:00, 1.58it/s]\n"
738
- ]
739
- },
740
- {
741
- "name": "stdout",
742
- "output_type": "stream",
743
- "text": [
744
- "Train Loss: 0.2143\n",
745
- "Train Accuracy: 0.9205\n"
746
- ]
747
- },
748
- {
749
- "name": "stderr",
750
- "output_type": "stream",
751
- "text": [
752
- " \r"
753
- ]
754
- },
755
- {
756
- "name": "stdout",
757
- "output_type": "stream",
758
- "text": [
759
- "Checkpoint saved at ./exp/fox_nbc/checkpoints\\best\n",
760
- "Validation Loss: 0.2619\n",
761
- "Validation Accuracy: 0.8988\n",
762
- "Checkpoint saved at ./exp/fox_nbc/checkpoints\\checkpoint_epoch_4\n"
763
- ]
764
- },
765
- {
766
- "name": "stderr",
767
- "output_type": "stream",
768
- "text": [
769
- "Epoch 5/10: 100%|██████████| 96/96 [01:24<00:00, 1.13it/s]\n"
770
- ]
771
- },
772
- {
773
- "name": "stdout",
774
- "output_type": "stream",
775
- "text": [
776
- "Train Loss: 0.1113\n",
777
- "Train Accuracy: 0.9573\n"
778
- ]
779
- },
780
- {
781
- "name": "stderr",
782
- "output_type": "stream",
783
- "text": [
784
- " \r"
785
- ]
786
- },
787
- {
788
- "name": "stdout",
789
- "output_type": "stream",
790
- "text": [
791
- "Validation Loss: 0.4198\n",
792
- "Validation Accuracy: 0.8555\n",
793
- "Checkpoint saved at ./exp/fox_nbc/checkpoints\\checkpoint_epoch_5\n"
794
- ]
795
- },
796
- {
797
- "name": "stderr",
798
- "output_type": "stream",
799
- "text": [
800
- "Epoch 6/10: 100%|██████████| 96/96 [01:01<00:00, 1.56it/s]\n"
801
- ]
802
- },
803
- {
804
- "name": "stdout",
805
- "output_type": "stream",
806
- "text": [
807
- "Train Loss: 0.0643\n",
808
- "Train Accuracy: 0.9770\n"
809
- ]
810
- },
811
- {
812
- "name": "stderr",
813
- "output_type": "stream",
814
- "text": [
815
- " \r"
816
- ]
817
- },
818
- {
819
- "name": "stdout",
820
- "output_type": "stream",
821
- "text": [
822
- "Validation Loss: 0.3937\n",
823
- "Validation Accuracy: 0.8725\n",
824
- "Checkpoint saved at ./exp/fox_nbc/checkpoints\\checkpoint_epoch_6\n"
825
- ]
826
- },
827
- {
828
- "name": "stderr",
829
- "output_type": "stream",
830
- "text": [
831
- "Epoch 7/10: 100%|██████████| 96/96 [01:01<00:00, 1.57it/s]\n"
832
- ]
833
- },
834
- {
835
- "name": "stdout",
836
- "output_type": "stream",
837
- "text": [
838
- "Train Loss: 0.0294\n",
839
- "Train Accuracy: 0.9915\n"
840
- ]
841
- },
842
- {
843
- "name": "stderr",
844
- "output_type": "stream",
845
- "text": [
846
- " \r"
847
- ]
848
- },
849
- {
850
- "name": "stdout",
851
- "output_type": "stream",
852
- "text": [
853
- "Validation Loss: 0.4704\n",
854
- "Validation Accuracy: 0.8725\n",
855
- "Checkpoint saved at ./exp/fox_nbc/checkpoints\\checkpoint_epoch_7\n"
856
- ]
857
- },
858
- {
859
- "name": "stderr",
860
- "output_type": "stream",
861
- "text": [
862
- "Epoch 8/10: 100%|██████████| 96/96 [01:01<00:00, 1.56it/s]\n"
863
- ]
864
- },
865
- {
866
- "name": "stdout",
867
- "output_type": "stream",
868
- "text": [
869
- "Train Loss: 0.0128\n",
870
- "Train Accuracy: 0.9970\n"
871
- ]
872
- },
873
- {
874
- "name": "stderr",
875
- "output_type": "stream",
876
- "text": [
877
- " \r"
878
- ]
879
- },
880
- {
881
- "name": "stdout",
882
- "output_type": "stream",
883
- "text": [
884
- "Validation Loss: 0.5717\n",
885
- "Validation Accuracy: 0.8633\n",
886
- "Checkpoint saved at ./exp/fox_nbc/checkpoints\\checkpoint_epoch_8\n"
887
- ]
888
- },
889
- {
890
- "name": "stderr",
891
- "output_type": "stream",
892
- "text": [
893
- "Epoch 9/10: 100%|██████████| 96/96 [01:02<00:00, 1.54it/s]\n"
894
- ]
895
- },
896
- {
897
- "name": "stdout",
898
- "output_type": "stream",
899
- "text": [
900
- "Train Loss: 0.0088\n",
901
- "Train Accuracy: 0.9970\n"
902
- ]
903
- },
904
- {
905
- "name": "stderr",
906
- "output_type": "stream",
907
- "text": [
908
- " \r"
909
- ]
910
- },
911
- {
912
- "name": "stdout",
913
- "output_type": "stream",
914
- "text": [
915
- "Validation Loss: 0.5458\n",
916
- "Validation Accuracy: 0.8739\n",
917
- "Checkpoint saved at ./exp/fox_nbc/checkpoints\\checkpoint_epoch_9\n"
918
- ]
919
- },
920
- {
921
- "name": "stderr",
922
- "output_type": "stream",
923
- "text": [
924
- "Epoch 10/10: 100%|██████████| 96/96 [01:06<00:00, 1.45it/s]\n"
925
- ]
926
- },
927
- {
928
- "name": "stdout",
929
- "output_type": "stream",
930
- "text": [
931
- "Train Loss: 0.0056\n",
932
- "Train Accuracy: 0.9984\n"
933
- ]
934
- },
935
- {
936
- "name": "stderr",
937
- "output_type": "stream",
938
- "text": [
939
- " \r"
940
- ]
941
- },
942
- {
943
- "name": "stdout",
944
- "output_type": "stream",
945
- "text": [
946
- "Validation Loss: 0.4930\n",
947
- "Validation Accuracy: 0.8804\n",
948
- "Checkpoint saved at ./exp/fox_nbc/checkpoints\\checkpoint_epoch_10\n",
949
- "Final model saved at ./exp/fox_nbc/checkpoints\n"
950
- ]
951
- }
952
- ],
953
- "execution_count": 19
954
- },
955
- {
956
- "metadata": {},
957
- "cell_type": "markdown",
958
- "source": "## Evaluate Model",
959
- "id": "4af000263dd99bca"
960
- },
961
- {
962
- "metadata": {
963
- "ExecuteTime": {
964
- "end_time": "2024-12-16T18:50:16.035455Z",
965
- "start_time": "2024-12-16T18:50:02.434452Z"
966
- }
967
- },
968
- "cell_type": "code",
969
- "source": [
970
- "from transformers import AutoConfig, AutoModel\n",
971
- "from sklearn.metrics import accuracy_score, classification_report\n",
972
- "def load_last_checkpoint(checkpoint_dir):\n",
973
- " # Find all checkpoints in the directory\n",
974
- " checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith(\"checkpoint_epoch_\")]\n",
975
- " if not checkpoints:\n",
976
- " raise FileNotFoundError(f\"No checkpoints found in {checkpoint_dir}!\")\n",
977
- " # Sort checkpoints by epoch number\n",
978
- " checkpoints.sort(key=lambda x: int(x.split(\"_\")[-1]))\n",
979
- "\n",
980
- " # Load the last checkpoint\n",
981
- " last_checkpoint = os.path.join(checkpoint_dir, checkpoints[-1])\n",
982
- " # print(f\"Loading checkpoint from {last_checkpoint}\")\n",
983
- " # Load the best checkpoint\n",
984
- " if os.path.join(checkpoint_dir, \"best\") is not None:\n",
985
- " last_checkpoint = os.path.join(checkpoint_dir, \"best\")\n",
986
- " print(f\"Loading checkpoint from {last_checkpoint}\")\n",
987
- " # Load model and config\n",
988
- " config = AutoConfig.from_pretrained(last_checkpoint)\n",
989
- " model = AutoModel.from_pretrained(last_checkpoint, config=config)\n",
990
- " return model\n",
991
- "\n",
992
- "# Step 1: Define paths and setup\n",
993
- "checkpoint_dir = os.path.join(config.base_exp_dir, \"checkpoints\") # Directory where checkpoints are stored\n",
994
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
995
- "model = load_last_checkpoint(checkpoint_dir)\n",
996
- "model.to(device)\n",
997
- "\n",
998
- "# criterion = torch.nn.BCEWithLogitsLoss()\n",
999
- "\n",
1000
- "criterion = torch.nn.CrossEntropyLoss()\n",
1001
- "\n",
1002
- "def evaluate_model(model, val_loader, criterion, device=\"cuda\"):\n",
1003
- " model.eval()\n",
1004
- " val_loss = 0.0\n",
1005
- " correct = 0\n",
1006
- " total = 0\n",
1007
- " all_preds = []\n",
1008
- " all_labels = []\n",
1009
- " with torch.no_grad():\n",
1010
- " for batch_inputs, labels in tqdm(val_loader, desc=\"Testing\", leave=False):\n",
1011
- " freq_inputs = batch_inputs[\"freq_inputs\"].to(device)\n",
1012
- " seq_inputs = batch_inputs[\"seq_inputs\"].to(device)\n",
1013
- " pos_inputs = batch_inputs[\"pos_inputs\"].to(device)\n",
1014
- " labels = labels.to(device)\n",
1015
- "\n",
1016
- " preds= model({\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs})\n",
1017
- " loss = criterion(preds, labels)\n",
1018
- " _, preds = torch.max(preds, dim=1)\n",
1019
- " # preds = (torch.sigmoid(preds) > 0.5).float()\n",
1020
- " val_loss += loss.item()\n",
1021
- " total += labels.size(0)\n",
1022
- " # preds = (torch.sigmoid(preds) > 0.5).int()\n",
1023
- " correct += (preds == labels).sum().item()\n",
1024
- " all_preds.extend(preds.cpu().numpy())\n",
1025
- " all_labels.extend(labels.cpu().numpy())\n",
1026
- "\n",
1027
- " return accuracy_score(all_labels, all_preds), classification_report(all_labels, all_preds)\n",
1028
- "\n",
1029
- "\n",
1030
- "accuracy, report = evaluate_model(model, test_loader, criterion)\n",
1031
- "print(f\"Accuracy: {accuracy:.4f}\")\n",
1032
- "print(report)\n"
1033
- ],
1034
- "id": "b75d2dc8a300cdf6",
1035
- "outputs": [
1036
- {
1037
- "name": "stderr",
1038
- "output_type": "stream",
1039
- "text": [
1040
- "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
1041
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1042
- ]
1043
- },
1044
- {
1045
- "name": "stdout",
1046
- "output_type": "stream",
1047
- "text": [
1048
- "Loading checkpoint from ./exp/fox_nbc/checkpoints\\best\n"
1049
- ]
1050
- },
1051
- {
1052
- "name": "stderr",
1053
- "output_type": "stream",
1054
- "text": [
1055
- "Some weights of the model checkpoint at ./exp/fox_nbc/checkpoints\\best were not used when initializing CustomModel: ['cls.lin0.parametrizations.weight.original0', 'cls.lin0.parametrizations.weight.original1', 'cls.lin1.parametrizations.weight.original0', 'cls.lin1.parametrizations.weight.original1', 'cls.lin2.parametrizations.weight.original0', 'cls.lin2.parametrizations.weight.original1', 'freq.lin0.parametrizations.weight.original0', 'freq.lin0.parametrizations.weight.original1', 'freq.lin1.parametrizations.weight.original0', 'freq.lin1.parametrizations.weight.original1', 'freq.lin2.parametrizations.weight.original0', 'freq.lin2.parametrizations.weight.original1', 'pos.lin0.parametrizations.weight.original0', 'pos.lin0.parametrizations.weight.original1', 'pos.lin1.parametrizations.weight.original0', 'pos.lin1.parametrizations.weight.original1', 'pos.lin2.parametrizations.weight.original0', 'pos.lin2.parametrizations.weight.original1']\n",
1056
- "- This IS expected if you are initializing CustomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
1057
- "- This IS NOT expected if you are initializing CustomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
1058
- "Some weights of CustomModel were not initialized from the model checkpoint at ./exp/fox_nbc/checkpoints\\best and are newly initialized: ['cls.lin0.weight_g', 'cls.lin0.weight_v', 'cls.lin1.weight_g', 'cls.lin1.weight_v', 'cls.lin2.weight_g', 'cls.lin2.weight_v', 'freq.lin0.weight_g', 'freq.lin0.weight_v', 'freq.lin1.weight_g', 'freq.lin1.weight_v', 'freq.lin2.weight_g', 'freq.lin2.weight_v', 'pos.lin0.weight_g', 'pos.lin0.weight_v', 'pos.lin1.weight_g', 'pos.lin1.weight_v', 'pos.lin2.weight_g', 'pos.lin2.weight_v']\n",
1059
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
1060
- " "
1061
- ]
1062
- },
1063
- {
1064
- "name": "stdout",
1065
- "output_type": "stream",
1066
- "text": [
1067
- "Accuracy: 0.8988\n",
1068
- " precision recall f1-score support\n",
1069
- "\n",
1070
- " 0 0.90 0.88 0.89 356\n",
1071
- " 1 0.90 0.91 0.91 405\n",
1072
- "\n",
1073
- " accuracy 0.90 761\n",
1074
- " macro avg 0.90 0.90 0.90 761\n",
1075
- "weighted avg 0.90 0.90 0.90 761\n",
1076
- "\n"
1077
- ]
1078
- },
1079
- {
1080
- "name": "stderr",
1081
- "output_type": "stream",
1082
- "text": [
1083
- "\r"
1084
- ]
1085
- }
1086
- ],
1087
- "execution_count": 21
1088
- },
1089
- {
1090
- "metadata": {},
1091
- "cell_type": "markdown",
1092
- "source": "# Part 3. Pushing the Model to the Hugging Face",
1093
- "id": "d2ffeb383ea00beb"
1094
- },
1095
- {
1096
- "metadata": {
1097
- "ExecuteTime": {
1098
- "end_time": "2024-12-16T18:50:47.965853Z",
1099
- "start_time": "2024-12-16T18:50:23.635567Z"
1100
- }
1101
- },
1102
- "cell_type": "code",
1103
- "source": "model.push_model(REPO_NAME)",
1104
- "id": "f55c22b0a1b2a66b",
1105
- "outputs": [
1106
- {
1107
- "data": {
1108
- "text/plain": [
1109
- "README.md: 0%| | 0.00/839 [00:00<?, ?B/s]"
1110
- ],
1111
- "application/vnd.jupyter.widget-view+json": {
1112
- "version_major": 2,
1113
- "version_minor": 0,
1114
- "model_id": "3258d736d65a4c36b524011271415c56"
1115
- }
1116
- },
1117
- "metadata": {},
1118
- "output_type": "display_data"
1119
- },
1120
- {
1121
- "name": "stderr",
1122
- "output_type": "stream",
1123
- "text": [
1124
- "C:\\Users\\swall\\anaconda3\\envs\\newsCLS\\Lib\\site-packages\\huggingface_hub\\file_download.py:139: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\swall\\.cache\\huggingface\\hub\\models--CISProject--News-Headline-Classifier-Notebook. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
1125
- "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
1126
- " warnings.warn(message)\n",
1127
- "Repo card metadata block was not found. Setting CardData to empty.\n"
1128
- ]
1129
- },
1130
- {
1131
- "data": {
1132
- "text/plain": [
1133
- "model.safetensors: 0%| | 0.00/518M [00:00<?, ?B/s]"
1134
- ],
1135
- "application/vnd.jupyter.widget-view+json": {
1136
- "version_major": 2,
1137
- "version_minor": 0,
1138
- "model_id": "bf9fd6651886433489d5059f9a83b831"
1139
- }
1140
- },
1141
- "metadata": {},
1142
- "output_type": "display_data"
1143
- }
1144
- ],
1145
- "execution_count": 22
1146
- },
1147
- {
1148
- "metadata": {},
1149
- "cell_type": "markdown",
1150
- "source": "### NOTE: You need to ensure that your Hugging Face token has both read and write access to your repository and Hugging Face organization.",
1151
- "id": "3826c0b6195a8fd5"
1152
- },
1153
- {
1154
- "metadata": {
1155
- "ExecuteTime": {
1156
- "end_time": "2024-12-16T18:51:38.723144Z",
1157
- "start_time": "2024-12-16T18:51:24.496422Z"
1158
- }
1159
- },
1160
- "cell_type": "code",
1161
- "source": [
1162
- "# Load model directly\n",
1163
- "from transformers import AutoModel, AutoConfig\n",
1164
- "config = AutoConfig.from_pretrained(\"CISProject/News-Headline-Classifier-Notebook\")\n",
1165
- "model = AutoModel.from_pretrained(\"CISProject/News-Headline-Classifier-Notebook\",config = config)"
1166
- ],
1167
- "id": "33a0ca269c24d700",
1168
- "outputs": [
1169
- {
1170
- "data": {
1171
- "text/plain": [
1172
- "config.json: 0%| | 0.00/1.08k [00:00<?, ?B/s]"
1173
- ],
1174
- "application/vnd.jupyter.widget-view+json": {
1175
- "version_major": 2,
1176
- "version_minor": 0,
1177
- "model_id": "ee3167049b5942acacc9eaab7cbb0a35"
1178
- }
1179
- },
1180
- "metadata": {},
1181
- "output_type": "display_data"
1182
- },
1183
- {
1184
- "data": {
1185
- "text/plain": [
1186
- "model.safetensors: 0%| | 0.00/518M [00:00<?, ?B/s]"
1187
- ],
1188
- "application/vnd.jupyter.widget-view+json": {
1189
- "version_major": 2,
1190
- "version_minor": 0,
1191
- "model_id": "456b7f100f9342c49fd9f08d2b24e1d8"
1192
- }
1193
- },
1194
- "metadata": {},
1195
- "output_type": "display_data"
1196
- },
1197
- {
1198
- "name": "stderr",
1199
- "output_type": "stream",
1200
- "text": [
1201
- "C:\\Users\\swall\\anaconda3\\envs\\newsCLS\\Lib\\site-packages\\torch\\nn\\utils\\weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n",
1202
- " WeightNorm.apply(module, name, dim)\n",
1203
- "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
1204
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
1205
- "Some weights of the model checkpoint at CISProject/News-Headline-Classifier-Notebook were not used when initializing CustomModel: ['cls.lin0.parametrizations.weight.original0', 'cls.lin0.parametrizations.weight.original1', 'cls.lin1.parametrizations.weight.original0', 'cls.lin1.parametrizations.weight.original1', 'cls.lin2.parametrizations.weight.original0', 'cls.lin2.parametrizations.weight.original1', 'freq.lin0.parametrizations.weight.original0', 'freq.lin0.parametrizations.weight.original1', 'freq.lin1.parametrizations.weight.original0', 'freq.lin1.parametrizations.weight.original1', 'freq.lin2.parametrizations.weight.original0', 'freq.lin2.parametrizations.weight.original1', 'pos.lin0.parametrizations.weight.original0', 'pos.lin0.parametrizations.weight.original1', 'pos.lin1.parametrizations.weight.original0', 'pos.lin1.parametrizations.weight.original1', 'pos.lin2.parametrizations.weight.original0', 'pos.lin2.parametrizations.weight.original1']\n",
1206
- "- This IS expected if you are initializing CustomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
1207
- "- This IS NOT expected if you are initializing CustomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
1208
- "Some weights of CustomModel were not initialized from the model checkpoint at CISProject/News-Headline-Classifier-Notebook and are newly initialized: ['cls.lin0.weight_g', 'cls.lin0.weight_v', 'cls.lin1.weight_g', 'cls.lin1.weight_v', 'cls.lin2.weight_g', 'cls.lin2.weight_v', 'freq.lin0.weight_g', 'freq.lin0.weight_v', 'freq.lin1.weight_g', 'freq.lin1.weight_v', 'freq.lin2.weight_g', 'freq.lin2.weight_v', 'pos.lin0.weight_g', 'pos.lin0.weight_v', 'pos.lin1.weight_g', 'pos.lin1.weight_v', 'pos.lin2.weight_g', 'pos.lin2.weight_v']\n",
1209
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1210
- ]
1211
- }
1212
- ],
1213
- "execution_count": 23
1214
- },
1215
- {
1216
- "metadata": {
1217
- "ExecuteTime": {
1218
- "end_time": "2024-12-16T18:51:53.997442Z",
1219
- "start_time": "2024-12-16T18:51:40.978026Z"
1220
- }
1221
- },
1222
- "cell_type": "code",
1223
- "source": [
1224
- "from transformers import AutoConfig, AutoModel\n",
1225
- "from sklearn.metrics import accuracy_score, classification_report\n",
1226
- "\n",
1227
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
1228
- "model.to(device)\n",
1229
- "\n",
1230
- "#criterion = torch.nn.BCEWithLogitsLoss()\n",
1231
- "\n",
1232
- "criterion = torch.nn.CrossEntropyLoss()\n",
1233
- "def evaluate_model(model, val_loader, criterion, device=\"cuda\"):\n",
1234
- " model.eval()\n",
1235
- " val_loss = 0.0\n",
1236
- " correct = 0\n",
1237
- " total = 0\n",
1238
- " all_preds = []\n",
1239
- " all_labels = []\n",
1240
- " with torch.no_grad():\n",
1241
- " for batch_inputs, labels in tqdm(val_loader, desc=\"Testing\", leave=False):\n",
1242
- " freq_inputs = batch_inputs[\"freq_inputs\"].to(device)\n",
1243
- " seq_inputs = batch_inputs[\"seq_inputs\"].to(device)\n",
1244
- " pos_inputs = batch_inputs[\"pos_inputs\"].to(device)\n",
1245
- " labels = labels.to(device)\n",
1246
- "\n",
1247
- " preds = model({\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs})\n",
1248
- " loss = criterion(preds, labels)\n",
1249
- " _, preds = torch.max(preds, dim=1)\n",
1250
- " # preds = (torch.sigmoid(preds) > 0.5).float()\n",
1251
- " val_loss += loss.item()\n",
1252
- " total += labels.size(0)\n",
1253
- " correct += (preds == labels).sum().item()\n",
1254
- " all_preds.extend(preds.cpu().numpy())\n",
1255
- " all_labels.extend(labels.cpu().numpy())\n",
1256
- "\n",
1257
- " return accuracy_score(all_labels, all_preds), classification_report(all_labels, all_preds)\n",
1258
- "\n",
1259
- "\n",
1260
- "accuracy, report = evaluate_model(model, test_loader, criterion)\n",
1261
- "print(f\"Accuracy: {accuracy:.4f}\")\n",
1262
- "print(report)\n"
1263
- ],
1264
- "id": "cc313b4396f87690",
1265
- "outputs": [
1266
- {
1267
- "name": "stderr",
1268
- "output_type": "stream",
1269
- "text": [
1270
- " "
1271
- ]
1272
- },
1273
- {
1274
- "name": "stdout",
1275
- "output_type": "stream",
1276
- "text": [
1277
- "Accuracy: 0.8988\n",
1278
- " precision recall f1-score support\n",
1279
- "\n",
1280
- " 0 0.90 0.88 0.89 356\n",
1281
- " 1 0.90 0.91 0.91 405\n",
1282
- "\n",
1283
- " accuracy 0.90 761\n",
1284
- " macro avg 0.90 0.90 0.90 761\n",
1285
- "weighted avg 0.90 0.90 0.90 761\n",
1286
- "\n"
1287
- ]
1288
- },
1289
- {
1290
- "name": "stderr",
1291
- "output_type": "stream",
1292
- "text": [
1293
- "\r"
1294
- ]
1295
- }
1296
- ],
1297
- "execution_count": 24
1298
- }
1299
- ],
1300
- "metadata": {
1301
- "kernelspec": {
1302
- "name": "python3",
1303
- "language": "python",
1304
- "display_name": "Python 3 (ipykernel)"
1305
- }
1306
- },
1307
- "nbformat": 5,
1308
- "nbformat_minor": 9
1309
- }