File size: 21,456 Bytes
73b504f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-12-06T19:54:24.990141Z",
     "start_time": "2024-12-06T19:53:17.183491Z"
    }
   },
   "source": [
    "!pip install geopy > delete.txt\n",
    "!pip install datasets > delete.txt\n",
    "!pip install torch torchvision datasets > delete.txt\n",
    "!pip install huggingface_hub > delete.txt\n",
    "!pip install pyhocon > delete.txt\n",
    "!pip install transformers > delete.txt\n",
    "!rm delete.txt"
   ],
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "'rm' is not recognized as an internal or external command,\n",
      "operable program or batch file.\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-06T19:56:26.136466Z",
     "start_time": "2024-12-06T19:54:38.679955Z"
    }
   },
   "cell_type": "code",
   "source": "!huggingface-cli login",
   "id": "b0a77c981c32a0c8",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "^C\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-06T19:57:30.983629Z",
     "start_time": "2024-12-06T19:57:29.451887Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset_train = load_dataset(\"CISProject/FOX_NBC\", split=\"train\")\n",
    "dataset_test = load_dataset(\"path/to/test\", split=\"test\")"
   ],
   "id": "a4aa3b759defc904",
   "outputs": [],
   "execution_count": 5
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-06T19:58:41.568459Z",
     "start_time": "2024-12-06T19:58:41.445848Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from transformers import BertTokenizer\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "\n",
    "def positional_encoding(seq_len, d_model):\n",
    "    pos_enc = np.zeros((seq_len, d_model))\n",
    "    for pos in range(seq_len):\n",
    "        for i in range(0, d_model, 2):\n",
    "            pos_enc[pos, i] = np.sin(pos / (10000 ** ((2 * i) / d_model)))\n",
    "            if i + 1 < d_model:\n",
    "                pos_enc[pos, i + 1] = np.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))\n",
    "    return torch.tensor(pos_enc, dtype=torch.float)\n",
    "\n",
    "def preprocess_data(data, mode=\"train\", tfidf_vectorizer=None, max_tfidf_features=4096, max_seq_length=128, num_proc=4):\n",
    "    tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "\n",
    "    # Initialize TF-IDF vectorizer if not provided\n",
    "    if tfidf_vectorizer is None and mode == \"train\":\n",
    "        tfidf_vectorizer = TfidfVectorizer(max_features=max_tfidf_features)\n",
    "\n",
    "    # Fit TF-IDF only in train mode\n",
    "    if mode == \"train\":\n",
    "        tfidf_vectorizer.fit(data[\"title\"])\n",
    "        print(\"TF-IDF vectorizer fitted on training data.\")\n",
    "\n",
    "    def process_batch(batch):\n",
    "        headlines = batch[\"title\"]\n",
    "        agencies = batch[\"news\"]\n",
    "\n",
    "        # TF-IDF transformation (batch-wise)\n",
    "        if mode == \"train\" or tfidf_vectorizer is not None:\n",
    "            freq_inputs = tfidf_vectorizer.transform(headlines).toarray()\n",
    "        else:\n",
    "            raise ValueError(\"TF-IDF vectorizer must be provided in test mode.\")\n",
    "\n",
    "        # Tokenization (batch-wise)\n",
    "        tokenized = tokenizer(\n",
    "            headlines,\n",
    "            padding=\"max_length\",\n",
    "            truncation=True,\n",
    "            max_length=max_seq_length,\n",
    "            return_tensors=\"pt\"\n",
    "        )\n",
    "\n",
    "        # Stack input_ids and attention_mask along a new dimension\n",
    "        input_ids = tokenized[\"input_ids\"]\n",
    "        attention_mask = tokenized[\"attention_mask\"]\n",
    "\n",
    "        # Ensure consistent stacking: (batch_size, 2, seq_len)\n",
    "        seq_inputs = torch.stack([input_ids, attention_mask], dim=1)\n",
    "\n",
    "        # Positional encoding\n",
    "        pos_inputs = positional_encoding(max_seq_length, 512).unsqueeze(0).expand(len(headlines), -1, -1)\n",
    "\n",
    "        # Labels\n",
    "        labels = [1.0 if agency == \"fox\" else 0.0 for agency in agencies]\n",
    "\n",
    "        return {\n",
    "            \"freq_inputs\": torch.tensor(freq_inputs),\n",
    "            \"seq_inputs\": seq_inputs,\n",
    "            \"pos_inputs\": pos_inputs,\n",
    "            \"labels\": torch.tensor(labels),\n",
    "        }\n",
    "\n",
    "    # Use `map` with batching and parallelism\n",
    "    processed_data = data.map(\n",
    "        process_batch,\n",
    "        batched=True,\n",
    "        batch_size=32,\n",
    "        num_proc=num_proc\n",
    "    )\n",
    "\n",
    "    return processed_data, tfidf_vectorizer"
   ],
   "id": "ce6e6b982e22e9fe",
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mValueError\u001B[0m                                Traceback (most recent call last)",
      "Cell \u001B[1;32mIn[12], line 4\u001B[0m\n\u001B[0;32m      2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\n\u001B[0;32m      3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtransformers\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m BertTokenizer\n\u001B[1;32m----> 4\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01msklearn\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mfeature_extraction\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mtext\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m TfidfVectorizer\n\u001B[0;32m      6\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mpositional_encoding\u001B[39m(seq_len, d_model):\n\u001B[0;32m      7\u001B[0m     pos_enc \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mzeros((seq_len, d_model))\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\__init__.py:84\u001B[0m\n\u001B[0;32m     70\u001B[0m     \u001B[38;5;66;03m# We are not importing the rest of scikit-learn during the build\u001B[39;00m\n\u001B[0;32m     71\u001B[0m     \u001B[38;5;66;03m# process, as it may not be compiled yet\u001B[39;00m\n\u001B[0;32m     72\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m   (...)\u001B[0m\n\u001B[0;32m     78\u001B[0m     \u001B[38;5;66;03m# later is linked to the OpenMP runtime to make it possible to introspect\u001B[39;00m\n\u001B[0;32m     79\u001B[0m     \u001B[38;5;66;03m# it and importing it first would fail if the OpenMP dll cannot be found.\u001B[39;00m\n\u001B[0;32m     80\u001B[0m     \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m (\n\u001B[0;32m     81\u001B[0m         __check_build,  \u001B[38;5;66;03m# noqa: F401\u001B[39;00m\n\u001B[0;32m     82\u001B[0m         _distributor_init,  \u001B[38;5;66;03m# noqa: F401\u001B[39;00m\n\u001B[0;32m     83\u001B[0m     )\n\u001B[1;32m---> 84\u001B[0m     \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbase\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m clone\n\u001B[0;32m     85\u001B[0m     \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_show_versions\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m show_versions\n\u001B[0;32m     87\u001B[0m     __all__ \u001B[38;5;241m=\u001B[39m [\n\u001B[0;32m     88\u001B[0m         \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcalibration\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m     89\u001B[0m         \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcluster\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[1;32m   (...)\u001B[0m\n\u001B[0;32m    130\u001B[0m         \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mshow_versions\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m    131\u001B[0m     ]\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\base.py:19\u001B[0m\n\u001B[0;32m     17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_config\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m config_context, get_config\n\u001B[0;32m     18\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mexceptions\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m InconsistentVersionWarning\n\u001B[1;32m---> 19\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_estimator_html_repr\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _HTMLDocumentationLinkMixin, estimator_html_repr\n\u001B[0;32m     20\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_metadata_requests\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _MetadataRequester, _routing_enabled\n\u001B[0;32m     21\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_param_validation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m validate_parameter_constraints\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\utils\\__init__.py:11\u001B[0m\n\u001B[0;32m      9\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _joblib, metadata_routing\n\u001B[0;32m     10\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_bunch\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Bunch\n\u001B[1;32m---> 11\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_chunking\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m gen_batches, gen_even_slices\n\u001B[0;32m     12\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_estimator_html_repr\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m estimator_html_repr\n\u001B[0;32m     14\u001B[0m \u001B[38;5;66;03m# Make _safe_indexing importable from here for backward compat as this particular\u001B[39;00m\n\u001B[0;32m     15\u001B[0m \u001B[38;5;66;03m# helper is considered semi-private and typically very useful for third-party\u001B[39;00m\n\u001B[0;32m     16\u001B[0m \u001B[38;5;66;03m# libraries that want to comply with scikit-learn's estimator API. In particular,\u001B[39;00m\n\u001B[0;32m     17\u001B[0m \u001B[38;5;66;03m# _safe_indexing was included in our public API documentation despite the leading\u001B[39;00m\n\u001B[0;32m     18\u001B[0m \u001B[38;5;66;03m# `_` in its name.\u001B[39;00m\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\utils\\_chunking.py:8\u001B[0m\n\u001B[0;32m      5\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[0;32m      7\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_config\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m get_config\n\u001B[1;32m----> 8\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_param_validation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Interval, validate_params\n\u001B[0;32m     11\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mchunk_generator\u001B[39m(gen, chunksize):\n\u001B[0;32m     12\u001B[0m \u001B[38;5;250m    \u001B[39m\u001B[38;5;124;03m\"\"\"Chunk generator, ``gen`` into lists of length ``chunksize``. The last\u001B[39;00m\n\u001B[0;32m     13\u001B[0m \u001B[38;5;124;03m    chunk may have a length less than ``chunksize``.\"\"\"\u001B[39;00m\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\utils\\_param_validation.py:11\u001B[0m\n\u001B[0;32m      8\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mnumbers\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Integral, Real\n\u001B[0;32m     10\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[1;32m---> 11\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mscipy\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01msparse\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m csr_matrix, issparse\n\u001B[0;32m     13\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_config\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m config_context, get_config\n\u001B[0;32m     14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mvalidation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _is_arraylike_not_scalar\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\scipy\\sparse\\__init__.py:297\u001B[0m\n\u001B[0;32m    295\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_csr\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[0;32m    296\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_csc\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[1;32m--> 297\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_lil\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[0;32m    298\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_dok\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[0;32m    299\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_coo\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\scipy\\sparse\\_lil.py:17\u001B[0m\n\u001B[0;32m     14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_index\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m IndexMixin, INT_TYPES, _broadcast_arrays\n\u001B[0;32m     15\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_sputils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m (getdtype, isshape, isscalarlike, upcast_scalar,\n\u001B[0;32m     16\u001B[0m                        check_shape, check_reshape_kwargs)\n\u001B[1;32m---> 17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _csparsetools\n\u001B[0;32m     20\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01m_lil_base\u001B[39;00m(_spbase, IndexMixin):\n\u001B[0;32m     21\u001B[0m     _format \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlil\u001B[39m\u001B[38;5;124m'\u001B[39m\n",
      "File \u001B[1;32mscipy\\\\sparse\\\\_csparsetools.pyx:1\u001B[0m, in \u001B[0;36minit _csparsetools\u001B[1;34m()\u001B[0m\n",
      "\u001B[1;31mValueError\u001B[0m: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "dataset_train, tfidf_vectorizer = preprocess_data(\n",
    "    data=dataset_train,\n",
    "    mode=\"train\",\n",
    "    max_tfidf_features=8192,\n",
    "    max_seq_length=128\n",
    ")\n",
    "\n",
    "dataset_test, _ = preprocess_data(\n",
    "    data=dataset_test,\n",
    "    mode=\"test\",\n",
    "    tfidf_vectorizer=tfidf_vectorizer,\n",
    "    max_tfidf_features=8192,\n",
    "    max_seq_length=128\n",
    ")"
   ],
   "id": "b605d3b4f5ff547a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Load model directly\n",
    "from transformers import AutoModel\n",
    "model = AutoModel.from_pretrained(\"CISProject/News-Headline-Classifier-Notebook\")"
   ],
   "id": "b20d11caa1d25445"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-06T19:53:05.824524Z",
     "start_time": "2024-12-06T19:53:05.550141Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# Define a collate function to handle the batched data\n",
    "def collate_fn(batch):\n",
    "    freq_inputs = torch.stack([torch.tensor(item[\"freq_inputs\"]) for item in batch])\n",
    "    seq_inputs = torch.stack([torch.tensor(item[\"seq_inputs\"]) for item in batch])\n",
    "    pos_inputs = torch.stack([torch.tensor(item[\"pos_inputs\"]) for item in batch])\n",
    "    labels = torch.tensor([torch.tensor(item[\"labels\"]) for item in batch])\n",
    "    return {\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs}, labels\n",
    "\n",
    "train_loader = DataLoader(dataset_train, batch_size=config.train[\"batch_size\"], shuffle=True,collate_fn=collate_fn)\n",
    "test_loader = DataLoader(dataset_test, batch_size=config.train[\"batch_size\"], shuffle=False,collate_fn=collate_fn)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "criterion = torch.nn.BCEWithLogitsLoss()\n",
    "\n",
    "def evaluate_model(model, val_loader, criterion, device=\"cuda\"):\n",
    "    model.eval()\n",
    "    val_loss = 0.0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch_inputs, labels in tqdm(val_loader, desc=\"Testing\", leave=False):\n",
    "            freq_inputs = batch_inputs[\"freq_inputs\"].to(device)\n",
    "            seq_inputs = batch_inputs[\"seq_inputs\"].to(device)\n",
    "            pos_inputs = batch_inputs[\"pos_inputs\"].to(device)\n",
    "            labels = labels[:,None].to(device)\n",
    "\n",
    "            preds = model({\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs})\n",
    "            loss = criterion(preds, labels)\n",
    "\n",
    "            val_loss += loss.item()\n",
    "            total += labels.size(0)\n",
    "            correct += ((torch.sigmoid(preds) > 0.5).float() == labels).sum().item()\n",
    "\n",
    "    print(f\"Test Loss: {val_loss / total:.4f}\")\n",
    "    print(f\"Test Accuracy: {correct / total:.4f}\")\n",
    "\n",
    "\n",
    "evaluate_model(model, test_loader, criterion)\n",
    "# Save the final model in Hugging Face format\n",
    "\n"
   ],
   "id": "1d23cedfe1d79660",
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'torch'",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mModuleNotFoundError\u001B[0m                       Traceback (most recent call last)",
      "Cell \u001B[1;32mIn[1], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdata\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m DataLoader\n\u001B[0;32m      3\u001B[0m \u001B[38;5;66;03m# Define a collate function to handle the batched data\u001B[39;00m\n\u001B[0;32m      4\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mcollate_fn\u001B[39m(batch):\n",
      "\u001B[1;31mModuleNotFoundError\u001B[0m: No module named 'torch'"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "549f3e0a004e80ab"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}