Upload train+test.ipynb
Browse files- train+test.ipynb +25 -143
train+test.ipynb
CHANGED
@@ -13,8 +13,6 @@
|
|
13 |
{
|
14 |
"metadata": {},
|
15 |
"cell_type": "code",
|
16 |
-
"outputs": [],
|
17 |
-
"execution_count": null,
|
18 |
"source": [
|
19 |
"!pip install geopy > delete.txt\n",
|
20 |
"!pip install datasets > delete.txt\n",
|
@@ -25,7 +23,9 @@
|
|
25 |
"!pip install gensim > delete.txt\n",
|
26 |
"!rm delete.txt"
|
27 |
],
|
28 |
-
"id": "5a596f2639253772"
|
|
|
|
|
29 |
},
|
30 |
{
|
31 |
"metadata": {},
|
@@ -37,17 +37,12 @@
|
|
37 |
"id": "432a756039e6399"
|
38 |
},
|
39 |
{
|
40 |
-
"metadata": {
|
41 |
-
"ExecuteTime": {
|
42 |
-
"end_time": "2024-12-16T19:48:43.216631Z",
|
43 |
-
"start_time": "2024-12-16T19:48:43.214630Z"
|
44 |
-
}
|
45 |
-
},
|
46 |
"cell_type": "code",
|
47 |
"source": "!huggingface-cli login",
|
48 |
"id": "2e73da09a7c6171e",
|
49 |
"outputs": [],
|
50 |
-
"execution_count":
|
51 |
},
|
52 |
{
|
53 |
"metadata": {},
|
@@ -68,12 +63,7 @@
|
|
68 |
"id": "b8920847b7cc378d"
|
69 |
},
|
70 |
{
|
71 |
-
"metadata": {
|
72 |
-
"ExecuteTime": {
|
73 |
-
"end_time": "2024-12-16T19:48:45.272372Z",
|
74 |
-
"start_time": "2024-12-16T19:48:43.220140Z"
|
75 |
-
}
|
76 |
-
},
|
77 |
"cell_type": "code",
|
78 |
"source": [
|
79 |
"from datasets import load_dataset\n",
|
@@ -84,15 +74,10 @@
|
|
84 |
],
|
85 |
"id": "877c90c978d62b7d",
|
86 |
"outputs": [],
|
87 |
-
"execution_count":
|
88 |
},
|
89 |
{
|
90 |
-
"metadata": {
|
91 |
-
"ExecuteTime": {
|
92 |
-
"end_time": "2024-12-16T19:48:45.287939Z",
|
93 |
-
"start_time": "2024-12-16T19:48:45.278748Z"
|
94 |
-
}
|
95 |
-
},
|
96 |
"cell_type": "code",
|
97 |
"source": [
|
98 |
"import numpy as np\n",
|
@@ -228,15 +213,10 @@
|
|
228 |
],
|
229 |
"id": "dc2ba675ce880d6d",
|
230 |
"outputs": [],
|
231 |
-
"execution_count":
|
232 |
},
|
233 |
{
|
234 |
-
"metadata": {
|
235 |
-
"ExecuteTime": {
|
236 |
-
"end_time": "2024-12-16T19:49:01.529651Z",
|
237 |
-
"start_time": "2024-12-16T19:48:45.294290Z"
|
238 |
-
}
|
239 |
-
},
|
240 |
"cell_type": "code",
|
241 |
"source": [
|
242 |
"from gensim.models import KeyedVectors\n",
|
@@ -260,47 +240,19 @@
|
|
260 |
")"
|
261 |
],
|
262 |
"id": "158b99950fb22d1",
|
263 |
-
"outputs": [
|
264 |
-
|
265 |
-
"name": "stdout",
|
266 |
-
"output_type": "stream",
|
267 |
-
"text": [
|
268 |
-
"vectorizer fitted on training data.\n"
|
269 |
-
]
|
270 |
-
}
|
271 |
-
],
|
272 |
-
"execution_count": 47
|
273 |
},
|
274 |
{
|
275 |
-
"metadata": {
|
276 |
-
"ExecuteTime": {
|
277 |
-
"end_time": "2024-12-16T19:49:01.538067Z",
|
278 |
-
"start_time": "2024-12-16T19:49:01.535063Z"
|
279 |
-
}
|
280 |
-
},
|
281 |
"cell_type": "code",
|
282 |
"source": [
|
283 |
"print(dataset_train)\n",
|
284 |
"print(dataset_test)"
|
285 |
],
|
286 |
"id": "edd80d33175c96a0",
|
287 |
-
"outputs": [
|
288 |
-
|
289 |
-
"name": "stdout",
|
290 |
-
"output_type": "stream",
|
291 |
-
"text": [
|
292 |
-
"Dataset({\n",
|
293 |
-
" features: ['title', 'outlet', 'index', 'url', 'labels', 'clean_title', 'freq_inputs', 'input_ids', 'attention_mask', 'tokens', 'pos_inputs', 'seq_inputs'],\n",
|
294 |
-
" num_rows: 3044\n",
|
295 |
-
"})\n",
|
296 |
-
"Dataset({\n",
|
297 |
-
" features: ['title', 'outlet', 'index', 'url', 'labels', 'clean_title', 'freq_inputs', 'input_ids', 'attention_mask', 'tokens', 'pos_inputs', 'seq_inputs'],\n",
|
298 |
-
" num_rows: 761\n",
|
299 |
-
"})\n"
|
300 |
-
]
|
301 |
-
}
|
302 |
-
],
|
303 |
-
"execution_count": 48
|
304 |
},
|
305 |
{
|
306 |
"metadata": {},
|
@@ -321,12 +273,7 @@
|
|
321 |
"id": "f0eae08a025b6ed9"
|
322 |
},
|
323 |
{
|
324 |
-
"metadata": {
|
325 |
-
"ExecuteTime": {
|
326 |
-
"end_time": "2024-12-16T19:49:01.554769Z",
|
327 |
-
"start_time": "2024-12-16T19:49:01.543575Z"
|
328 |
-
}
|
329 |
-
},
|
330 |
"cell_type": "code",
|
331 |
"source": [
|
332 |
"# TODO: import all packages necessary for your custom model\n",
|
@@ -419,8 +366,6 @@
|
|
419 |
" input_ids=seq_inputs[:,0,:],\n",
|
420 |
" attention_mask=seq_inputs[:,1,:]\n",
|
421 |
" ).pooler_output # last_hidden_state[:, 0, :]\n",
|
422 |
-
" lstm_out, (h_n, c_n) = self.lstm(seq_feature)\n",
|
423 |
-
" seq_feature = h_n[-1] # Use the last hidden state\n",
|
424 |
" freq_feature = self.freq(freq_inputs) # Shape: (batch_size, 128)\n",
|
425 |
"\n",
|
426 |
" pos_feature = self.pos(pos_inputs) #Shape: (batch_size, 128)\n",
|
@@ -444,15 +389,10 @@
|
|
444 |
],
|
445 |
"id": "21f079d0c52d7d",
|
446 |
"outputs": [],
|
447 |
-
"execution_count":
|
448 |
},
|
449 |
{
|
450 |
-
"metadata": {
|
451 |
-
"ExecuteTime": {
|
452 |
-
"end_time": "2024-12-16T19:49:01.791918Z",
|
453 |
-
"start_time": "2024-12-16T19:49:01.561338Z"
|
454 |
-
}
|
455 |
-
},
|
456 |
"cell_type": "code",
|
457 |
"source": [
|
458 |
"from huggingface_hub import hf_hub_download\n",
|
@@ -465,27 +405,11 @@
|
|
465 |
"REPO_NAME = \"CISProject/News-Headline-Classifier-Notebook\" # TODO: PROVIDE A STRING TO YOUR REPO ON HUGGINGFACE"
|
466 |
],
|
467 |
"id": "b6ba3f96d3ce21",
|
468 |
-
"outputs": [
|
469 |
-
|
470 |
-
"name": "stderr",
|
471 |
-
"output_type": "stream",
|
472 |
-
"text": [
|
473 |
-
"C:\\Users\\swall\\anaconda3\\envs\\newsCLS\\Lib\\site-packages\\torch\\nn\\utils\\weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n",
|
474 |
-
" WeightNorm.apply(module, name, dim)\n",
|
475 |
-
"Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
|
476 |
-
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
477 |
-
]
|
478 |
-
}
|
479 |
-
],
|
480 |
-
"execution_count": 50
|
481 |
},
|
482 |
{
|
483 |
-
"metadata": {
|
484 |
-
"ExecuteTime": {
|
485 |
-
"end_time": "2024-12-16T19:49:01.808079Z",
|
486 |
-
"start_time": "2024-12-16T19:49:01.798760Z"
|
487 |
-
}
|
488 |
-
},
|
489 |
"cell_type": "code",
|
490 |
"source": [
|
491 |
"import torch\n",
|
@@ -623,15 +547,10 @@
|
|
623 |
],
|
624 |
"id": "7be377251b81a25d",
|
625 |
"outputs": [],
|
626 |
-
"execution_count":
|
627 |
},
|
628 |
{
|
629 |
-
"metadata": {
|
630 |
-
"ExecuteTime": {
|
631 |
-
"end_time": "2024-12-16T19:49:03.149673Z",
|
632 |
-
"start_time": "2024-12-16T19:49:01.812943Z"
|
633 |
-
}
|
634 |
-
},
|
635 |
"cell_type": "code",
|
636 |
"source": [
|
637 |
"from torch.utils.data import DataLoader\n",
|
@@ -656,45 +575,8 @@
|
|
656 |
"print(f\"Final model saved at {final_save_path}\")\n"
|
657 |
],
|
658 |
"id": "dd1749c306f148eb",
|
659 |
-
"outputs": [
|
660 |
-
|
661 |
-
"name": "stderr",
|
662 |
-
"output_type": "stream",
|
663 |
-
"text": [
|
664 |
-
"Epoch 1/10: 0%| | 0/96 [00:00<?, ?it/s]"
|
665 |
-
]
|
666 |
-
},
|
667 |
-
{
|
668 |
-
"name": "stdout",
|
669 |
-
"output_type": "stream",
|
670 |
-
"text": [
|
671 |
-
"torch.Size([1, 768]) torch.Size([32, 128]) torch.Size([32, 128])\n"
|
672 |
-
]
|
673 |
-
},
|
674 |
-
{
|
675 |
-
"name": "stderr",
|
676 |
-
"output_type": "stream",
|
677 |
-
"text": [
|
678 |
-
"\n"
|
679 |
-
]
|
680 |
-
},
|
681 |
-
{
|
682 |
-
"ename": "RuntimeError",
|
683 |
-
"evalue": "Sizes of tensors must match except in dimension 1. Expected size 1 but got size 32 for tensor number 1 in the list.",
|
684 |
-
"output_type": "error",
|
685 |
-
"traceback": [
|
686 |
-
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
687 |
-
"\u001B[1;31mRuntimeError\u001B[0m Traceback (most recent call last)",
|
688 |
-
"Cell \u001B[1;32mIn[52], line 16\u001B[0m\n\u001B[0;32m 13\u001B[0m trainer \u001B[38;5;241m=\u001B[39m Trainer(model, train_loader, test_loader, config)\n\u001B[0;32m 15\u001B[0m \u001B[38;5;66;03m# Train the model\u001B[39;00m\n\u001B[1;32m---> 16\u001B[0m trainer\u001B[38;5;241m.\u001B[39mtrain()\n\u001B[0;32m 17\u001B[0m \u001B[38;5;66;03m# Save the final model in Hugging Face format\u001B[39;00m\n\u001B[0;32m 18\u001B[0m final_save_path \u001B[38;5;241m=\u001B[39m os\u001B[38;5;241m.\u001B[39mpath\u001B[38;5;241m.\u001B[39mjoin(config\u001B[38;5;241m.\u001B[39mbase_exp_dir, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcheckpoints\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n",
|
689 |
-
"Cell \u001B[1;32mIn[51], line 69\u001B[0m, in \u001B[0;36mTrainer.train\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m 66\u001B[0m y_train \u001B[38;5;241m=\u001B[39m labels\u001B[38;5;241m.\u001B[39mto(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdevice)\n\u001B[0;32m 68\u001B[0m \u001B[38;5;66;03m# Forward pass\u001B[39;00m\n\u001B[1;32m---> 69\u001B[0m preds \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmodel({\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mfreq_inputs\u001B[39m\u001B[38;5;124m\"\u001B[39m: freq_inputs, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mseq_inputs\u001B[39m\u001B[38;5;124m\"\u001B[39m: seq_inputs, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mpos_inputs\u001B[39m\u001B[38;5;124m\"\u001B[39m: pos_inputs})\n\u001B[0;32m 70\u001B[0m loss \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcriterion(preds, y_train)\n\u001B[0;32m 72\u001B[0m \u001B[38;5;66;03m# preds = (torch.sigmoid(preds) > 0.5).int()\u001B[39;00m\n\u001B[0;32m 73\u001B[0m \u001B[38;5;66;03m# Backward pass\u001B[39;00m\n",
|
690 |
-
"File \u001B[1;32m~\\anaconda3\\envs\\newsCLS\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1736\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1734\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1735\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1736\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n",
|
691 |
-
"File \u001B[1;32m~\\anaconda3\\envs\\newsCLS\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1747\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1742\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1743\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1744\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1745\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1746\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1747\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m forward_call(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 1749\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m 1750\u001B[0m called_always_called_hooks \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mset\u001B[39m()\n",
|
692 |
-
"Cell \u001B[1;32mIn[49], line 99\u001B[0m, in \u001B[0;36mCustomModel.forward\u001B[1;34m(self, x)\u001B[0m\n\u001B[0;32m 97\u001B[0m pos_feature \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mpos(pos_inputs) \u001B[38;5;66;03m#Shape: (batch_size, 128)\u001B[39;00m\n\u001B[0;32m 98\u001B[0m \u001B[38;5;28mprint\u001B[39m(seq_feature\u001B[38;5;241m.\u001B[39mshape,pos_feature\u001B[38;5;241m.\u001B[39mshape,freq_feature\u001B[38;5;241m.\u001B[39mshape)\n\u001B[1;32m---> 99\u001B[0m inputs \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mcat((seq_feature, freq_feature, pos_feature), dim\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1\u001B[39m) \u001B[38;5;66;03m# Shape: (batch_size, 384)\u001B[39;00m\n\u001B[0;32m 100\u001B[0m \u001B[38;5;66;03m# inputs = torch.cat((seq_feature, freq_feature), dim=1) # Shape: (batch_size,256)\u001B[39;00m\n\u001B[0;32m 101\u001B[0m \u001B[38;5;66;03m# inputs = seq_feature\u001B[39;00m\n\u001B[0;32m 103\u001B[0m x \u001B[38;5;241m=\u001B[39m inputs\n",
|
693 |
-
"\u001B[1;31mRuntimeError\u001B[0m: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 32 for tensor number 1 in the list."
|
694 |
-
]
|
695 |
-
}
|
696 |
-
],
|
697 |
-
"execution_count": 52
|
698 |
},
|
699 |
{
|
700 |
"metadata": {},
|
|
|
13 |
{
|
14 |
"metadata": {},
|
15 |
"cell_type": "code",
|
|
|
|
|
16 |
"source": [
|
17 |
"!pip install geopy > delete.txt\n",
|
18 |
"!pip install datasets > delete.txt\n",
|
|
|
23 |
"!pip install gensim > delete.txt\n",
|
24 |
"!rm delete.txt"
|
25 |
],
|
26 |
+
"id": "5a596f2639253772",
|
27 |
+
"outputs": [],
|
28 |
+
"execution_count": null
|
29 |
},
|
30 |
{
|
31 |
"metadata": {},
|
|
|
37 |
"id": "432a756039e6399"
|
38 |
},
|
39 |
{
|
40 |
+
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
41 |
"cell_type": "code",
|
42 |
"source": "!huggingface-cli login",
|
43 |
"id": "2e73da09a7c6171e",
|
44 |
"outputs": [],
|
45 |
+
"execution_count": null
|
46 |
},
|
47 |
{
|
48 |
"metadata": {},
|
|
|
63 |
"id": "b8920847b7cc378d"
|
64 |
},
|
65 |
{
|
66 |
+
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
67 |
"cell_type": "code",
|
68 |
"source": [
|
69 |
"from datasets import load_dataset\n",
|
|
|
74 |
],
|
75 |
"id": "877c90c978d62b7d",
|
76 |
"outputs": [],
|
77 |
+
"execution_count": null
|
78 |
},
|
79 |
{
|
80 |
+
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
81 |
"cell_type": "code",
|
82 |
"source": [
|
83 |
"import numpy as np\n",
|
|
|
213 |
],
|
214 |
"id": "dc2ba675ce880d6d",
|
215 |
"outputs": [],
|
216 |
+
"execution_count": null
|
217 |
},
|
218 |
{
|
219 |
+
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
220 |
"cell_type": "code",
|
221 |
"source": [
|
222 |
"from gensim.models import KeyedVectors\n",
|
|
|
240 |
")"
|
241 |
],
|
242 |
"id": "158b99950fb22d1",
|
243 |
+
"outputs": [],
|
244 |
+
"execution_count": null
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
},
|
246 |
{
|
247 |
+
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
248 |
"cell_type": "code",
|
249 |
"source": [
|
250 |
"print(dataset_train)\n",
|
251 |
"print(dataset_test)"
|
252 |
],
|
253 |
"id": "edd80d33175c96a0",
|
254 |
+
"outputs": [],
|
255 |
+
"execution_count": null
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
},
|
257 |
{
|
258 |
"metadata": {},
|
|
|
273 |
"id": "f0eae08a025b6ed9"
|
274 |
},
|
275 |
{
|
276 |
+
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
277 |
"cell_type": "code",
|
278 |
"source": [
|
279 |
"# TODO: import all packages necessary for your custom model\n",
|
|
|
366 |
" input_ids=seq_inputs[:,0,:],\n",
|
367 |
" attention_mask=seq_inputs[:,1,:]\n",
|
368 |
" ).pooler_output # last_hidden_state[:, 0, :]\n",
|
|
|
|
|
369 |
" freq_feature = self.freq(freq_inputs) # Shape: (batch_size, 128)\n",
|
370 |
"\n",
|
371 |
" pos_feature = self.pos(pos_inputs) #Shape: (batch_size, 128)\n",
|
|
|
389 |
],
|
390 |
"id": "21f079d0c52d7d",
|
391 |
"outputs": [],
|
392 |
+
"execution_count": null
|
393 |
},
|
394 |
{
|
395 |
+
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
396 |
"cell_type": "code",
|
397 |
"source": [
|
398 |
"from huggingface_hub import hf_hub_download\n",
|
|
|
405 |
"REPO_NAME = \"CISProject/News-Headline-Classifier-Notebook\" # TODO: PROVIDE A STRING TO YOUR REPO ON HUGGINGFACE"
|
406 |
],
|
407 |
"id": "b6ba3f96d3ce21",
|
408 |
+
"outputs": [],
|
409 |
+
"execution_count": null
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
},
|
411 |
{
|
412 |
+
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
413 |
"cell_type": "code",
|
414 |
"source": [
|
415 |
"import torch\n",
|
|
|
547 |
],
|
548 |
"id": "7be377251b81a25d",
|
549 |
"outputs": [],
|
550 |
+
"execution_count": null
|
551 |
},
|
552 |
{
|
553 |
+
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
554 |
"cell_type": "code",
|
555 |
"source": [
|
556 |
"from torch.utils.data import DataLoader\n",
|
|
|
575 |
"print(f\"Final model saved at {final_save_path}\")\n"
|
576 |
],
|
577 |
"id": "dd1749c306f148eb",
|
578 |
+
"outputs": [],
|
579 |
+
"execution_count": null
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
},
|
581 |
{
|
582 |
"metadata": {},
|