Delete eval.py.ipynb
Browse files- eval.py.ipynb +0 -317
eval.py.ipynb
DELETED
@@ -1,317 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"id": "initial_id",
|
6 |
-
"metadata": {
|
7 |
-
"collapsed": true,
|
8 |
-
"ExecuteTime": {
|
9 |
-
"end_time": "2024-12-06T19:54:24.990141Z",
|
10 |
-
"start_time": "2024-12-06T19:53:17.183491Z"
|
11 |
-
}
|
12 |
-
},
|
13 |
-
"source": [
|
14 |
-
"!pip install geopy > delete.txt\n",
|
15 |
-
"!pip install datasets > delete.txt\n",
|
16 |
-
"!pip install torch torchvision datasets > delete.txt\n",
|
17 |
-
"!pip install huggingface_hub > delete.txt\n",
|
18 |
-
"!pip install pyhocon > delete.txt\n",
|
19 |
-
"!pip install transformers > delete.txt\n",
|
20 |
-
"!rm delete.txt"
|
21 |
-
],
|
22 |
-
"outputs": [
|
23 |
-
{
|
24 |
-
"name": "stderr",
|
25 |
-
"output_type": "stream",
|
26 |
-
"text": [
|
27 |
-
"'rm' is not recognized as an internal or external command,\n",
|
28 |
-
"operable program or batch file.\n"
|
29 |
-
]
|
30 |
-
}
|
31 |
-
],
|
32 |
-
"execution_count": 2
|
33 |
-
},
|
34 |
-
{
|
35 |
-
"metadata": {
|
36 |
-
"ExecuteTime": {
|
37 |
-
"end_time": "2024-12-06T19:56:26.136466Z",
|
38 |
-
"start_time": "2024-12-06T19:54:38.679955Z"
|
39 |
-
}
|
40 |
-
},
|
41 |
-
"cell_type": "code",
|
42 |
-
"source": "!huggingface-cli login",
|
43 |
-
"id": "b0a77c981c32a0c8",
|
44 |
-
"outputs": [
|
45 |
-
{
|
46 |
-
"name": "stdout",
|
47 |
-
"output_type": "stream",
|
48 |
-
"text": [
|
49 |
-
"^C\n"
|
50 |
-
]
|
51 |
-
}
|
52 |
-
],
|
53 |
-
"execution_count": 3
|
54 |
-
},
|
55 |
-
{
|
56 |
-
"metadata": {
|
57 |
-
"ExecuteTime": {
|
58 |
-
"end_time": "2024-12-06T19:57:30.983629Z",
|
59 |
-
"start_time": "2024-12-06T19:57:29.451887Z"
|
60 |
-
}
|
61 |
-
},
|
62 |
-
"cell_type": "code",
|
63 |
-
"source": [
|
64 |
-
"from datasets import load_dataset\n",
|
65 |
-
"\n",
|
66 |
-
"dataset_train = load_dataset(\"CISProject/FOX_NBC\", split=\"train\")\n",
|
67 |
-
"dataset_test = load_dataset(\"path/to/test\", split=\"test\")"
|
68 |
-
],
|
69 |
-
"id": "a4aa3b759defc904",
|
70 |
-
"outputs": [],
|
71 |
-
"execution_count": 5
|
72 |
-
},
|
73 |
-
{
|
74 |
-
"metadata": {
|
75 |
-
"ExecuteTime": {
|
76 |
-
"end_time": "2024-12-06T19:58:41.568459Z",
|
77 |
-
"start_time": "2024-12-06T19:58:41.445848Z"
|
78 |
-
}
|
79 |
-
},
|
80 |
-
"cell_type": "code",
|
81 |
-
"source": [
|
82 |
-
"import numpy as np\n",
|
83 |
-
"import torch\n",
|
84 |
-
"from transformers import BertTokenizer\n",
|
85 |
-
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
86 |
-
"\n",
|
87 |
-
"def positional_encoding(seq_len, d_model):\n",
|
88 |
-
" pos_enc = np.zeros((seq_len, d_model))\n",
|
89 |
-
" for pos in range(seq_len):\n",
|
90 |
-
" for i in range(0, d_model, 2):\n",
|
91 |
-
" pos_enc[pos, i] = np.sin(pos / (10000 ** ((2 * i) / d_model)))\n",
|
92 |
-
" if i + 1 < d_model:\n",
|
93 |
-
" pos_enc[pos, i + 1] = np.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))\n",
|
94 |
-
" return torch.tensor(pos_enc, dtype=torch.float)\n",
|
95 |
-
"\n",
|
96 |
-
"def preprocess_data(data, mode=\"train\", tfidf_vectorizer=None, max_tfidf_features=4096, max_seq_length=128, num_proc=4):\n",
|
97 |
-
" tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
|
98 |
-
"\n",
|
99 |
-
" # Initialize TF-IDF vectorizer if not provided\n",
|
100 |
-
" if tfidf_vectorizer is None and mode == \"train\":\n",
|
101 |
-
" tfidf_vectorizer = TfidfVectorizer(max_features=max_tfidf_features)\n",
|
102 |
-
"\n",
|
103 |
-
" # Fit TF-IDF only in train mode\n",
|
104 |
-
" if mode == \"train\":\n",
|
105 |
-
" tfidf_vectorizer.fit(data[\"title\"])\n",
|
106 |
-
" print(\"TF-IDF vectorizer fitted on training data.\")\n",
|
107 |
-
"\n",
|
108 |
-
" def process_batch(batch):\n",
|
109 |
-
" headlines = batch[\"title\"]\n",
|
110 |
-
" agencies = batch[\"news\"]\n",
|
111 |
-
"\n",
|
112 |
-
" # TF-IDF transformation (batch-wise)\n",
|
113 |
-
" if mode == \"train\" or tfidf_vectorizer is not None:\n",
|
114 |
-
" freq_inputs = tfidf_vectorizer.transform(headlines).toarray()\n",
|
115 |
-
" else:\n",
|
116 |
-
" raise ValueError(\"TF-IDF vectorizer must be provided in test mode.\")\n",
|
117 |
-
"\n",
|
118 |
-
" # Tokenization (batch-wise)\n",
|
119 |
-
" tokenized = tokenizer(\n",
|
120 |
-
" headlines,\n",
|
121 |
-
" padding=\"max_length\",\n",
|
122 |
-
" truncation=True,\n",
|
123 |
-
" max_length=max_seq_length,\n",
|
124 |
-
" return_tensors=\"pt\"\n",
|
125 |
-
" )\n",
|
126 |
-
"\n",
|
127 |
-
" # Stack input_ids and attention_mask along a new dimension\n",
|
128 |
-
" input_ids = tokenized[\"input_ids\"]\n",
|
129 |
-
" attention_mask = tokenized[\"attention_mask\"]\n",
|
130 |
-
"\n",
|
131 |
-
" # Ensure consistent stacking: (batch_size, 2, seq_len)\n",
|
132 |
-
" seq_inputs = torch.stack([input_ids, attention_mask], dim=1)\n",
|
133 |
-
"\n",
|
134 |
-
" # Positional encoding\n",
|
135 |
-
" pos_inputs = positional_encoding(max_seq_length, 512).unsqueeze(0).expand(len(headlines), -1, -1)\n",
|
136 |
-
"\n",
|
137 |
-
" # Labels\n",
|
138 |
-
" labels = [1.0 if agency == \"fox\" else 0.0 for agency in agencies]\n",
|
139 |
-
"\n",
|
140 |
-
" return {\n",
|
141 |
-
" \"freq_inputs\": torch.tensor(freq_inputs),\n",
|
142 |
-
" \"seq_inputs\": seq_inputs,\n",
|
143 |
-
" \"pos_inputs\": pos_inputs,\n",
|
144 |
-
" \"labels\": torch.tensor(labels),\n",
|
145 |
-
" }\n",
|
146 |
-
"\n",
|
147 |
-
" # Use `map` with batching and parallelism\n",
|
148 |
-
" processed_data = data.map(\n",
|
149 |
-
" process_batch,\n",
|
150 |
-
" batched=True,\n",
|
151 |
-
" batch_size=32,\n",
|
152 |
-
" num_proc=num_proc\n",
|
153 |
-
" )\n",
|
154 |
-
"\n",
|
155 |
-
" return processed_data, tfidf_vectorizer"
|
156 |
-
],
|
157 |
-
"id": "ce6e6b982e22e9fe",
|
158 |
-
"outputs": [
|
159 |
-
{
|
160 |
-
"ename": "ValueError",
|
161 |
-
"evalue": "numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject",
|
162 |
-
"output_type": "error",
|
163 |
-
"traceback": [
|
164 |
-
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
165 |
-
"\u001B[1;31mValueError\u001B[0m Traceback (most recent call last)",
|
166 |
-
"Cell \u001B[1;32mIn[12], line 4\u001B[0m\n\u001B[0;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\n\u001B[0;32m 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtransformers\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m BertTokenizer\n\u001B[1;32m----> 4\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01msklearn\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mfeature_extraction\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mtext\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m TfidfVectorizer\n\u001B[0;32m 6\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mpositional_encoding\u001B[39m(seq_len, d_model):\n\u001B[0;32m 7\u001B[0m pos_enc \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mzeros((seq_len, d_model))\n",
|
167 |
-
"File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\__init__.py:84\u001B[0m\n\u001B[0;32m 70\u001B[0m \u001B[38;5;66;03m# We are not importing the rest of scikit-learn during the build\u001B[39;00m\n\u001B[0;32m 71\u001B[0m \u001B[38;5;66;03m# process, as it may not be compiled yet\u001B[39;00m\n\u001B[0;32m 72\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 78\u001B[0m \u001B[38;5;66;03m# later is linked to the OpenMP runtime to make it possible to introspect\u001B[39;00m\n\u001B[0;32m 79\u001B[0m \u001B[38;5;66;03m# it and importing it first would fail if the OpenMP dll cannot be found.\u001B[39;00m\n\u001B[0;32m 80\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m (\n\u001B[0;32m 81\u001B[0m __check_build, \u001B[38;5;66;03m# noqa: F401\u001B[39;00m\n\u001B[0;32m 82\u001B[0m _distributor_init, \u001B[38;5;66;03m# noqa: F401\u001B[39;00m\n\u001B[0;32m 83\u001B[0m )\n\u001B[1;32m---> 84\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbase\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m clone\n\u001B[0;32m 85\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_show_versions\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m show_versions\n\u001B[0;32m 87\u001B[0m __all__ \u001B[38;5;241m=\u001B[39m [\n\u001B[0;32m 88\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcalibration\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m 89\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcluster\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 130\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mshow_versions\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m 131\u001B[0m ]\n",
|
168 |
-
"File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\base.py:19\u001B[0m\n\u001B[0;32m 17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_config\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m config_context, get_config\n\u001B[0;32m 18\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mexceptions\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m InconsistentVersionWarning\n\u001B[1;32m---> 19\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_estimator_html_repr\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _HTMLDocumentationLinkMixin, estimator_html_repr\n\u001B[0;32m 20\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_metadata_requests\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _MetadataRequester, _routing_enabled\n\u001B[0;32m 21\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_param_validation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m validate_parameter_constraints\n",
|
169 |
-
"File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\utils\\__init__.py:11\u001B[0m\n\u001B[0;32m 9\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _joblib, metadata_routing\n\u001B[0;32m 10\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_bunch\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Bunch\n\u001B[1;32m---> 11\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_chunking\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m gen_batches, gen_even_slices\n\u001B[0;32m 12\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_estimator_html_repr\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m estimator_html_repr\n\u001B[0;32m 14\u001B[0m \u001B[38;5;66;03m# Make _safe_indexing importable from here for backward compat as this particular\u001B[39;00m\n\u001B[0;32m 15\u001B[0m \u001B[38;5;66;03m# helper is considered semi-private and typically very useful for third-party\u001B[39;00m\n\u001B[0;32m 16\u001B[0m \u001B[38;5;66;03m# libraries that want to comply with scikit-learn's estimator API. In particular,\u001B[39;00m\n\u001B[0;32m 17\u001B[0m \u001B[38;5;66;03m# _safe_indexing was included in our public API documentation despite the leading\u001B[39;00m\n\u001B[0;32m 18\u001B[0m \u001B[38;5;66;03m# `_` in its name.\u001B[39;00m\n",
|
170 |
-
"File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\utils\\_chunking.py:8\u001B[0m\n\u001B[0;32m 5\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[0;32m 7\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_config\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m get_config\n\u001B[1;32m----> 8\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_param_validation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Interval, validate_params\n\u001B[0;32m 11\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mchunk_generator\u001B[39m(gen, chunksize):\n\u001B[0;32m 12\u001B[0m \u001B[38;5;250m \u001B[39m\u001B[38;5;124;03m\"\"\"Chunk generator, ``gen`` into lists of length ``chunksize``. The last\u001B[39;00m\n\u001B[0;32m 13\u001B[0m \u001B[38;5;124;03m chunk may have a length less than ``chunksize``.\"\"\"\u001B[39;00m\n",
|
171 |
-
"File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\utils\\_param_validation.py:11\u001B[0m\n\u001B[0;32m 8\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mnumbers\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Integral, Real\n\u001B[0;32m 10\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[1;32m---> 11\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mscipy\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01msparse\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m csr_matrix, issparse\n\u001B[0;32m 13\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_config\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m config_context, get_config\n\u001B[0;32m 14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mvalidation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _is_arraylike_not_scalar\n",
|
172 |
-
"File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\scipy\\sparse\\__init__.py:297\u001B[0m\n\u001B[0;32m 295\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_csr\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[0;32m 296\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_csc\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[1;32m--> 297\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_lil\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[0;32m 298\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_dok\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[0;32m 299\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_coo\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n",
|
173 |
-
"File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\scipy\\sparse\\_lil.py:17\u001B[0m\n\u001B[0;32m 14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_index\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m IndexMixin, INT_TYPES, _broadcast_arrays\n\u001B[0;32m 15\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_sputils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m (getdtype, isshape, isscalarlike, upcast_scalar,\n\u001B[0;32m 16\u001B[0m check_shape, check_reshape_kwargs)\n\u001B[1;32m---> 17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _csparsetools\n\u001B[0;32m 20\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01m_lil_base\u001B[39;00m(_spbase, IndexMixin):\n\u001B[0;32m 21\u001B[0m _format \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlil\u001B[39m\u001B[38;5;124m'\u001B[39m\n",
|
174 |
-
"File \u001B[1;32mscipy\\\\sparse\\\\_csparsetools.pyx:1\u001B[0m, in \u001B[0;36minit _csparsetools\u001B[1;34m()\u001B[0m\n",
|
175 |
-
"\u001B[1;31mValueError\u001B[0m: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject"
|
176 |
-
]
|
177 |
-
}
|
178 |
-
],
|
179 |
-
"execution_count": 12
|
180 |
-
},
|
181 |
-
{
|
182 |
-
"metadata": {},
|
183 |
-
"cell_type": "code",
|
184 |
-
"outputs": [],
|
185 |
-
"execution_count": null,
|
186 |
-
"source": [
|
187 |
-
"dataset_train, tfidf_vectorizer = preprocess_data(\n",
|
188 |
-
" data=dataset_train,\n",
|
189 |
-
" mode=\"train\",\n",
|
190 |
-
" max_tfidf_features=8192,\n",
|
191 |
-
" max_seq_length=128\n",
|
192 |
-
")\n",
|
193 |
-
"\n",
|
194 |
-
"dataset_test, _ = preprocess_data(\n",
|
195 |
-
" data=dataset_test,\n",
|
196 |
-
" mode=\"test\",\n",
|
197 |
-
" tfidf_vectorizer=tfidf_vectorizer,\n",
|
198 |
-
" max_tfidf_features=8192,\n",
|
199 |
-
" max_seq_length=128\n",
|
200 |
-
")"
|
201 |
-
],
|
202 |
-
"id": "b605d3b4f5ff547a"
|
203 |
-
},
|
204 |
-
{
|
205 |
-
"metadata": {},
|
206 |
-
"cell_type": "code",
|
207 |
-
"outputs": [],
|
208 |
-
"execution_count": null,
|
209 |
-
"source": [
|
210 |
-
"# Load model directly\n",
|
211 |
-
"from transformers import AutoModel\n",
|
212 |
-
"model = AutoModel.from_pretrained(\"CISProject/News-Headline-Classifier-Notebook\")"
|
213 |
-
],
|
214 |
-
"id": "b20d11caa1d25445"
|
215 |
-
},
|
216 |
-
{
|
217 |
-
"metadata": {
|
218 |
-
"ExecuteTime": {
|
219 |
-
"end_time": "2024-12-06T19:53:05.824524Z",
|
220 |
-
"start_time": "2024-12-06T19:53:05.550141Z"
|
221 |
-
}
|
222 |
-
},
|
223 |
-
"cell_type": "code",
|
224 |
-
"source": [
|
225 |
-
"from torch.utils.data import DataLoader\n",
|
226 |
-
"\n",
|
227 |
-
"# Define a collate function to handle the batched data\n",
|
228 |
-
"def collate_fn(batch):\n",
|
229 |
-
" freq_inputs = torch.stack([torch.tensor(item[\"freq_inputs\"]) for item in batch])\n",
|
230 |
-
" seq_inputs = torch.stack([torch.tensor(item[\"seq_inputs\"]) for item in batch])\n",
|
231 |
-
" pos_inputs = torch.stack([torch.tensor(item[\"pos_inputs\"]) for item in batch])\n",
|
232 |
-
" labels = torch.tensor([torch.tensor(item[\"labels\"]) for item in batch])\n",
|
233 |
-
" return {\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs}, labels\n",
|
234 |
-
"\n",
|
235 |
-
"train_loader = DataLoader(dataset_train, batch_size=config.train[\"batch_size\"], shuffle=True,collate_fn=collate_fn)\n",
|
236 |
-
"test_loader = DataLoader(dataset_test, batch_size=config.train[\"batch_size\"], shuffle=False,collate_fn=collate_fn)\n",
|
237 |
-
"\n",
|
238 |
-
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
239 |
-
"model.to(device)\n",
|
240 |
-
"\n",
|
241 |
-
"criterion = torch.nn.BCEWithLogitsLoss()\n",
|
242 |
-
"\n",
|
243 |
-
"def evaluate_model(model, val_loader, criterion, device=\"cuda\"):\n",
|
244 |
-
" model.eval()\n",
|
245 |
-
" val_loss = 0.0\n",
|
246 |
-
" correct = 0\n",
|
247 |
-
" total = 0\n",
|
248 |
-
"\n",
|
249 |
-
" with torch.no_grad():\n",
|
250 |
-
" for batch_inputs, labels in tqdm(val_loader, desc=\"Testing\", leave=False):\n",
|
251 |
-
" freq_inputs = batch_inputs[\"freq_inputs\"].to(device)\n",
|
252 |
-
" seq_inputs = batch_inputs[\"seq_inputs\"].to(device)\n",
|
253 |
-
" pos_inputs = batch_inputs[\"pos_inputs\"].to(device)\n",
|
254 |
-
" labels = labels[:,None].to(device)\n",
|
255 |
-
"\n",
|
256 |
-
" preds = model({\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs})\n",
|
257 |
-
" loss = criterion(preds, labels)\n",
|
258 |
-
"\n",
|
259 |
-
" val_loss += loss.item()\n",
|
260 |
-
" total += labels.size(0)\n",
|
261 |
-
" correct += ((torch.sigmoid(preds) > 0.5).float() == labels).sum().item()\n",
|
262 |
-
"\n",
|
263 |
-
" print(f\"Test Loss: {val_loss / total:.4f}\")\n",
|
264 |
-
" print(f\"Test Accuracy: {correct / total:.4f}\")\n",
|
265 |
-
"\n",
|
266 |
-
"\n",
|
267 |
-
"evaluate_model(model, test_loader, criterion)\n",
|
268 |
-
"# Save the final model in Hugging Face format\n",
|
269 |
-
"\n"
|
270 |
-
],
|
271 |
-
"id": "1d23cedfe1d79660",
|
272 |
-
"outputs": [
|
273 |
-
{
|
274 |
-
"ename": "ModuleNotFoundError",
|
275 |
-
"evalue": "No module named 'torch'",
|
276 |
-
"output_type": "error",
|
277 |
-
"traceback": [
|
278 |
-
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
279 |
-
"\u001B[1;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)",
|
280 |
-
"Cell \u001B[1;32mIn[1], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdata\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m DataLoader\n\u001B[0;32m 3\u001B[0m \u001B[38;5;66;03m# Define a collate function to handle the batched data\u001B[39;00m\n\u001B[0;32m 4\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mcollate_fn\u001B[39m(batch):\n",
|
281 |
-
"\u001B[1;31mModuleNotFoundError\u001B[0m: No module named 'torch'"
|
282 |
-
]
|
283 |
-
}
|
284 |
-
],
|
285 |
-
"execution_count": 1
|
286 |
-
},
|
287 |
-
{
|
288 |
-
"metadata": {},
|
289 |
-
"cell_type": "code",
|
290 |
-
"outputs": [],
|
291 |
-
"execution_count": null,
|
292 |
-
"source": "",
|
293 |
-
"id": "549f3e0a004e80ab"
|
294 |
-
}
|
295 |
-
],
|
296 |
-
"metadata": {
|
297 |
-
"kernelspec": {
|
298 |
-
"display_name": "Python 3",
|
299 |
-
"language": "python",
|
300 |
-
"name": "python3"
|
301 |
-
},
|
302 |
-
"language_info": {
|
303 |
-
"codemirror_mode": {
|
304 |
-
"name": "ipython",
|
305 |
-
"version": 2
|
306 |
-
},
|
307 |
-
"file_extension": ".py",
|
308 |
-
"mimetype": "text/x-python",
|
309 |
-
"name": "python",
|
310 |
-
"nbconvert_exporter": "python",
|
311 |
-
"pygments_lexer": "ipython2",
|
312 |
-
"version": "2.7.6"
|
313 |
-
}
|
314 |
-
},
|
315 |
-
"nbformat": 4,
|
316 |
-
"nbformat_minor": 5
|
317 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|