TUEN-YUE commited on
Commit
4a8afb6
·
verified ·
1 Parent(s): b50c998

Delete eval.py.ipynb

Browse files
Files changed (1) hide show
  1. eval.py.ipynb +0 -317
eval.py.ipynb DELETED
@@ -1,317 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "id": "initial_id",
6
- "metadata": {
7
- "collapsed": true,
8
- "ExecuteTime": {
9
- "end_time": "2024-12-06T19:54:24.990141Z",
10
- "start_time": "2024-12-06T19:53:17.183491Z"
11
- }
12
- },
13
- "source": [
14
- "!pip install geopy > delete.txt\n",
15
- "!pip install datasets > delete.txt\n",
16
- "!pip install torch torchvision datasets > delete.txt\n",
17
- "!pip install huggingface_hub > delete.txt\n",
18
- "!pip install pyhocon > delete.txt\n",
19
- "!pip install transformers > delete.txt\n",
20
- "!rm delete.txt"
21
- ],
22
- "outputs": [
23
- {
24
- "name": "stderr",
25
- "output_type": "stream",
26
- "text": [
27
- "'rm' is not recognized as an internal or external command,\n",
28
- "operable program or batch file.\n"
29
- ]
30
- }
31
- ],
32
- "execution_count": 2
33
- },
34
- {
35
- "metadata": {
36
- "ExecuteTime": {
37
- "end_time": "2024-12-06T19:56:26.136466Z",
38
- "start_time": "2024-12-06T19:54:38.679955Z"
39
- }
40
- },
41
- "cell_type": "code",
42
- "source": "!huggingface-cli login",
43
- "id": "b0a77c981c32a0c8",
44
- "outputs": [
45
- {
46
- "name": "stdout",
47
- "output_type": "stream",
48
- "text": [
49
- "^C\n"
50
- ]
51
- }
52
- ],
53
- "execution_count": 3
54
- },
55
- {
56
- "metadata": {
57
- "ExecuteTime": {
58
- "end_time": "2024-12-06T19:57:30.983629Z",
59
- "start_time": "2024-12-06T19:57:29.451887Z"
60
- }
61
- },
62
- "cell_type": "code",
63
- "source": [
64
- "from datasets import load_dataset\n",
65
- "\n",
66
- "dataset_train = load_dataset(\"CISProject/FOX_NBC\", split=\"train\")\n",
67
- "dataset_test = load_dataset(\"path/to/test\", split=\"test\")"
68
- ],
69
- "id": "a4aa3b759defc904",
70
- "outputs": [],
71
- "execution_count": 5
72
- },
73
- {
74
- "metadata": {
75
- "ExecuteTime": {
76
- "end_time": "2024-12-06T19:58:41.568459Z",
77
- "start_time": "2024-12-06T19:58:41.445848Z"
78
- }
79
- },
80
- "cell_type": "code",
81
- "source": [
82
- "import numpy as np\n",
83
- "import torch\n",
84
- "from transformers import BertTokenizer\n",
85
- "from sklearn.feature_extraction.text import TfidfVectorizer\n",
86
- "\n",
87
- "def positional_encoding(seq_len, d_model):\n",
88
- " pos_enc = np.zeros((seq_len, d_model))\n",
89
- " for pos in range(seq_len):\n",
90
- " for i in range(0, d_model, 2):\n",
91
- " pos_enc[pos, i] = np.sin(pos / (10000 ** ((2 * i) / d_model)))\n",
92
- " if i + 1 < d_model:\n",
93
- " pos_enc[pos, i + 1] = np.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))\n",
94
- " return torch.tensor(pos_enc, dtype=torch.float)\n",
95
- "\n",
96
- "def preprocess_data(data, mode=\"train\", tfidf_vectorizer=None, max_tfidf_features=4096, max_seq_length=128, num_proc=4):\n",
97
- " tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
98
- "\n",
99
- " # Initialize TF-IDF vectorizer if not provided\n",
100
- " if tfidf_vectorizer is None and mode == \"train\":\n",
101
- " tfidf_vectorizer = TfidfVectorizer(max_features=max_tfidf_features)\n",
102
- "\n",
103
- " # Fit TF-IDF only in train mode\n",
104
- " if mode == \"train\":\n",
105
- " tfidf_vectorizer.fit(data[\"title\"])\n",
106
- " print(\"TF-IDF vectorizer fitted on training data.\")\n",
107
- "\n",
108
- " def process_batch(batch):\n",
109
- " headlines = batch[\"title\"]\n",
110
- " agencies = batch[\"news\"]\n",
111
- "\n",
112
- " # TF-IDF transformation (batch-wise)\n",
113
- " if mode == \"train\" or tfidf_vectorizer is not None:\n",
114
- " freq_inputs = tfidf_vectorizer.transform(headlines).toarray()\n",
115
- " else:\n",
116
- " raise ValueError(\"TF-IDF vectorizer must be provided in test mode.\")\n",
117
- "\n",
118
- " # Tokenization (batch-wise)\n",
119
- " tokenized = tokenizer(\n",
120
- " headlines,\n",
121
- " padding=\"max_length\",\n",
122
- " truncation=True,\n",
123
- " max_length=max_seq_length,\n",
124
- " return_tensors=\"pt\"\n",
125
- " )\n",
126
- "\n",
127
- " # Stack input_ids and attention_mask along a new dimension\n",
128
- " input_ids = tokenized[\"input_ids\"]\n",
129
- " attention_mask = tokenized[\"attention_mask\"]\n",
130
- "\n",
131
- " # Ensure consistent stacking: (batch_size, 2, seq_len)\n",
132
- " seq_inputs = torch.stack([input_ids, attention_mask], dim=1)\n",
133
- "\n",
134
- " # Positional encoding\n",
135
- " pos_inputs = positional_encoding(max_seq_length, 512).unsqueeze(0).expand(len(headlines), -1, -1)\n",
136
- "\n",
137
- " # Labels\n",
138
- " labels = [1.0 if agency == \"fox\" else 0.0 for agency in agencies]\n",
139
- "\n",
140
- " return {\n",
141
- " \"freq_inputs\": torch.tensor(freq_inputs),\n",
142
- " \"seq_inputs\": seq_inputs,\n",
143
- " \"pos_inputs\": pos_inputs,\n",
144
- " \"labels\": torch.tensor(labels),\n",
145
- " }\n",
146
- "\n",
147
- " # Use `map` with batching and parallelism\n",
148
- " processed_data = data.map(\n",
149
- " process_batch,\n",
150
- " batched=True,\n",
151
- " batch_size=32,\n",
152
- " num_proc=num_proc\n",
153
- " )\n",
154
- "\n",
155
- " return processed_data, tfidf_vectorizer"
156
- ],
157
- "id": "ce6e6b982e22e9fe",
158
- "outputs": [
159
- {
160
- "ename": "ValueError",
161
- "evalue": "numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject",
162
- "output_type": "error",
163
- "traceback": [
164
- "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
165
- "\u001B[1;31mValueError\u001B[0m Traceback (most recent call last)",
166
- "Cell \u001B[1;32mIn[12], line 4\u001B[0m\n\u001B[0;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\n\u001B[0;32m 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtransformers\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m BertTokenizer\n\u001B[1;32m----> 4\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01msklearn\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mfeature_extraction\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mtext\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m TfidfVectorizer\n\u001B[0;32m 6\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mpositional_encoding\u001B[39m(seq_len, d_model):\n\u001B[0;32m 7\u001B[0m pos_enc \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mzeros((seq_len, d_model))\n",
167
- "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\__init__.py:84\u001B[0m\n\u001B[0;32m 70\u001B[0m \u001B[38;5;66;03m# We are not importing the rest of scikit-learn during the build\u001B[39;00m\n\u001B[0;32m 71\u001B[0m \u001B[38;5;66;03m# process, as it may not be compiled yet\u001B[39;00m\n\u001B[0;32m 72\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 78\u001B[0m \u001B[38;5;66;03m# later is linked to the OpenMP runtime to make it possible to introspect\u001B[39;00m\n\u001B[0;32m 79\u001B[0m \u001B[38;5;66;03m# it and importing it first would fail if the OpenMP dll cannot be found.\u001B[39;00m\n\u001B[0;32m 80\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m (\n\u001B[0;32m 81\u001B[0m __check_build, \u001B[38;5;66;03m# noqa: F401\u001B[39;00m\n\u001B[0;32m 82\u001B[0m _distributor_init, \u001B[38;5;66;03m# noqa: F401\u001B[39;00m\n\u001B[0;32m 83\u001B[0m )\n\u001B[1;32m---> 84\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbase\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m clone\n\u001B[0;32m 85\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_show_versions\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m show_versions\n\u001B[0;32m 87\u001B[0m __all__ \u001B[38;5;241m=\u001B[39m [\n\u001B[0;32m 88\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcalibration\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m 89\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcluster\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 130\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mshow_versions\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m 131\u001B[0m ]\n",
168
- "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\base.py:19\u001B[0m\n\u001B[0;32m 17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_config\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m config_context, get_config\n\u001B[0;32m 18\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mexceptions\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m InconsistentVersionWarning\n\u001B[1;32m---> 19\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_estimator_html_repr\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _HTMLDocumentationLinkMixin, estimator_html_repr\n\u001B[0;32m 20\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_metadata_requests\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _MetadataRequester, _routing_enabled\n\u001B[0;32m 21\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_param_validation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m validate_parameter_constraints\n",
169
- "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\utils\\__init__.py:11\u001B[0m\n\u001B[0;32m 9\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _joblib, metadata_routing\n\u001B[0;32m 10\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_bunch\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Bunch\n\u001B[1;32m---> 11\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_chunking\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m gen_batches, gen_even_slices\n\u001B[0;32m 12\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_estimator_html_repr\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m estimator_html_repr\n\u001B[0;32m 14\u001B[0m \u001B[38;5;66;03m# Make _safe_indexing importable from here for backward compat as this particular\u001B[39;00m\n\u001B[0;32m 15\u001B[0m \u001B[38;5;66;03m# helper is considered semi-private and typically very useful for third-party\u001B[39;00m\n\u001B[0;32m 16\u001B[0m \u001B[38;5;66;03m# libraries that want to comply with scikit-learn's estimator API. In particular,\u001B[39;00m\n\u001B[0;32m 17\u001B[0m \u001B[38;5;66;03m# _safe_indexing was included in our public API documentation despite the leading\u001B[39;00m\n\u001B[0;32m 18\u001B[0m \u001B[38;5;66;03m# `_` in its name.\u001B[39;00m\n",
170
- "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\utils\\_chunking.py:8\u001B[0m\n\u001B[0;32m 5\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[0;32m 7\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_config\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m get_config\n\u001B[1;32m----> 8\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_param_validation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Interval, validate_params\n\u001B[0;32m 11\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mchunk_generator\u001B[39m(gen, chunksize):\n\u001B[0;32m 12\u001B[0m \u001B[38;5;250m \u001B[39m\u001B[38;5;124;03m\"\"\"Chunk generator, ``gen`` into lists of length ``chunksize``. The last\u001B[39;00m\n\u001B[0;32m 13\u001B[0m \u001B[38;5;124;03m chunk may have a length less than ``chunksize``.\"\"\"\u001B[39;00m\n",
171
- "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\utils\\_param_validation.py:11\u001B[0m\n\u001B[0;32m 8\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mnumbers\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Integral, Real\n\u001B[0;32m 10\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[1;32m---> 11\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mscipy\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01msparse\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m csr_matrix, issparse\n\u001B[0;32m 13\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_config\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m config_context, get_config\n\u001B[0;32m 14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mvalidation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _is_arraylike_not_scalar\n",
172
- "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\scipy\\sparse\\__init__.py:297\u001B[0m\n\u001B[0;32m 295\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_csr\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[0;32m 296\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_csc\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[1;32m--> 297\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_lil\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[0;32m 298\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_dok\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[0;32m 299\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_coo\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n",
173
- "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\scipy\\sparse\\_lil.py:17\u001B[0m\n\u001B[0;32m 14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_index\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m IndexMixin, INT_TYPES, _broadcast_arrays\n\u001B[0;32m 15\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_sputils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m (getdtype, isshape, isscalarlike, upcast_scalar,\n\u001B[0;32m 16\u001B[0m check_shape, check_reshape_kwargs)\n\u001B[1;32m---> 17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _csparsetools\n\u001B[0;32m 20\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01m_lil_base\u001B[39;00m(_spbase, IndexMixin):\n\u001B[0;32m 21\u001B[0m _format \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlil\u001B[39m\u001B[38;5;124m'\u001B[39m\n",
174
- "File \u001B[1;32mscipy\\\\sparse\\\\_csparsetools.pyx:1\u001B[0m, in \u001B[0;36minit _csparsetools\u001B[1;34m()\u001B[0m\n",
175
- "\u001B[1;31mValueError\u001B[0m: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject"
176
- ]
177
- }
178
- ],
179
- "execution_count": 12
180
- },
181
- {
182
- "metadata": {},
183
- "cell_type": "code",
184
- "outputs": [],
185
- "execution_count": null,
186
- "source": [
187
- "dataset_train, tfidf_vectorizer = preprocess_data(\n",
188
- " data=dataset_train,\n",
189
- " mode=\"train\",\n",
190
- " max_tfidf_features=8192,\n",
191
- " max_seq_length=128\n",
192
- ")\n",
193
- "\n",
194
- "dataset_test, _ = preprocess_data(\n",
195
- " data=dataset_test,\n",
196
- " mode=\"test\",\n",
197
- " tfidf_vectorizer=tfidf_vectorizer,\n",
198
- " max_tfidf_features=8192,\n",
199
- " max_seq_length=128\n",
200
- ")"
201
- ],
202
- "id": "b605d3b4f5ff547a"
203
- },
204
- {
205
- "metadata": {},
206
- "cell_type": "code",
207
- "outputs": [],
208
- "execution_count": null,
209
- "source": [
210
- "# Load model directly\n",
211
- "from transformers import AutoModel\n",
212
- "model = AutoModel.from_pretrained(\"CISProject/News-Headline-Classifier-Notebook\")"
213
- ],
214
- "id": "b20d11caa1d25445"
215
- },
216
- {
217
- "metadata": {
218
- "ExecuteTime": {
219
- "end_time": "2024-12-06T19:53:05.824524Z",
220
- "start_time": "2024-12-06T19:53:05.550141Z"
221
- }
222
- },
223
- "cell_type": "code",
224
- "source": [
225
- "from torch.utils.data import DataLoader\n",
226
- "\n",
227
- "# Define a collate function to handle the batched data\n",
228
- "def collate_fn(batch):\n",
229
- " freq_inputs = torch.stack([torch.tensor(item[\"freq_inputs\"]) for item in batch])\n",
230
- " seq_inputs = torch.stack([torch.tensor(item[\"seq_inputs\"]) for item in batch])\n",
231
- " pos_inputs = torch.stack([torch.tensor(item[\"pos_inputs\"]) for item in batch])\n",
232
- " labels = torch.tensor([torch.tensor(item[\"labels\"]) for item in batch])\n",
233
- " return {\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs}, labels\n",
234
- "\n",
235
- "train_loader = DataLoader(dataset_train, batch_size=config.train[\"batch_size\"], shuffle=True,collate_fn=collate_fn)\n",
236
- "test_loader = DataLoader(dataset_test, batch_size=config.train[\"batch_size\"], shuffle=False,collate_fn=collate_fn)\n",
237
- "\n",
238
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
239
- "model.to(device)\n",
240
- "\n",
241
- "criterion = torch.nn.BCEWithLogitsLoss()\n",
242
- "\n",
243
- "def evaluate_model(model, val_loader, criterion, device=\"cuda\"):\n",
244
- " model.eval()\n",
245
- " val_loss = 0.0\n",
246
- " correct = 0\n",
247
- " total = 0\n",
248
- "\n",
249
- " with torch.no_grad():\n",
250
- " for batch_inputs, labels in tqdm(val_loader, desc=\"Testing\", leave=False):\n",
251
- " freq_inputs = batch_inputs[\"freq_inputs\"].to(device)\n",
252
- " seq_inputs = batch_inputs[\"seq_inputs\"].to(device)\n",
253
- " pos_inputs = batch_inputs[\"pos_inputs\"].to(device)\n",
254
- " labels = labels[:,None].to(device)\n",
255
- "\n",
256
- " preds = model({\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs})\n",
257
- " loss = criterion(preds, labels)\n",
258
- "\n",
259
- " val_loss += loss.item()\n",
260
- " total += labels.size(0)\n",
261
- " correct += ((torch.sigmoid(preds) > 0.5).float() == labels).sum().item()\n",
262
- "\n",
263
- " print(f\"Test Loss: {val_loss / total:.4f}\")\n",
264
- " print(f\"Test Accuracy: {correct / total:.4f}\")\n",
265
- "\n",
266
- "\n",
267
- "evaluate_model(model, test_loader, criterion)\n",
268
- "# Save the final model in Hugging Face format\n",
269
- "\n"
270
- ],
271
- "id": "1d23cedfe1d79660",
272
- "outputs": [
273
- {
274
- "ename": "ModuleNotFoundError",
275
- "evalue": "No module named 'torch'",
276
- "output_type": "error",
277
- "traceback": [
278
- "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
279
- "\u001B[1;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)",
280
- "Cell \u001B[1;32mIn[1], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdata\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m DataLoader\n\u001B[0;32m 3\u001B[0m \u001B[38;5;66;03m# Define a collate function to handle the batched data\u001B[39;00m\n\u001B[0;32m 4\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mcollate_fn\u001B[39m(batch):\n",
281
- "\u001B[1;31mModuleNotFoundError\u001B[0m: No module named 'torch'"
282
- ]
283
- }
284
- ],
285
- "execution_count": 1
286
- },
287
- {
288
- "metadata": {},
289
- "cell_type": "code",
290
- "outputs": [],
291
- "execution_count": null,
292
- "source": "",
293
- "id": "549f3e0a004e80ab"
294
- }
295
- ],
296
- "metadata": {
297
- "kernelspec": {
298
- "display_name": "Python 3",
299
- "language": "python",
300
- "name": "python3"
301
- },
302
- "language_info": {
303
- "codemirror_mode": {
304
- "name": "ipython",
305
- "version": 2
306
- },
307
- "file_extension": ".py",
308
- "mimetype": "text/x-python",
309
- "name": "python",
310
- "nbconvert_exporter": "python",
311
- "pygments_lexer": "ipython2",
312
- "version": "2.7.6"
313
- }
314
- },
315
- "nbformat": 4,
316
- "nbformat_minor": 5
317
- }