Valeriy Sinyukov commited on
Commit
da67e9c
·
1 Parent(s): 82ec9f7

Add ipynb for test

Browse files
Files changed (1) hide show
  1. category_classification/test.ipynb +225 -0
category_classification/test.ipynb ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import json\n",
10
+ "import math\n",
11
+ "from pathlib import Path\n",
12
+ "\n",
13
+ "import numpy as np\n",
14
+ "import pandas as pd\n",
15
+ "from datasets import Dataset\n",
16
+ "from sklearn.metrics import f1_score, accuracy_score, log_loss\n",
17
+ "from tqdm import tqdm\n",
18
+ "\n",
19
+ "from models.models import language_to_models"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "en = \"en\"\n",
29
+ "ru = \"ru\"\n",
30
+ "datasets_dir = Path(\"datasets\")\n",
31
+ "test_filename = \"arxiv_test\"\n",
32
+ "test_dataset_filename = {\n",
33
+ " en: datasets_dir / en / test_filename,\n",
34
+ " ru: datasets_dir / ru / test_filename,\n",
35
+ "}"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "test_datasets = {}\n",
45
+ "for lang in (en, ru):\n",
46
+ " csv_file = str(test_dataset_filename[lang]) + \".csv\"\n",
47
+ " json_file = str(test_dataset_filename[lang]) + \".json\"\n",
48
+ " if Path(csv_file).exists():\n",
49
+ " test_datasets[lang] = pd.read_csv(csv_file)\n",
50
+ " else:\n",
51
+ " test_datasets[lang] = pd.read_json(json_file, lines=True)"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "test_results_filename = Path(\"test_results.json\")\n",
61
+ "if test_results_filename.exists():\n",
62
+ " with open(test_results_filename, \"r\") as f:\n",
63
+ " test_results = json.load(f)\n",
64
+ "else:\n",
65
+ " test_results = {}"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "def pred_to_1d(pred):\n",
75
+ " return pred.idxmax(axis=1)\n",
76
+ "\n",
77
+ "\n",
78
+ "def true_to_nd(true, columns):\n",
79
+ " columns = list(columns)\n",
80
+ " true_arr = np.zeros((len(true), len(columns)))\n",
81
+ " column_numbers = true.apply(lambda label: columns.index(label)).to_numpy()\n",
82
+ " one_inds = np.column_stack((np.arange(len(true)), column_numbers))\n",
83
+ " true_arr[one_inds] = 1\n",
84
+ " true = pd.DataFrame(true_arr, columns=columns)\n",
85
+ " return true\n",
86
+ "\n",
87
+ "\n",
88
+ "def accuracy(pred, true):\n",
89
+ " return accuracy_score(true, pred_to_1d(pred))\n",
90
+ "\n",
91
+ "\n",
92
+ "def f1(pred, true):\n",
93
+ " return f1_score(true, pred_to_1d(pred), average=\"macro\")\n",
94
+ "\n",
95
+ "\n",
96
+ "def cross_entropy(pred, true):\n",
97
+ " pred = pd.DataFrame(\n",
98
+ " pred.to_numpy() / pred.sum(axis=1).to_numpy()[:, None], columns=pred.columns\n",
99
+ " )\n",
100
+ " return log_loss(true_to_nd(true, pred.columns), pred)"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "metrics = {\"Macro F1\": f1, \"Accuracy\": accuracy, \"Cross-entropy loss\": cross_entropy}"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": null,
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "predications_dir = Path(\"pred\")\n",
119
+ "predications_dir.mkdir(exist_ok=True)"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": null,
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "def canonicalize_label(label):\n",
129
+ " if \".\" in label:\n",
130
+ " return label[: label.index(\".\")]\n",
131
+ " return label\n",
132
+ "\n",
133
+ "\n",
134
+ "def predict(model_name, model, dataset: pd.DataFrame, batch_size=32, first: int = 3000):\n",
135
+ " label = \"category\"\n",
136
+ " all_labels = list(dataset[label].unique())\n",
137
+ " if first is not None:\n",
138
+ " dataset = dataset[:first]\n",
139
+ " true = dataset[label]\n",
140
+ " prediction_file_path = predications_dir / (model_name + \".csv\")\n",
141
+ " dataset_size = len(dataset)\n",
142
+ " if not prediction_file_path.exists():\n",
143
+ " preds = []\n",
144
+ " for i in tqdm(\n",
145
+ " range(0, dataset_size + batch_size, batch_size),\n",
146
+ " desc=f\"Predicting using {model_name}\",\n",
147
+ " total=math.ceil(dataset_size / batch_size),\n",
148
+ " unit=\"batch\",\n",
149
+ " ):\n",
150
+ " data = dataset.iloc[i : i + batch_size]\n",
151
+ " if data.empty:\n",
152
+ " break\n",
153
+ " data = Dataset.from_pandas(data)\n",
154
+ " batch_pred = model(data)\n",
155
+ " batch_pred_canonicalised = []\n",
156
+ " for paper_pred in batch_pred:\n",
157
+ " labels_dict = {}\n",
158
+ " for label_score in paper_pred:\n",
159
+ " label = canonicalize_label(label_score[\"label\"])\n",
160
+ " if label not in all_labels:\n",
161
+ " return None, None\n",
162
+ " labels_dict[label] = label_score[\"score\"]\n",
163
+ " batch_pred_canonicalised.append(labels_dict)\n",
164
+ " preds.extend(batch_pred_canonicalised)\n",
165
+ " else:\n",
166
+ " preds = pd.read_csv(prediction_file_path, index_col=0)\n",
167
+ " preds = pd.DataFrame(preds).fillna(0)\n",
168
+ " for label in all_labels:\n",
169
+ " if label not in preds.columns:\n",
170
+ " preds[label] = 0\n",
171
+ " preds = preds.reindex(sorted(preds.columns), axis=1)\n",
172
+ " if not prediction_file_path.exists():\n",
173
+ " preds.to_csv(prediction_file_path)\n",
174
+ " return preds, true\n",
175
+ "\n",
176
+ "\n",
177
+ "for lang, name_get_model in language_to_models.items():\n",
178
+ " lang_results = test_results.setdefault(lang, {})\n",
179
+ " for metric_name, metic in metrics.items():\n",
180
+ " metrics_results = lang_results.setdefault(metric_name, {})\n",
181
+ " for model_name, get_model in name_get_model.items():\n",
182
+ " model_name = model_name.replace(\"/\", \".\")\n",
183
+ " if model_name not in metrics_results:\n",
184
+ " test_size = 3000 if en == lang else 500\n",
185
+ " pred, true = predict(model_name, get_model(), test_datasets[lang], first=test_size)\n",
186
+ " if pred is None:\n",
187
+ " print(f\"{model_name} does not produce labels that we can estimate\")\n",
188
+ " continue\n",
189
+ " metrics_results[model_name] = metic(pred, true)\n",
190
+ " print(f\"{metric_name} for {model_name} = {metrics_results[model_name]}\")"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": null,
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "with open(test_results_filename, \"w\") as f:\n",
200
+ " json.dump(test_results, f)"
201
+ ]
202
+ }
203
+ ],
204
+ "metadata": {
205
+ "kernelspec": {
206
+ "display_name": ".venv",
207
+ "language": "python",
208
+ "name": "python3"
209
+ },
210
+ "language_info": {
211
+ "codemirror_mode": {
212
+ "name": "ipython",
213
+ "version": 3
214
+ },
215
+ "file_extension": ".py",
216
+ "mimetype": "text/x-python",
217
+ "name": "python",
218
+ "nbconvert_exporter": "python",
219
+ "pygments_lexer": "ipython3",
220
+ "version": "3.10.12"
221
+ }
222
+ },
223
+ "nbformat": 4,
224
+ "nbformat_minor": 2
225
+ }