TUEN-YUE commited on
Commit
ef1c2c2
·
verified ·
1 Parent(s): ec11c59

Upload eval_py.ipynb

Browse files
Files changed (1) hide show
  1. eval_py.ipynb +424 -46
eval_py.ipynb CHANGED
@@ -8,9 +8,14 @@
8
  "base_uri": "https://localhost:8080/"
9
  },
10
  "id": "initial_id",
11
- "outputId": "85dee483-9370-4e16-ecb1-1c2d545ee1fb"
 
 
 
 
12
  },
13
  "source": [
 
14
  "!pip install geopy > delete.txt\n",
15
  "!pip install datasets > delete.txt\n",
16
  "!pip install torch torchvision datasets > delete.txt\n",
@@ -18,10 +23,21 @@
18
  "!pip install pyhocon > delete.txt\n",
19
  "!pip install transformers > delete.txt\n",
20
  "!pip install gensim > delete.txt\n",
21
- "!rm delete.txt"
22
  ],
23
- "outputs": [],
24
- "execution_count": null
 
 
 
 
 
 
 
 
 
 
 
25
  },
26
  {
27
  "metadata": {
@@ -29,15 +45,17 @@
29
  "base_uri": "https://localhost:8080/"
30
  },
31
  "id": "b0a77c981c32a0c8",
32
- "outputId": "fe03df52-1418-4034-8124-3bf9030ed5d7"
 
 
 
 
33
  },
34
  "cell_type": "code",
35
- "source": [
36
- "!huggingface-cli login"
37
- ],
38
  "id": "b0a77c981c32a0c8",
39
  "outputs": [],
40
- "execution_count": null
41
  },
42
  {
43
  "metadata": {
@@ -105,8 +123,8 @@
105
  "id": "a4aa3b759defc904",
106
  "outputId": "b1868c23-e675-41db-aa26-5eed9de60d9f",
107
  "ExecuteTime": {
108
- "end_time": "2024-12-16T08:26:09.513376Z",
109
- "start_time": "2024-12-16T08:26:05.978557Z"
110
  }
111
  },
112
  "cell_type": "code",
@@ -120,7 +138,7 @@
120
  ],
121
  "id": "a4aa3b759defc904",
122
  "outputs": [],
123
- "execution_count": 1
124
  },
125
  {
126
  "metadata": {
@@ -144,8 +162,8 @@
144
  "id": "ce6e6b982e22e9fe",
145
  "outputId": "f38ef6b3-35ac-41dc-a8ae-f0dd28b1f84d",
146
  "ExecuteTime": {
147
- "end_time": "2024-12-16T08:26:54.306779Z",
148
- "start_time": "2024-12-16T08:26:54.298397Z"
149
  }
150
  },
151
  "cell_type": "code",
@@ -362,8 +380,8 @@
362
  "id": "b605d3b4f5ff547a",
363
  "outputId": "f365a98e-c181-4754-9fac-77aa1e8639db",
364
  "ExecuteTime": {
365
- "end_time": "2024-12-16T08:27:16.788714Z",
366
- "start_time": "2024-12-16T08:27:01.757035Z"
367
  }
368
  },
369
  "cell_type": "code",
@@ -396,10 +414,331 @@
396
  "text": [
397
  "vectorizer fitted on training data.\n"
398
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  }
400
  ],
401
  "execution_count": 5
402
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  {
404
  "metadata": {
405
  "colab": {
@@ -422,8 +761,8 @@
422
  "id": "b20d11caa1d25445",
423
  "outputId": "986c82fd-014b-432a-8174-857b2b866cb8",
424
  "ExecuteTime": {
425
- "end_time": "2024-12-16T08:27:32.874705Z",
426
- "start_time": "2024-12-16T08:27:32.787248Z"
427
  }
428
  },
429
  "cell_type": "code",
@@ -435,33 +774,51 @@
435
  "id": "b20d11caa1d25445",
436
  "outputs": [
437
  {
438
- "ename": "ValueError",
439
- "evalue": "The checkpoint you are trying to load has model type `headlineclassifier` but Transformers does not recognize this architecture. This could be because of an issue with the checkpoint, or because your version of Transformers is out of date.",
440
- "output_type": "error",
441
- "traceback": [
442
- "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
443
- "\u001B[1;31mKeyError\u001B[0m Traceback (most recent call last)",
444
- "File \u001B[1;32m~\\anaconda3\\envs\\newsCLS\\Lib\\site-packages\\transformers\\models\\auto\\configuration_auto.py:1038\u001B[0m, in \u001B[0;36mAutoConfig.from_pretrained\u001B[1;34m(cls, pretrained_model_name_or_path, **kwargs)\u001B[0m\n\u001B[0;32m 1037\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m-> 1038\u001B[0m config_class \u001B[38;5;241m=\u001B[39m CONFIG_MAPPING[config_dict[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mmodel_type\u001B[39m\u001B[38;5;124m\"\u001B[39m]]\n\u001B[0;32m 1039\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mKeyError\u001B[39;00m:\n",
445
- "File \u001B[1;32m~\\anaconda3\\envs\\newsCLS\\Lib\\site-packages\\transformers\\models\\auto\\configuration_auto.py:740\u001B[0m, in \u001B[0;36m_LazyConfigMapping.__getitem__\u001B[1;34m(self, key)\u001B[0m\n\u001B[0;32m 739\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m key \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_mapping:\n\u001B[1;32m--> 740\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mKeyError\u001B[39;00m(key)\n\u001B[0;32m 741\u001B[0m value \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_mapping[key]\n",
446
- "\u001B[1;31mKeyError\u001B[0m: 'headlineclassifier'",
447
- "\nDuring handling of the above exception, another exception occurred:\n",
448
- "\u001B[1;31mValueError\u001B[0m Traceback (most recent call last)",
449
- "Cell \u001B[1;32mIn[15], line 2\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtransformers\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m AutoModel, AutoConfig\n\u001B[1;32m----> 2\u001B[0m config \u001B[38;5;241m=\u001B[39m AutoConfig\u001B[38;5;241m.\u001B[39mfrom_pretrained(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mCISProject/News-Headline-Classifier-Notebook\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 3\u001B[0m model \u001B[38;5;241m=\u001B[39m AutoModel\u001B[38;5;241m.\u001B[39mfrom_pretrained(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mCISProject/News-Headline-Classifier-Notebook\u001B[39m\u001B[38;5;124m\"\u001B[39m,config \u001B[38;5;241m=\u001B[39m config)\n",
450
- "File \u001B[1;32m~\\anaconda3\\envs\\newsCLS\\Lib\\site-packages\\transformers\\models\\auto\\configuration_auto.py:1040\u001B[0m, in \u001B[0;36mAutoConfig.from_pretrained\u001B[1;34m(cls, pretrained_model_name_or_path, **kwargs)\u001B[0m\n\u001B[0;32m 1038\u001B[0m config_class \u001B[38;5;241m=\u001B[39m CONFIG_MAPPING[config_dict[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mmodel_type\u001B[39m\u001B[38;5;124m\"\u001B[39m]]\n\u001B[0;32m 1039\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mKeyError\u001B[39;00m:\n\u001B[1;32m-> 1040\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[0;32m 1041\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mThe checkpoint you are trying to load has model type `\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mconfig_dict[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mmodel_type\u001B[39m\u001B[38;5;124m'\u001B[39m]\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m` \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 1042\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mbut Transformers does not recognize this architecture. This could be because of an \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 1043\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124missue with the checkpoint, or because your version of Transformers is out of date.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 1044\u001B[0m )\n\u001B[0;32m 1045\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m config_class\u001B[38;5;241m.\u001B[39mfrom_dict(config_dict, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39munused_kwargs)\n\u001B[0;32m 1046\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m 1047\u001B[0m \u001B[38;5;66;03m# Fallback: use pattern matching on the string.\u001B[39;00m\n\u001B[0;32m 1048\u001B[0m \u001B[38;5;66;03m# We go from longer names to shorter names to catch roberta before bert (for instance)\u001B[39;00m\n",
451
- "\u001B[1;31mValueError\u001B[0m: The checkpoint you are trying to load has model type `headlineclassifier` but Transformers does not recognize this architecture. This could be because of an issue with the checkpoint, or because your version of Transformers is out of date."
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  ]
453
  }
454
  ],
455
- "execution_count": 15
456
  },
457
  {
458
  "metadata": {
459
- "id": "1d23cedfe1d79660"
 
 
 
 
460
  },
461
  "cell_type": "code",
462
  "source": [
463
  "from torch.utils.data import DataLoader\n",
464
  "from sklearn.metrics import accuracy_score, classification_report\n",
 
465
  "# Define a collate function to handle the batched data\n",
466
  "def collate_fn(batch):\n",
467
  " freq_inputs = torch.stack([torch.tensor(item[\"freq_inputs\"]) for item in batch])\n",
@@ -510,18 +867,39 @@
510
  "print(report)"
511
  ],
512
  "id": "1d23cedfe1d79660",
513
- "outputs": [],
514
- "execution_count": null
515
- },
516
- {
517
- "metadata": {
518
- "id": "549f3e0a004e80ab"
519
- },
520
- "cell_type": "code",
521
- "source": [],
522
- "id": "549f3e0a004e80ab",
523
- "outputs": [],
524
- "execution_count": null
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  }
526
  ],
527
  "metadata": {
 
8
  "base_uri": "https://localhost:8080/"
9
  },
10
  "id": "initial_id",
11
+ "outputId": "85dee483-9370-4e16-ecb1-1c2d545ee1fb",
12
+ "ExecuteTime": {
13
+ "end_time": "2024-12-24T16:26:10.771980Z",
14
+ "start_time": "2024-12-24T16:26:10.767703Z"
15
+ }
16
  },
17
  "source": [
18
+ "\n",
19
  "!pip install geopy > delete.txt\n",
20
  "!pip install datasets > delete.txt\n",
21
  "!pip install torch torchvision datasets > delete.txt\n",
 
23
  "!pip install pyhocon > delete.txt\n",
24
  "!pip install transformers > delete.txt\n",
25
  "!pip install gensim > delete.txt\n",
26
+ "!rm delete.txt\n"
27
  ],
28
+ "outputs": [
29
+ {
30
+ "data": {
31
+ "text/plain": [
32
+ "'\\n!pip install geopy > delete.txt\\n!pip install datasets > delete.txt\\n!pip install torch torchvision datasets > delete.txt\\n!pip install huggingface_hub > delete.txt\\n!pip install pyhocon > delete.txt\\n!pip install transformers > delete.txt\\n!pip install gensim > delete.txt\\n!rm delete.txt\\n'"
33
+ ]
34
+ },
35
+ "execution_count": 1,
36
+ "metadata": {},
37
+ "output_type": "execute_result"
38
+ }
39
+ ],
40
+ "execution_count": 1
41
  },
42
  {
43
  "metadata": {
 
45
  "base_uri": "https://localhost:8080/"
46
  },
47
  "id": "b0a77c981c32a0c8",
48
+ "outputId": "fe03df52-1418-4034-8124-3bf9030ed5d7",
49
+ "ExecuteTime": {
50
+ "end_time": "2024-12-24T16:26:10.813215Z",
51
+ "start_time": "2024-12-24T16:26:10.810118Z"
52
+ }
53
  },
54
  "cell_type": "code",
55
+ "source": "!huggingface-cli login",
 
 
56
  "id": "b0a77c981c32a0c8",
57
  "outputs": [],
58
+ "execution_count": 2
59
  },
60
  {
61
  "metadata": {
 
123
  "id": "a4aa3b759defc904",
124
  "outputId": "b1868c23-e675-41db-aa26-5eed9de60d9f",
125
  "ExecuteTime": {
126
+ "end_time": "2024-12-24T16:26:14.452521Z",
127
+ "start_time": "2024-12-24T16:26:10.866207Z"
128
  }
129
  },
130
  "cell_type": "code",
 
138
  ],
139
  "id": "a4aa3b759defc904",
140
  "outputs": [],
141
+ "execution_count": 3
142
  },
143
  {
144
  "metadata": {
 
162
  "id": "ce6e6b982e22e9fe",
163
  "outputId": "f38ef6b3-35ac-41dc-a8ae-f0dd28b1f84d",
164
  "ExecuteTime": {
165
+ "end_time": "2024-12-24T16:26:16.191425Z",
166
+ "start_time": "2024-12-24T16:26:14.463529Z"
167
  }
168
  },
169
  "cell_type": "code",
 
380
  "id": "b605d3b4f5ff547a",
381
  "outputId": "f365a98e-c181-4754-9fac-77aa1e8639db",
382
  "ExecuteTime": {
383
+ "end_time": "2024-12-24T16:27:18.269951Z",
384
+ "start_time": "2024-12-24T16:26:16.194928Z"
385
  }
386
  },
387
  "cell_type": "code",
 
414
  "text": [
415
  "vectorizer fitted on training data.\n"
416
  ]
417
+ },
418
+ {
419
+ "data": {
420
+ "text/plain": [
421
+ "Map (num_proc=4): 0%| | 0/3044 [00:00<?, ? examples/s]"
422
+ ],
423
+ "application/vnd.jupyter.widget-view+json": {
424
+ "version_major": 2,
425
+ "version_minor": 0,
426
+ "model_id": "0aec436603c54b82b962cb31750a5921"
427
+ }
428
+ },
429
+ "metadata": {},
430
+ "output_type": "display_data"
431
+ },
432
+ {
433
+ "data": {
434
+ "text/plain": [
435
+ "Map (num_proc=4): 0%| | 0/3044 [00:00<?, ? examples/s]"
436
+ ],
437
+ "application/vnd.jupyter.widget-view+json": {
438
+ "version_major": 2,
439
+ "version_minor": 0,
440
+ "model_id": "2061954a6f7e47e08803c4c54e852571"
441
+ }
442
+ },
443
+ "metadata": {},
444
+ "output_type": "display_data"
445
+ },
446
+ {
447
+ "data": {
448
+ "text/plain": [
449
+ "Map (num_proc=4): 0%| | 0/3044 [00:00<?, ? examples/s]"
450
+ ],
451
+ "application/vnd.jupyter.widget-view+json": {
452
+ "version_major": 2,
453
+ "version_minor": 0,
454
+ "model_id": "a3e3797234bb478c88fcf944fafc8570"
455
+ }
456
+ },
457
+ "metadata": {},
458
+ "output_type": "display_data"
459
+ },
460
+ {
461
+ "data": {
462
+ "text/plain": [
463
+ "Map (num_proc=4): 0%| | 0/3044 [00:00<?, ? examples/s]"
464
+ ],
465
+ "application/vnd.jupyter.widget-view+json": {
466
+ "version_major": 2,
467
+ "version_minor": 0,
468
+ "model_id": "2c8bc25d53984eee839a20da98f749d1"
469
+ }
470
+ },
471
+ "metadata": {},
472
+ "output_type": "display_data"
473
+ },
474
+ {
475
+ "data": {
476
+ "text/plain": [
477
+ "Map (num_proc=4): 0%| | 0/3044 [00:00<?, ? examples/s]"
478
+ ],
479
+ "application/vnd.jupyter.widget-view+json": {
480
+ "version_major": 2,
481
+ "version_minor": 0,
482
+ "model_id": "ab84a6ca73a945d58cc23643538c8a6c"
483
+ }
484
+ },
485
+ "metadata": {},
486
+ "output_type": "display_data"
487
+ },
488
+ {
489
+ "data": {
490
+ "text/plain": [
491
+ "Map (num_proc=4): 0%| | 0/3044 [00:00<?, ? examples/s]"
492
+ ],
493
+ "application/vnd.jupyter.widget-view+json": {
494
+ "version_major": 2,
495
+ "version_minor": 0,
496
+ "model_id": "4ebc99d9190c49bfae44fb7f9db88007"
497
+ }
498
+ },
499
+ "metadata": {},
500
+ "output_type": "display_data"
501
+ },
502
+ {
503
+ "data": {
504
+ "text/plain": [
505
+ "Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
506
+ ],
507
+ "application/vnd.jupyter.widget-view+json": {
508
+ "version_major": 2,
509
+ "version_minor": 0,
510
+ "model_id": "b1a4da224f31484fa8946982954e74ff"
511
+ }
512
+ },
513
+ "metadata": {},
514
+ "output_type": "display_data"
515
+ },
516
+ {
517
+ "data": {
518
+ "text/plain": [
519
+ "Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
520
+ ],
521
+ "application/vnd.jupyter.widget-view+json": {
522
+ "version_major": 2,
523
+ "version_minor": 0,
524
+ "model_id": "ab6cbe1f1b714017809ef32d010e452f"
525
+ }
526
+ },
527
+ "metadata": {},
528
+ "output_type": "display_data"
529
+ },
530
+ {
531
+ "data": {
532
+ "text/plain": [
533
+ "Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
534
+ ],
535
+ "application/vnd.jupyter.widget-view+json": {
536
+ "version_major": 2,
537
+ "version_minor": 0,
538
+ "model_id": "fcd18d89aa5043b39fe0143d4a5ac681"
539
+ }
540
+ },
541
+ "metadata": {},
542
+ "output_type": "display_data"
543
+ },
544
+ {
545
+ "data": {
546
+ "text/plain": [
547
+ "Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
548
+ ],
549
+ "application/vnd.jupyter.widget-view+json": {
550
+ "version_major": 2,
551
+ "version_minor": 0,
552
+ "model_id": "d1b2b18cedc94f338d4f5bf5dc4a5dec"
553
+ }
554
+ },
555
+ "metadata": {},
556
+ "output_type": "display_data"
557
+ },
558
+ {
559
+ "data": {
560
+ "text/plain": [
561
+ "Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
562
+ ],
563
+ "application/vnd.jupyter.widget-view+json": {
564
+ "version_major": 2,
565
+ "version_minor": 0,
566
+ "model_id": "03f17c1c68e248a6bf71a6bd7d7bbae3"
567
+ }
568
+ },
569
+ "metadata": {},
570
+ "output_type": "display_data"
571
+ },
572
+ {
573
+ "data": {
574
+ "text/plain": [
575
+ "Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
576
+ ],
577
+ "application/vnd.jupyter.widget-view+json": {
578
+ "version_major": 2,
579
+ "version_minor": 0,
580
+ "model_id": "1e4135f4a3974b149f50dd3dba35c02e"
581
+ }
582
+ },
583
+ "metadata": {},
584
+ "output_type": "display_data"
585
  }
586
  ],
587
  "execution_count": 5
588
  },
589
+ {
590
+ "metadata": {
591
+ "ExecuteTime": {
592
+ "end_time": "2024-12-24T16:28:32.064840Z",
593
+ "start_time": "2024-12-24T16:28:31.661013Z"
594
+ }
595
+ },
596
+ "cell_type": "code",
597
+ "source": [
598
+ "# TODO: import all packages necessary for your custom model\n",
599
+ "import pandas as pd\n",
600
+ "import os\n",
601
+ "from torch.utils.data import DataLoader\n",
602
+ "from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel\n",
603
+ "import torch\n",
604
+ "import torch.nn as nn\n",
605
+ "from transformers import RobertaModel, RobertaConfig,RobertaForSequenceClassification, BertModel\n",
606
+ "from model.network import Classifier\n",
607
+ "from model.frequential import FreqNetwork\n",
608
+ "from model.sequential import SeqNetwork\n",
609
+ "from model.positional import PosNetwork\n",
610
+ "\n",
611
+ "class CustomConfig(PretrainedConfig):\n",
612
+ " model_type = \"headlineclassifier\"\n",
613
+ "\n",
614
+ " def __init__(\n",
615
+ " self,\n",
616
+ " base_exp_dir=\"./exp/fox_nbc/\",\n",
617
+ " # dataset={\"data_dir\": \"./data/CASE_NAME/data.csv\", \"transform\": True},\n",
618
+ " train={\n",
619
+ " \"learning_rate\": 2e-5,\n",
620
+ " \"learning_rate_alpha\": 0.05,\n",
621
+ " \"end_iter\": 10,\n",
622
+ " \"batch_size\": 32,\n",
623
+ " \"warm_up_end\": 2,\n",
624
+ " \"anneal_end\": 5,\n",
625
+ " \"save_freq\": 1,\n",
626
+ " \"val_freq\": 1,\n",
627
+ " },\n",
628
+ " model={\n",
629
+ " \"freq\": {\n",
630
+ " \"tfidf_input_dim\": 8145,\n",
631
+ " \"tfidf_output_dim\": 128,\n",
632
+ " \"tfidf_hidden_dim\": 512,\n",
633
+ " \"n_layers\": 2,\n",
634
+ " \"skip_in\": [80],\n",
635
+ " \"weight_norm\": True,\n",
636
+ " },\n",
637
+ " \"pos\": {\n",
638
+ " \"input_dim\": 300,\n",
639
+ " \"output_dim\": 128,\n",
640
+ " \"hidden_dim\": 256,\n",
641
+ " \"n_layers\": 2,\n",
642
+ " \"skip_in\": [80],\n",
643
+ " \"weight_norm\": True,\n",
644
+ " },\n",
645
+ " \"cls\": {\n",
646
+ " \"combined_input\": 1024, #1024\n",
647
+ " \"combined_dim\": 128,\n",
648
+ " \"num_classes\": 2,\n",
649
+ " \"n_layers\": 2,\n",
650
+ " \"skip_in\": [80],\n",
651
+ " \"weight_norm\": True,\n",
652
+ " },\n",
653
+ " },\n",
654
+ " **kwargs,\n",
655
+ " ):\n",
656
+ " super().__init__(**kwargs)\n",
657
+ "\n",
658
+ " self.base_exp_dir = base_exp_dir\n",
659
+ " # self.dataset = dataset\n",
660
+ " self.train = train\n",
661
+ " self.model = model\n",
662
+ "\n",
663
+ "# TODO: define all parameters needed for your model, as well as calling the model itself\n",
664
+ "class CustomModel(PreTrainedModel):\n",
665
+ " config_class = CustomConfig\n",
666
+ "\n",
667
+ " def __init__(self, config):\n",
668
+ " super().__init__(config)\n",
669
+ " self.conf = config\n",
670
+ " self.freq = FreqNetwork(**self.conf.model[\"freq\"])\n",
671
+ " self.pos = PosNetwork(**self.conf.model[\"pos\"])\n",
672
+ " self.cls = Classifier(**self.conf.model[\"cls\"])\n",
673
+ " self.fc = nn.Linear(self.conf.model[\"cls\"][\"combined_input\"],2)\n",
674
+ " self.seq = RobertaModel.from_pretrained(\"roberta-base\")\n",
675
+ " # self.seq = BertModel.from_pretrained(\"bert-base-uncased\")\n",
676
+ " #for param in self.roberta.parameters():\n",
677
+ " # param.requires_grad = False\n",
678
+ " self.dropout = nn.Dropout(0.2)\n",
679
+ "\n",
680
+ " def forward(self, x):\n",
681
+ " freq_inputs = x[\"freq_inputs\"]\n",
682
+ " seq_inputs = x[\"seq_inputs\"]\n",
683
+ " pos_inputs = x[\"pos_inputs\"]\n",
684
+ " seq_feature = self.seq(\n",
685
+ " input_ids=seq_inputs[:,0,:],\n",
686
+ " attention_mask=seq_inputs[:,1,:]\n",
687
+ " ).pooler_output # last_hidden_state[:, 0, :]\n",
688
+ " freq_feature = self.freq(freq_inputs) # Shape: (batch_size, 128)\n",
689
+ "\n",
690
+ " pos_feature = self.pos(pos_inputs) #Shape: (batch_size, 128)\n",
691
+ " inputs = torch.cat((seq_feature, freq_feature, pos_feature), dim=1) # Shape: (batch_size, 384)\n",
692
+ " # inputs = torch.cat((seq_feature, freq_feature), dim=1) # Shape: (batch_size,256)\n",
693
+ " # inputs = seq_feature\n",
694
+ "\n",
695
+ " x = inputs\n",
696
+ " x = self.dropout(x)\n",
697
+ " outputs = self.fc(x)\n",
698
+ "\n",
699
+ " return outputs\n",
700
+ "\n",
701
+ " def save_model(self, save_path):\n",
702
+ " \"\"\"Save the model locally using the Hugging Face format.\"\"\"\n",
703
+ " self.save_pretrained(save_path)\n",
704
+ "\n",
705
+ " def push_model(self, repo_name):\n",
706
+ " \"\"\"Push the model to the Hugging Face Hub.\"\"\"\n",
707
+ " self.push_to_hub(repo_name)"
708
+ ],
709
+ "id": "9266d67887120863",
710
+ "outputs": [],
711
+ "execution_count": 7
712
+ },
713
+ {
714
+ "metadata": {
715
+ "ExecuteTime": {
716
+ "end_time": "2024-12-24T16:28:35.657792Z",
717
+ "start_time": "2024-12-24T16:28:35.392033Z"
718
+ }
719
+ },
720
+ "cell_type": "code",
721
+ "source": [
722
+ "AutoConfig.register(\"headlineclassifier\", CustomConfig)\n",
723
+ "AutoModel.register(CustomConfig, CustomModel)\n",
724
+ "config = CustomConfig()\n",
725
+ "model = CustomModel(config)"
726
+ ],
727
+ "id": "77b94c012f4fae3a",
728
+ "outputs": [
729
+ {
730
+ "name": "stderr",
731
+ "output_type": "stream",
732
+ "text": [
733
+ "C:\\Users\\swall\\anaconda3\\envs\\newsCLS\\Lib\\site-packages\\torch\\nn\\utils\\weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n",
734
+ " WeightNorm.apply(module, name, dim)\n",
735
+ "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
736
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
737
+ ]
738
+ }
739
+ ],
740
+ "execution_count": 8
741
+ },
742
  {
743
  "metadata": {
744
  "colab": {
 
761
  "id": "b20d11caa1d25445",
762
  "outputId": "986c82fd-014b-432a-8174-857b2b866cb8",
763
  "ExecuteTime": {
764
+ "end_time": "2024-12-24T16:28:50.195358Z",
765
+ "start_time": "2024-12-24T16:28:37.051697Z"
766
  }
767
  },
768
  "cell_type": "code",
 
774
  "id": "b20d11caa1d25445",
775
  "outputs": [
776
  {
777
+ "data": {
778
+ "text/plain": [
779
+ "model.safetensors: 0%| | 0.00/518M [00:00<?, ?B/s]"
780
+ ],
781
+ "application/vnd.jupyter.widget-view+json": {
782
+ "version_major": 2,
783
+ "version_minor": 0,
784
+ "model_id": "882ba9da828e4438bdcbc3cd60ce32a4"
785
+ }
786
+ },
787
+ "metadata": {},
788
+ "output_type": "display_data"
789
+ },
790
+ {
791
+ "name": "stderr",
792
+ "output_type": "stream",
793
+ "text": [
794
+ "C:\\Users\\swall\\anaconda3\\envs\\newsCLS\\Lib\\site-packages\\huggingface_hub\\file_download.py:139: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\swall\\.cache\\huggingface\\hub\\models--CISProject--News-Headline-Classifier-Notebook. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
795
+ "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
796
+ " warnings.warn(message)\n",
797
+ "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
798
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
799
+ "Some weights of the model checkpoint at CISProject/News-Headline-Classifier-Notebook were not used when initializing CustomModel: ['cls.lin0.parametrizations.weight.original0', 'cls.lin0.parametrizations.weight.original1', 'cls.lin1.parametrizations.weight.original0', 'cls.lin1.parametrizations.weight.original1', 'cls.lin2.parametrizations.weight.original0', 'cls.lin2.parametrizations.weight.original1', 'freq.lin0.parametrizations.weight.original0', 'freq.lin0.parametrizations.weight.original1', 'freq.lin1.parametrizations.weight.original0', 'freq.lin1.parametrizations.weight.original1', 'freq.lin2.parametrizations.weight.original0', 'freq.lin2.parametrizations.weight.original1', 'pos.lin0.parametrizations.weight.original0', 'pos.lin0.parametrizations.weight.original1', 'pos.lin1.parametrizations.weight.original0', 'pos.lin1.parametrizations.weight.original1', 'pos.lin2.parametrizations.weight.original0', 'pos.lin2.parametrizations.weight.original1']\n",
800
+ "- This IS expected if you are initializing CustomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
801
+ "- This IS NOT expected if you are initializing CustomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
802
+ "Some weights of CustomModel were not initialized from the model checkpoint at CISProject/News-Headline-Classifier-Notebook and are newly initialized: ['cls.lin0.weight_g', 'cls.lin0.weight_v', 'cls.lin1.weight_g', 'cls.lin1.weight_v', 'cls.lin2.weight_g', 'cls.lin2.weight_v', 'freq.lin0.weight_g', 'freq.lin0.weight_v', 'freq.lin1.weight_g', 'freq.lin1.weight_v', 'freq.lin2.weight_g', 'freq.lin2.weight_v', 'pos.lin0.weight_g', 'pos.lin0.weight_v', 'pos.lin1.weight_g', 'pos.lin1.weight_v', 'pos.lin2.weight_g', 'pos.lin2.weight_v']\n",
803
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
804
  ]
805
  }
806
  ],
807
+ "execution_count": 9
808
  },
809
  {
810
  "metadata": {
811
+ "id": "1d23cedfe1d79660",
812
+ "ExecuteTime": {
813
+ "end_time": "2024-12-24T16:29:36.873566Z",
814
+ "start_time": "2024-12-24T16:29:23.549424Z"
815
+ }
816
  },
817
  "cell_type": "code",
818
  "source": [
819
  "from torch.utils.data import DataLoader\n",
820
  "from sklearn.metrics import accuracy_score, classification_report\n",
821
+ "from tqdm import tqdm\n",
822
  "# Define a collate function to handle the batched data\n",
823
  "def collate_fn(batch):\n",
824
  " freq_inputs = torch.stack([torch.tensor(item[\"freq_inputs\"]) for item in batch])\n",
 
867
  "print(report)"
868
  ],
869
  "id": "1d23cedfe1d79660",
870
+ "outputs": [
871
+ {
872
+ "name": "stderr",
873
+ "output_type": "stream",
874
+ "text": [
875
+ " "
876
+ ]
877
+ },
878
+ {
879
+ "name": "stdout",
880
+ "output_type": "stream",
881
+ "text": [
882
+ "Accuracy: 0.8988\n",
883
+ " precision recall f1-score support\n",
884
+ "\n",
885
+ " 0 0.90 0.88 0.89 356\n",
886
+ " 1 0.90 0.91 0.91 405\n",
887
+ "\n",
888
+ " accuracy 0.90 761\n",
889
+ " macro avg 0.90 0.90 0.90 761\n",
890
+ "weighted avg 0.90 0.90 0.90 761\n",
891
+ "\n"
892
+ ]
893
+ },
894
+ {
895
+ "name": "stderr",
896
+ "output_type": "stream",
897
+ "text": [
898
+ "\r"
899
+ ]
900
+ }
901
+ ],
902
+ "execution_count": 12
903
  }
904
  ],
905
  "metadata": {