Upload eval_py.ipynb
Browse files- eval_py.ipynb +424 -46
eval_py.ipynb
CHANGED
@@ -8,9 +8,14 @@
|
|
8 |
"base_uri": "https://localhost:8080/"
|
9 |
},
|
10 |
"id": "initial_id",
|
11 |
-
"outputId": "85dee483-9370-4e16-ecb1-1c2d545ee1fb"
|
|
|
|
|
|
|
|
|
12 |
},
|
13 |
"source": [
|
|
|
14 |
"!pip install geopy > delete.txt\n",
|
15 |
"!pip install datasets > delete.txt\n",
|
16 |
"!pip install torch torchvision datasets > delete.txt\n",
|
@@ -18,10 +23,21 @@
|
|
18 |
"!pip install pyhocon > delete.txt\n",
|
19 |
"!pip install transformers > delete.txt\n",
|
20 |
"!pip install gensim > delete.txt\n",
|
21 |
-
"!rm delete.txt"
|
22 |
],
|
23 |
-
"outputs": [
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
},
|
26 |
{
|
27 |
"metadata": {
|
@@ -29,15 +45,17 @@
|
|
29 |
"base_uri": "https://localhost:8080/"
|
30 |
},
|
31 |
"id": "b0a77c981c32a0c8",
|
32 |
-
"outputId": "fe03df52-1418-4034-8124-3bf9030ed5d7"
|
|
|
|
|
|
|
|
|
33 |
},
|
34 |
"cell_type": "code",
|
35 |
-
"source":
|
36 |
-
"!huggingface-cli login"
|
37 |
-
],
|
38 |
"id": "b0a77c981c32a0c8",
|
39 |
"outputs": [],
|
40 |
-
"execution_count":
|
41 |
},
|
42 |
{
|
43 |
"metadata": {
|
@@ -105,8 +123,8 @@
|
|
105 |
"id": "a4aa3b759defc904",
|
106 |
"outputId": "b1868c23-e675-41db-aa26-5eed9de60d9f",
|
107 |
"ExecuteTime": {
|
108 |
-
"end_time": "2024-12-
|
109 |
-
"start_time": "2024-12-
|
110 |
}
|
111 |
},
|
112 |
"cell_type": "code",
|
@@ -120,7 +138,7 @@
|
|
120 |
],
|
121 |
"id": "a4aa3b759defc904",
|
122 |
"outputs": [],
|
123 |
-
"execution_count":
|
124 |
},
|
125 |
{
|
126 |
"metadata": {
|
@@ -144,8 +162,8 @@
|
|
144 |
"id": "ce6e6b982e22e9fe",
|
145 |
"outputId": "f38ef6b3-35ac-41dc-a8ae-f0dd28b1f84d",
|
146 |
"ExecuteTime": {
|
147 |
-
"end_time": "2024-12-
|
148 |
-
"start_time": "2024-12-
|
149 |
}
|
150 |
},
|
151 |
"cell_type": "code",
|
@@ -362,8 +380,8 @@
|
|
362 |
"id": "b605d3b4f5ff547a",
|
363 |
"outputId": "f365a98e-c181-4754-9fac-77aa1e8639db",
|
364 |
"ExecuteTime": {
|
365 |
-
"end_time": "2024-12-
|
366 |
-
"start_time": "2024-12-
|
367 |
}
|
368 |
},
|
369 |
"cell_type": "code",
|
@@ -396,10 +414,331 @@
|
|
396 |
"text": [
|
397 |
"vectorizer fitted on training data.\n"
|
398 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
}
|
400 |
],
|
401 |
"execution_count": 5
|
402 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
{
|
404 |
"metadata": {
|
405 |
"colab": {
|
@@ -422,8 +761,8 @@
|
|
422 |
"id": "b20d11caa1d25445",
|
423 |
"outputId": "986c82fd-014b-432a-8174-857b2b866cb8",
|
424 |
"ExecuteTime": {
|
425 |
-
"end_time": "2024-12-
|
426 |
-
"start_time": "2024-12-
|
427 |
}
|
428 |
},
|
429 |
"cell_type": "code",
|
@@ -435,33 +774,51 @@
|
|
435 |
"id": "b20d11caa1d25445",
|
436 |
"outputs": [
|
437 |
{
|
438 |
-
"
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
"
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
]
|
453 |
}
|
454 |
],
|
455 |
-
"execution_count":
|
456 |
},
|
457 |
{
|
458 |
"metadata": {
|
459 |
-
"id": "1d23cedfe1d79660"
|
|
|
|
|
|
|
|
|
460 |
},
|
461 |
"cell_type": "code",
|
462 |
"source": [
|
463 |
"from torch.utils.data import DataLoader\n",
|
464 |
"from sklearn.metrics import accuracy_score, classification_report\n",
|
|
|
465 |
"# Define a collate function to handle the batched data\n",
|
466 |
"def collate_fn(batch):\n",
|
467 |
" freq_inputs = torch.stack([torch.tensor(item[\"freq_inputs\"]) for item in batch])\n",
|
@@ -510,18 +867,39 @@
|
|
510 |
"print(report)"
|
511 |
],
|
512 |
"id": "1d23cedfe1d79660",
|
513 |
-
"outputs": [
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
}
|
526 |
],
|
527 |
"metadata": {
|
|
|
8 |
"base_uri": "https://localhost:8080/"
|
9 |
},
|
10 |
"id": "initial_id",
|
11 |
+
"outputId": "85dee483-9370-4e16-ecb1-1c2d545ee1fb",
|
12 |
+
"ExecuteTime": {
|
13 |
+
"end_time": "2024-12-24T16:26:10.771980Z",
|
14 |
+
"start_time": "2024-12-24T16:26:10.767703Z"
|
15 |
+
}
|
16 |
},
|
17 |
"source": [
|
18 |
+
"\n",
|
19 |
"!pip install geopy > delete.txt\n",
|
20 |
"!pip install datasets > delete.txt\n",
|
21 |
"!pip install torch torchvision datasets > delete.txt\n",
|
|
|
23 |
"!pip install pyhocon > delete.txt\n",
|
24 |
"!pip install transformers > delete.txt\n",
|
25 |
"!pip install gensim > delete.txt\n",
|
26 |
+
"!rm delete.txt\n"
|
27 |
],
|
28 |
+
"outputs": [
|
29 |
+
{
|
30 |
+
"data": {
|
31 |
+
"text/plain": [
|
32 |
+
"'\\n!pip install geopy > delete.txt\\n!pip install datasets > delete.txt\\n!pip install torch torchvision datasets > delete.txt\\n!pip install huggingface_hub > delete.txt\\n!pip install pyhocon > delete.txt\\n!pip install transformers > delete.txt\\n!pip install gensim > delete.txt\\n!rm delete.txt\\n'"
|
33 |
+
]
|
34 |
+
},
|
35 |
+
"execution_count": 1,
|
36 |
+
"metadata": {},
|
37 |
+
"output_type": "execute_result"
|
38 |
+
}
|
39 |
+
],
|
40 |
+
"execution_count": 1
|
41 |
},
|
42 |
{
|
43 |
"metadata": {
|
|
|
45 |
"base_uri": "https://localhost:8080/"
|
46 |
},
|
47 |
"id": "b0a77c981c32a0c8",
|
48 |
+
"outputId": "fe03df52-1418-4034-8124-3bf9030ed5d7",
|
49 |
+
"ExecuteTime": {
|
50 |
+
"end_time": "2024-12-24T16:26:10.813215Z",
|
51 |
+
"start_time": "2024-12-24T16:26:10.810118Z"
|
52 |
+
}
|
53 |
},
|
54 |
"cell_type": "code",
|
55 |
+
"source": "!huggingface-cli login",
|
|
|
|
|
56 |
"id": "b0a77c981c32a0c8",
|
57 |
"outputs": [],
|
58 |
+
"execution_count": 2
|
59 |
},
|
60 |
{
|
61 |
"metadata": {
|
|
|
123 |
"id": "a4aa3b759defc904",
|
124 |
"outputId": "b1868c23-e675-41db-aa26-5eed9de60d9f",
|
125 |
"ExecuteTime": {
|
126 |
+
"end_time": "2024-12-24T16:26:14.452521Z",
|
127 |
+
"start_time": "2024-12-24T16:26:10.866207Z"
|
128 |
}
|
129 |
},
|
130 |
"cell_type": "code",
|
|
|
138 |
],
|
139 |
"id": "a4aa3b759defc904",
|
140 |
"outputs": [],
|
141 |
+
"execution_count": 3
|
142 |
},
|
143 |
{
|
144 |
"metadata": {
|
|
|
162 |
"id": "ce6e6b982e22e9fe",
|
163 |
"outputId": "f38ef6b3-35ac-41dc-a8ae-f0dd28b1f84d",
|
164 |
"ExecuteTime": {
|
165 |
+
"end_time": "2024-12-24T16:26:16.191425Z",
|
166 |
+
"start_time": "2024-12-24T16:26:14.463529Z"
|
167 |
}
|
168 |
},
|
169 |
"cell_type": "code",
|
|
|
380 |
"id": "b605d3b4f5ff547a",
|
381 |
"outputId": "f365a98e-c181-4754-9fac-77aa1e8639db",
|
382 |
"ExecuteTime": {
|
383 |
+
"end_time": "2024-12-24T16:27:18.269951Z",
|
384 |
+
"start_time": "2024-12-24T16:26:16.194928Z"
|
385 |
}
|
386 |
},
|
387 |
"cell_type": "code",
|
|
|
414 |
"text": [
|
415 |
"vectorizer fitted on training data.\n"
|
416 |
]
|
417 |
+
},
|
418 |
+
{
|
419 |
+
"data": {
|
420 |
+
"text/plain": [
|
421 |
+
"Map (num_proc=4): 0%| | 0/3044 [00:00<?, ? examples/s]"
|
422 |
+
],
|
423 |
+
"application/vnd.jupyter.widget-view+json": {
|
424 |
+
"version_major": 2,
|
425 |
+
"version_minor": 0,
|
426 |
+
"model_id": "0aec436603c54b82b962cb31750a5921"
|
427 |
+
}
|
428 |
+
},
|
429 |
+
"metadata": {},
|
430 |
+
"output_type": "display_data"
|
431 |
+
},
|
432 |
+
{
|
433 |
+
"data": {
|
434 |
+
"text/plain": [
|
435 |
+
"Map (num_proc=4): 0%| | 0/3044 [00:00<?, ? examples/s]"
|
436 |
+
],
|
437 |
+
"application/vnd.jupyter.widget-view+json": {
|
438 |
+
"version_major": 2,
|
439 |
+
"version_minor": 0,
|
440 |
+
"model_id": "2061954a6f7e47e08803c4c54e852571"
|
441 |
+
}
|
442 |
+
},
|
443 |
+
"metadata": {},
|
444 |
+
"output_type": "display_data"
|
445 |
+
},
|
446 |
+
{
|
447 |
+
"data": {
|
448 |
+
"text/plain": [
|
449 |
+
"Map (num_proc=4): 0%| | 0/3044 [00:00<?, ? examples/s]"
|
450 |
+
],
|
451 |
+
"application/vnd.jupyter.widget-view+json": {
|
452 |
+
"version_major": 2,
|
453 |
+
"version_minor": 0,
|
454 |
+
"model_id": "a3e3797234bb478c88fcf944fafc8570"
|
455 |
+
}
|
456 |
+
},
|
457 |
+
"metadata": {},
|
458 |
+
"output_type": "display_data"
|
459 |
+
},
|
460 |
+
{
|
461 |
+
"data": {
|
462 |
+
"text/plain": [
|
463 |
+
"Map (num_proc=4): 0%| | 0/3044 [00:00<?, ? examples/s]"
|
464 |
+
],
|
465 |
+
"application/vnd.jupyter.widget-view+json": {
|
466 |
+
"version_major": 2,
|
467 |
+
"version_minor": 0,
|
468 |
+
"model_id": "2c8bc25d53984eee839a20da98f749d1"
|
469 |
+
}
|
470 |
+
},
|
471 |
+
"metadata": {},
|
472 |
+
"output_type": "display_data"
|
473 |
+
},
|
474 |
+
{
|
475 |
+
"data": {
|
476 |
+
"text/plain": [
|
477 |
+
"Map (num_proc=4): 0%| | 0/3044 [00:00<?, ? examples/s]"
|
478 |
+
],
|
479 |
+
"application/vnd.jupyter.widget-view+json": {
|
480 |
+
"version_major": 2,
|
481 |
+
"version_minor": 0,
|
482 |
+
"model_id": "ab84a6ca73a945d58cc23643538c8a6c"
|
483 |
+
}
|
484 |
+
},
|
485 |
+
"metadata": {},
|
486 |
+
"output_type": "display_data"
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"data": {
|
490 |
+
"text/plain": [
|
491 |
+
"Map (num_proc=4): 0%| | 0/3044 [00:00<?, ? examples/s]"
|
492 |
+
],
|
493 |
+
"application/vnd.jupyter.widget-view+json": {
|
494 |
+
"version_major": 2,
|
495 |
+
"version_minor": 0,
|
496 |
+
"model_id": "4ebc99d9190c49bfae44fb7f9db88007"
|
497 |
+
}
|
498 |
+
},
|
499 |
+
"metadata": {},
|
500 |
+
"output_type": "display_data"
|
501 |
+
},
|
502 |
+
{
|
503 |
+
"data": {
|
504 |
+
"text/plain": [
|
505 |
+
"Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
|
506 |
+
],
|
507 |
+
"application/vnd.jupyter.widget-view+json": {
|
508 |
+
"version_major": 2,
|
509 |
+
"version_minor": 0,
|
510 |
+
"model_id": "b1a4da224f31484fa8946982954e74ff"
|
511 |
+
}
|
512 |
+
},
|
513 |
+
"metadata": {},
|
514 |
+
"output_type": "display_data"
|
515 |
+
},
|
516 |
+
{
|
517 |
+
"data": {
|
518 |
+
"text/plain": [
|
519 |
+
"Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
|
520 |
+
],
|
521 |
+
"application/vnd.jupyter.widget-view+json": {
|
522 |
+
"version_major": 2,
|
523 |
+
"version_minor": 0,
|
524 |
+
"model_id": "ab6cbe1f1b714017809ef32d010e452f"
|
525 |
+
}
|
526 |
+
},
|
527 |
+
"metadata": {},
|
528 |
+
"output_type": "display_data"
|
529 |
+
},
|
530 |
+
{
|
531 |
+
"data": {
|
532 |
+
"text/plain": [
|
533 |
+
"Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
|
534 |
+
],
|
535 |
+
"application/vnd.jupyter.widget-view+json": {
|
536 |
+
"version_major": 2,
|
537 |
+
"version_minor": 0,
|
538 |
+
"model_id": "fcd18d89aa5043b39fe0143d4a5ac681"
|
539 |
+
}
|
540 |
+
},
|
541 |
+
"metadata": {},
|
542 |
+
"output_type": "display_data"
|
543 |
+
},
|
544 |
+
{
|
545 |
+
"data": {
|
546 |
+
"text/plain": [
|
547 |
+
"Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
|
548 |
+
],
|
549 |
+
"application/vnd.jupyter.widget-view+json": {
|
550 |
+
"version_major": 2,
|
551 |
+
"version_minor": 0,
|
552 |
+
"model_id": "d1b2b18cedc94f338d4f5bf5dc4a5dec"
|
553 |
+
}
|
554 |
+
},
|
555 |
+
"metadata": {},
|
556 |
+
"output_type": "display_data"
|
557 |
+
},
|
558 |
+
{
|
559 |
+
"data": {
|
560 |
+
"text/plain": [
|
561 |
+
"Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
|
562 |
+
],
|
563 |
+
"application/vnd.jupyter.widget-view+json": {
|
564 |
+
"version_major": 2,
|
565 |
+
"version_minor": 0,
|
566 |
+
"model_id": "03f17c1c68e248a6bf71a6bd7d7bbae3"
|
567 |
+
}
|
568 |
+
},
|
569 |
+
"metadata": {},
|
570 |
+
"output_type": "display_data"
|
571 |
+
},
|
572 |
+
{
|
573 |
+
"data": {
|
574 |
+
"text/plain": [
|
575 |
+
"Map (num_proc=4): 0%| | 0/761 [00:00<?, ? examples/s]"
|
576 |
+
],
|
577 |
+
"application/vnd.jupyter.widget-view+json": {
|
578 |
+
"version_major": 2,
|
579 |
+
"version_minor": 0,
|
580 |
+
"model_id": "1e4135f4a3974b149f50dd3dba35c02e"
|
581 |
+
}
|
582 |
+
},
|
583 |
+
"metadata": {},
|
584 |
+
"output_type": "display_data"
|
585 |
}
|
586 |
],
|
587 |
"execution_count": 5
|
588 |
},
|
589 |
+
{
|
590 |
+
"metadata": {
|
591 |
+
"ExecuteTime": {
|
592 |
+
"end_time": "2024-12-24T16:28:32.064840Z",
|
593 |
+
"start_time": "2024-12-24T16:28:31.661013Z"
|
594 |
+
}
|
595 |
+
},
|
596 |
+
"cell_type": "code",
|
597 |
+
"source": [
|
598 |
+
"# TODO: import all packages necessary for your custom model\n",
|
599 |
+
"import pandas as pd\n",
|
600 |
+
"import os\n",
|
601 |
+
"from torch.utils.data import DataLoader\n",
|
602 |
+
"from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel\n",
|
603 |
+
"import torch\n",
|
604 |
+
"import torch.nn as nn\n",
|
605 |
+
"from transformers import RobertaModel, RobertaConfig,RobertaForSequenceClassification, BertModel\n",
|
606 |
+
"from model.network import Classifier\n",
|
607 |
+
"from model.frequential import FreqNetwork\n",
|
608 |
+
"from model.sequential import SeqNetwork\n",
|
609 |
+
"from model.positional import PosNetwork\n",
|
610 |
+
"\n",
|
611 |
+
"class CustomConfig(PretrainedConfig):\n",
|
612 |
+
" model_type = \"headlineclassifier\"\n",
|
613 |
+
"\n",
|
614 |
+
" def __init__(\n",
|
615 |
+
" self,\n",
|
616 |
+
" base_exp_dir=\"./exp/fox_nbc/\",\n",
|
617 |
+
" # dataset={\"data_dir\": \"./data/CASE_NAME/data.csv\", \"transform\": True},\n",
|
618 |
+
" train={\n",
|
619 |
+
" \"learning_rate\": 2e-5,\n",
|
620 |
+
" \"learning_rate_alpha\": 0.05,\n",
|
621 |
+
" \"end_iter\": 10,\n",
|
622 |
+
" \"batch_size\": 32,\n",
|
623 |
+
" \"warm_up_end\": 2,\n",
|
624 |
+
" \"anneal_end\": 5,\n",
|
625 |
+
" \"save_freq\": 1,\n",
|
626 |
+
" \"val_freq\": 1,\n",
|
627 |
+
" },\n",
|
628 |
+
" model={\n",
|
629 |
+
" \"freq\": {\n",
|
630 |
+
" \"tfidf_input_dim\": 8145,\n",
|
631 |
+
" \"tfidf_output_dim\": 128,\n",
|
632 |
+
" \"tfidf_hidden_dim\": 512,\n",
|
633 |
+
" \"n_layers\": 2,\n",
|
634 |
+
" \"skip_in\": [80],\n",
|
635 |
+
" \"weight_norm\": True,\n",
|
636 |
+
" },\n",
|
637 |
+
" \"pos\": {\n",
|
638 |
+
" \"input_dim\": 300,\n",
|
639 |
+
" \"output_dim\": 128,\n",
|
640 |
+
" \"hidden_dim\": 256,\n",
|
641 |
+
" \"n_layers\": 2,\n",
|
642 |
+
" \"skip_in\": [80],\n",
|
643 |
+
" \"weight_norm\": True,\n",
|
644 |
+
" },\n",
|
645 |
+
" \"cls\": {\n",
|
646 |
+
" \"combined_input\": 1024, #1024\n",
|
647 |
+
" \"combined_dim\": 128,\n",
|
648 |
+
" \"num_classes\": 2,\n",
|
649 |
+
" \"n_layers\": 2,\n",
|
650 |
+
" \"skip_in\": [80],\n",
|
651 |
+
" \"weight_norm\": True,\n",
|
652 |
+
" },\n",
|
653 |
+
" },\n",
|
654 |
+
" **kwargs,\n",
|
655 |
+
" ):\n",
|
656 |
+
" super().__init__(**kwargs)\n",
|
657 |
+
"\n",
|
658 |
+
" self.base_exp_dir = base_exp_dir\n",
|
659 |
+
" # self.dataset = dataset\n",
|
660 |
+
" self.train = train\n",
|
661 |
+
" self.model = model\n",
|
662 |
+
"\n",
|
663 |
+
"# TODO: define all parameters needed for your model, as well as calling the model itself\n",
|
664 |
+
"class CustomModel(PreTrainedModel):\n",
|
665 |
+
" config_class = CustomConfig\n",
|
666 |
+
"\n",
|
667 |
+
" def __init__(self, config):\n",
|
668 |
+
" super().__init__(config)\n",
|
669 |
+
" self.conf = config\n",
|
670 |
+
" self.freq = FreqNetwork(**self.conf.model[\"freq\"])\n",
|
671 |
+
" self.pos = PosNetwork(**self.conf.model[\"pos\"])\n",
|
672 |
+
" self.cls = Classifier(**self.conf.model[\"cls\"])\n",
|
673 |
+
" self.fc = nn.Linear(self.conf.model[\"cls\"][\"combined_input\"],2)\n",
|
674 |
+
" self.seq = RobertaModel.from_pretrained(\"roberta-base\")\n",
|
675 |
+
" # self.seq = BertModel.from_pretrained(\"bert-base-uncased\")\n",
|
676 |
+
" #for param in self.roberta.parameters():\n",
|
677 |
+
" # param.requires_grad = False\n",
|
678 |
+
" self.dropout = nn.Dropout(0.2)\n",
|
679 |
+
"\n",
|
680 |
+
" def forward(self, x):\n",
|
681 |
+
" freq_inputs = x[\"freq_inputs\"]\n",
|
682 |
+
" seq_inputs = x[\"seq_inputs\"]\n",
|
683 |
+
" pos_inputs = x[\"pos_inputs\"]\n",
|
684 |
+
" seq_feature = self.seq(\n",
|
685 |
+
" input_ids=seq_inputs[:,0,:],\n",
|
686 |
+
" attention_mask=seq_inputs[:,1,:]\n",
|
687 |
+
" ).pooler_output # last_hidden_state[:, 0, :]\n",
|
688 |
+
" freq_feature = self.freq(freq_inputs) # Shape: (batch_size, 128)\n",
|
689 |
+
"\n",
|
690 |
+
" pos_feature = self.pos(pos_inputs) #Shape: (batch_size, 128)\n",
|
691 |
+
" inputs = torch.cat((seq_feature, freq_feature, pos_feature), dim=1) # Shape: (batch_size, 384)\n",
|
692 |
+
" # inputs = torch.cat((seq_feature, freq_feature), dim=1) # Shape: (batch_size,256)\n",
|
693 |
+
" # inputs = seq_feature\n",
|
694 |
+
"\n",
|
695 |
+
" x = inputs\n",
|
696 |
+
" x = self.dropout(x)\n",
|
697 |
+
" outputs = self.fc(x)\n",
|
698 |
+
"\n",
|
699 |
+
" return outputs\n",
|
700 |
+
"\n",
|
701 |
+
" def save_model(self, save_path):\n",
|
702 |
+
" \"\"\"Save the model locally using the Hugging Face format.\"\"\"\n",
|
703 |
+
" self.save_pretrained(save_path)\n",
|
704 |
+
"\n",
|
705 |
+
" def push_model(self, repo_name):\n",
|
706 |
+
" \"\"\"Push the model to the Hugging Face Hub.\"\"\"\n",
|
707 |
+
" self.push_to_hub(repo_name)"
|
708 |
+
],
|
709 |
+
"id": "9266d67887120863",
|
710 |
+
"outputs": [],
|
711 |
+
"execution_count": 7
|
712 |
+
},
|
713 |
+
{
|
714 |
+
"metadata": {
|
715 |
+
"ExecuteTime": {
|
716 |
+
"end_time": "2024-12-24T16:28:35.657792Z",
|
717 |
+
"start_time": "2024-12-24T16:28:35.392033Z"
|
718 |
+
}
|
719 |
+
},
|
720 |
+
"cell_type": "code",
|
721 |
+
"source": [
|
722 |
+
"AutoConfig.register(\"headlineclassifier\", CustomConfig)\n",
|
723 |
+
"AutoModel.register(CustomConfig, CustomModel)\n",
|
724 |
+
"config = CustomConfig()\n",
|
725 |
+
"model = CustomModel(config)"
|
726 |
+
],
|
727 |
+
"id": "77b94c012f4fae3a",
|
728 |
+
"outputs": [
|
729 |
+
{
|
730 |
+
"name": "stderr",
|
731 |
+
"output_type": "stream",
|
732 |
+
"text": [
|
733 |
+
"C:\\Users\\swall\\anaconda3\\envs\\newsCLS\\Lib\\site-packages\\torch\\nn\\utils\\weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n",
|
734 |
+
" WeightNorm.apply(module, name, dim)\n",
|
735 |
+
"Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
|
736 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
737 |
+
]
|
738 |
+
}
|
739 |
+
],
|
740 |
+
"execution_count": 8
|
741 |
+
},
|
742 |
{
|
743 |
"metadata": {
|
744 |
"colab": {
|
|
|
761 |
"id": "b20d11caa1d25445",
|
762 |
"outputId": "986c82fd-014b-432a-8174-857b2b866cb8",
|
763 |
"ExecuteTime": {
|
764 |
+
"end_time": "2024-12-24T16:28:50.195358Z",
|
765 |
+
"start_time": "2024-12-24T16:28:37.051697Z"
|
766 |
}
|
767 |
},
|
768 |
"cell_type": "code",
|
|
|
774 |
"id": "b20d11caa1d25445",
|
775 |
"outputs": [
|
776 |
{
|
777 |
+
"data": {
|
778 |
+
"text/plain": [
|
779 |
+
"model.safetensors: 0%| | 0.00/518M [00:00<?, ?B/s]"
|
780 |
+
],
|
781 |
+
"application/vnd.jupyter.widget-view+json": {
|
782 |
+
"version_major": 2,
|
783 |
+
"version_minor": 0,
|
784 |
+
"model_id": "882ba9da828e4438bdcbc3cd60ce32a4"
|
785 |
+
}
|
786 |
+
},
|
787 |
+
"metadata": {},
|
788 |
+
"output_type": "display_data"
|
789 |
+
},
|
790 |
+
{
|
791 |
+
"name": "stderr",
|
792 |
+
"output_type": "stream",
|
793 |
+
"text": [
|
794 |
+
"C:\\Users\\swall\\anaconda3\\envs\\newsCLS\\Lib\\site-packages\\huggingface_hub\\file_download.py:139: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\swall\\.cache\\huggingface\\hub\\models--CISProject--News-Headline-Classifier-Notebook. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
|
795 |
+
"To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
|
796 |
+
" warnings.warn(message)\n",
|
797 |
+
"Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
|
798 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
799 |
+
"Some weights of the model checkpoint at CISProject/News-Headline-Classifier-Notebook were not used when initializing CustomModel: ['cls.lin0.parametrizations.weight.original0', 'cls.lin0.parametrizations.weight.original1', 'cls.lin1.parametrizations.weight.original0', 'cls.lin1.parametrizations.weight.original1', 'cls.lin2.parametrizations.weight.original0', 'cls.lin2.parametrizations.weight.original1', 'freq.lin0.parametrizations.weight.original0', 'freq.lin0.parametrizations.weight.original1', 'freq.lin1.parametrizations.weight.original0', 'freq.lin1.parametrizations.weight.original1', 'freq.lin2.parametrizations.weight.original0', 'freq.lin2.parametrizations.weight.original1', 'pos.lin0.parametrizations.weight.original0', 'pos.lin0.parametrizations.weight.original1', 'pos.lin1.parametrizations.weight.original0', 'pos.lin1.parametrizations.weight.original1', 'pos.lin2.parametrizations.weight.original0', 'pos.lin2.parametrizations.weight.original1']\n",
|
800 |
+
"- This IS expected if you are initializing CustomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
801 |
+
"- This IS NOT expected if you are initializing CustomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
802 |
+
"Some weights of CustomModel were not initialized from the model checkpoint at CISProject/News-Headline-Classifier-Notebook and are newly initialized: ['cls.lin0.weight_g', 'cls.lin0.weight_v', 'cls.lin1.weight_g', 'cls.lin1.weight_v', 'cls.lin2.weight_g', 'cls.lin2.weight_v', 'freq.lin0.weight_g', 'freq.lin0.weight_v', 'freq.lin1.weight_g', 'freq.lin1.weight_v', 'freq.lin2.weight_g', 'freq.lin2.weight_v', 'pos.lin0.weight_g', 'pos.lin0.weight_v', 'pos.lin1.weight_g', 'pos.lin1.weight_v', 'pos.lin2.weight_g', 'pos.lin2.weight_v']\n",
|
803 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
804 |
]
|
805 |
}
|
806 |
],
|
807 |
+
"execution_count": 9
|
808 |
},
|
809 |
{
|
810 |
"metadata": {
|
811 |
+
"id": "1d23cedfe1d79660",
|
812 |
+
"ExecuteTime": {
|
813 |
+
"end_time": "2024-12-24T16:29:36.873566Z",
|
814 |
+
"start_time": "2024-12-24T16:29:23.549424Z"
|
815 |
+
}
|
816 |
},
|
817 |
"cell_type": "code",
|
818 |
"source": [
|
819 |
"from torch.utils.data import DataLoader\n",
|
820 |
"from sklearn.metrics import accuracy_score, classification_report\n",
|
821 |
+
"from tqdm import tqdm\n",
|
822 |
"# Define a collate function to handle the batched data\n",
|
823 |
"def collate_fn(batch):\n",
|
824 |
" freq_inputs = torch.stack([torch.tensor(item[\"freq_inputs\"]) for item in batch])\n",
|
|
|
867 |
"print(report)"
|
868 |
],
|
869 |
"id": "1d23cedfe1d79660",
|
870 |
+
"outputs": [
|
871 |
+
{
|
872 |
+
"name": "stderr",
|
873 |
+
"output_type": "stream",
|
874 |
+
"text": [
|
875 |
+
" "
|
876 |
+
]
|
877 |
+
},
|
878 |
+
{
|
879 |
+
"name": "stdout",
|
880 |
+
"output_type": "stream",
|
881 |
+
"text": [
|
882 |
+
"Accuracy: 0.8988\n",
|
883 |
+
" precision recall f1-score support\n",
|
884 |
+
"\n",
|
885 |
+
" 0 0.90 0.88 0.89 356\n",
|
886 |
+
" 1 0.90 0.91 0.91 405\n",
|
887 |
+
"\n",
|
888 |
+
" accuracy 0.90 761\n",
|
889 |
+
" macro avg 0.90 0.90 0.90 761\n",
|
890 |
+
"weighted avg 0.90 0.90 0.90 761\n",
|
891 |
+
"\n"
|
892 |
+
]
|
893 |
+
},
|
894 |
+
{
|
895 |
+
"name": "stderr",
|
896 |
+
"output_type": "stream",
|
897 |
+
"text": [
|
898 |
+
"\r"
|
899 |
+
]
|
900 |
+
}
|
901 |
+
],
|
902 |
+
"execution_count": 12
|
903 |
}
|
904 |
],
|
905 |
"metadata": {
|