adding translate_dataset.ipynb

#1
Files changed (1) hide show
  1. translate_dataset.ipynb +326 -0
translate_dataset.ipynb ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "colab": {
8
+ "base_uri": "https://localhost:8080/"
9
+ },
10
+ "execution": {
11
+ "iopub.execute_input": "2025-04-09T09:04:50.582374Z",
12
+ "iopub.status.busy": "2025-04-09T09:04:50.581446Z",
13
+ "iopub.status.idle": "2025-04-09T09:04:54.831276Z",
14
+ "shell.execute_reply": "2025-04-09T09:04:54.829937Z",
15
+ "shell.execute_reply.started": "2025-04-09T09:04:50.582330Z"
16
+ },
17
+ "id": "POBbLwluCMeK",
18
+ "outputId": "9589beb5-86c8-4b44-d9bd-cc3316c838c9"
19
+ },
20
+ "outputs": [],
21
+ "source": [
22
+ "%pip install kagglehub\n",
23
+ "%pip install sacremoses"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "metadata": {
30
+ "execution": {
31
+ "iopub.execute_input": "2025-04-09T09:04:54.834196Z",
32
+ "iopub.status.busy": "2025-04-09T09:04:54.833289Z",
33
+ "iopub.status.idle": "2025-04-09T09:04:58.835896Z",
34
+ "shell.execute_reply": "2025-04-09T09:04:58.834641Z",
35
+ "shell.execute_reply.started": "2025-04-09T09:04:54.834135Z"
36
+ },
37
+ "id": "BwJ36n6vZUB2",
38
+ "tags": []
39
+ },
40
+ "outputs": [],
41
+ "source": [
42
+ "from pathlib import Path\n",
43
+ "import os\n",
44
+ "from pathlib import Path\n",
45
+ "from transformers import pipeline\n",
46
+ "from tqdm import tqdm\n",
47
+ "import pandas as pd\n",
48
+ "import torch\n",
49
+ "import kagglehub\n",
50
+ "import signal"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "metadata": {
57
+ "execution": {
58
+ "iopub.execute_input": "2025-04-09T09:04:58.838507Z",
59
+ "iopub.status.busy": "2025-04-09T09:04:58.837160Z",
60
+ "iopub.status.idle": "2025-04-09T09:04:58.856737Z",
61
+ "shell.execute_reply": "2025-04-09T09:04:58.855801Z",
62
+ "shell.execute_reply.started": "2025-04-09T09:04:58.838466Z"
63
+ },
64
+ "id": "cOIT5Hu5FdT2"
65
+ },
66
+ "outputs": [],
67
+ "source": [
68
+ "class GracefulExiter:\n",
69
+ " # to catch keyboard interrupts\n",
70
+ " def __init__(self):\n",
71
+ " self.should_exit = False\n",
72
+ " signal.signal(signal.SIGINT, self.exit_gracefully)\n",
73
+ " signal.signal(signal.SIGTERM, self.exit_gracefully)\n",
74
+ "\n",
75
+ " def exit_gracefully(self, signum, frame):\n",
76
+ " print(\n",
77
+ " \"\\nReceived interrupt signal. Finishing current work and saving progress...\"\n",
78
+ " )\n",
79
+ " self.should_exit = True"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {
86
+ "execution": {
87
+ "iopub.execute_input": "2025-04-09T09:04:58.859897Z",
88
+ "iopub.status.busy": "2025-04-09T09:04:58.858860Z",
89
+ "iopub.status.idle": "2025-04-09T09:04:58.886712Z",
90
+ "shell.execute_reply": "2025-04-09T09:04:58.885792Z",
91
+ "shell.execute_reply.started": "2025-04-09T09:04:58.859858Z"
92
+ },
93
+ "id": "Fg9c5cFZZyoG"
94
+ },
95
+ "outputs": [],
96
+ "source": [
97
+ "def get_dataset():\n",
98
+ " # Download latest version\n",
99
+ " path = kagglehub.dataset_download(\"Cornell-University/arxiv\")\n",
100
+ "\n",
101
+ " print(\"Path to dataset files:\", path)\n",
102
+ "\n",
103
+ " file_name = os.listdir(path)[0]\n",
104
+ " path_to_dataset = Path(path) / file_name\n",
105
+ " data = pd.read_json(path_to_dataset, lines=True)\n",
106
+ "\n",
107
+ " # leave only the first common category\n",
108
+ " data[\"categories\"] = [category.split()[0] for category in data[\"categories\"]]\n",
109
+ " data[\"categories\"] = [category.split(\".\")[0] for category in data[\"categories\"]]\n",
110
+ "\n",
111
+ " # sort data in a proper way\n",
112
+ " counts = data.groupby(by=\"categories\")[\"title\"].count().sort_index()\n",
113
+ " unique_categories = counts.index.to_list()\n",
114
+ "\n",
115
+ " groups_same_category = {\n",
116
+ " category: data[data[\"categories\"] == category] for category in unique_categories\n",
117
+ " }\n",
118
+ "\n",
119
+ " max_group_size = counts.max()\n",
120
+ "\n",
121
+ " new_df = []\n",
122
+ "\n",
123
+ " for i in range(max_group_size):\n",
124
+ " for category in unique_categories:\n",
125
+ " if i < len(groups_same_category[category]):\n",
126
+ " new_df.append(groups_same_category[category].iloc[i])\n",
127
+ "\n",
128
+ " result_df = pd.DataFrame(new_df).reset_index()\n",
129
+ " return result_df"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "metadata": {
136
+ "execution": {
137
+ "iopub.execute_input": "2025-04-09T09:04:58.889441Z",
138
+ "iopub.status.busy": "2025-04-09T09:04:58.887873Z",
139
+ "iopub.status.idle": "2025-04-09T09:04:58.910755Z",
140
+ "shell.execute_reply": "2025-04-09T09:04:58.909796Z",
141
+ "shell.execute_reply.started": "2025-04-09T09:04:58.889390Z"
142
+ },
143
+ "id": "RqdjPXAk1dyg",
144
+ "tags": []
145
+ },
146
+ "outputs": [],
147
+ "source": [
148
+ "def translate_dataset(\n",
149
+ " starting_from=0,\n",
150
+ " count=1000,\n",
151
+ " batch_size=16,\n",
152
+ " save_interval=64,\n",
153
+ " dataset=None,\n",
154
+ " use_google_drive=False,\n",
155
+ "):\n",
156
+ " # if dataset is given the function will use it\n",
157
+ " # else it will download dataset\n",
158
+ "\n",
159
+ " # for colab to save files in your google drive\n",
160
+ " # just in case colab ending the session before you could save all the files\n",
161
+ "\n",
162
+ " # if use_google_drive:\n",
163
+ " # from google.colab import drive\n",
164
+ " # drive.mount('/content/drive')\n",
165
+ " # target_folder = Path(\"/content/drive/MyDrive/arxiv_translations\")\n",
166
+ " # else:\n",
167
+ " # target_folder = Path(\"russian_dataset\")\n",
168
+ " # target_folder.mkdir(exist_ok=True)\n",
169
+ "\n",
170
+ " target_folder = Path(\"dataset_parts\")\n",
171
+ " target_folder.mkdir(exist_ok=True)\n",
172
+ "\n",
173
+ " # to catch keyboard interrupts\n",
174
+ " exiter = GracefulExiter()\n",
175
+ "\n",
176
+ " result_df = dataset.copy()\n",
177
+ "\n",
178
+ " # download the model\n",
179
+ " translator = pipeline(\n",
180
+ " \"translation_en_to_ru\",\n",
181
+ " model=\"Helsinki-NLP/opus-mt-en-ru\",\n",
182
+ " device=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
183
+ " torch_dtype=\"auto\",\n",
184
+ " )\n",
185
+ "\n",
186
+ " def clean_text(text, max_length=512):\n",
187
+ " if pd.isna(text) or text.strip() == \"\":\n",
188
+ " return \"[EMPTY]\"\n",
189
+ " if len(text) > max_length:\n",
190
+ " text = text[:max_length]\n",
191
+ " return str(text).strip()\n",
192
+ "\n",
193
+ " def translate_batch(texts, batch_size=batch_size, max_length=512):\n",
194
+ " results = []\n",
195
+ " texts = [clean_text(text, max_length) for text in texts]\n",
196
+ " try:\n",
197
+ " for out in tqdm(\n",
198
+ " translator(texts, max_length=max_length, batch_size=batch_size),\n",
199
+ " total=len(texts),\n",
200
+ " desc=\"Translating...\",\n",
201
+ " ):\n",
202
+ " results.append(out)\n",
203
+ " except Exception as e:\n",
204
+ " print(f\"Error: {e}\")\n",
205
+ " return results\n",
206
+ "\n",
207
+ " # take the necessary interval\n",
208
+ " part_df = result_df.iloc[starting_from : starting_from + count]\n",
209
+ "\n",
210
+ " russian_data = pd.DataFrame(columns=[\"authors\", \"title\", \"abstract\", \"categories\"])\n",
211
+ "\n",
212
+ " previous_temp_file = None\n",
213
+ "\n",
214
+ " for chunk_start in range(0, count, save_interval):\n",
215
+ " if exiter.should_exit:\n",
216
+ " break\n",
217
+ "\n",
218
+ " chunk_end = min(chunk_start + save_interval, count)\n",
219
+ " print(f\"Processing records {chunk_start} to {chunk_end}...\")\n",
220
+ "\n",
221
+ " chunk_df = part_df.iloc[chunk_start:chunk_end]\n",
222
+ "\n",
223
+ " translated_chunk = {\n",
224
+ " \"authors\": translate_batch(chunk_df[\"authors\"].tolist()),\n",
225
+ " \"title\": translate_batch(chunk_df[\"title\"].tolist()),\n",
226
+ " \"abstract\": translate_batch(chunk_df[\"abstract\"].tolist()),\n",
227
+ " \"categories\": chunk_df[\"categories\"].tolist(),\n",
228
+ " }\n",
229
+ " if exiter.should_exit:\n",
230
+ " print(\"Interrupt detected. Saving partial results...\")\n",
231
+ " break\n",
232
+ " chunk_df_translated = pd.DataFrame(translated_chunk)\n",
233
+ " russian_data = pd.concat([russian_data, chunk_df_translated], ignore_index=True)\n",
234
+ "\n",
235
+ " # save temperory results\n",
236
+ " temp_filename = (\n",
237
+ " target_folder / f\"{starting_from}_{starting_from + chunk_end}_temp.csv\"\n",
238
+ " )\n",
239
+ " russian_data.to_csv(temp_filename, index=False)\n",
240
+ " print(f\"Saved temporary results to {temp_filename}\")\n",
241
+ "\n",
242
+ " # removing previous temporary file\n",
243
+ " if previous_temp_file is not None and previous_temp_file.exists():\n",
244
+ " previous_temp_file.unlink()\n",
245
+ " print(f\"Removed previous temporary file: {previous_temp_file}\")\n",
246
+ "\n",
247
+ " previous_temp_file = temp_filename\n",
248
+ "\n",
249
+ " if exiter.should_exit:\n",
250
+ " # keyboard interrupt\n",
251
+ " final_filename = (\n",
252
+ " target_folder\n",
253
+ " / f\"{starting_from}_{starting_from + len(russian_data)}_partial.csv\"\n",
254
+ " )\n",
255
+ " print(f\"\\nProcess interrupted. Saving partial results to {final_filename}\")\n",
256
+ " else:\n",
257
+ " final_filename = target_folder / f\"{starting_from}_{count}_final.csv\"\n",
258
+ " print(f\"\\nProcessing completed. Saving final results to {final_filename}\")\n",
259
+ "\n",
260
+ " russian_data.to_csv(final_filename, index=False)\n",
261
+ "\n",
262
+ " # remove temperorary files\n",
263
+ " if not exiter.should_exit:\n",
264
+ " for temp_file in target_folder.glob(\"*_temp.csv\"):\n",
265
+ " temp_file.unlink()\n",
266
+ " print(\"Temporary files removed.\")"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": null,
272
+ "metadata": {
273
+ "execution": {
274
+ "iopub.execute_input": "2025-04-09T09:04:58.913113Z",
275
+ "iopub.status.busy": "2025-04-09T09:04:58.911808Z"
276
+ }
277
+ },
278
+ "outputs": [],
279
+ "source": [
280
+ "df = get_dataset()"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": null,
286
+ "metadata": {
287
+ "colab": {
288
+ "base_uri": "https://localhost:8080/"
289
+ },
290
+ "id": "mlO-3KoY8uT6",
291
+ "outputId": "bb555bc7-6ad4-43ef-d096-06ef01b07525",
292
+ "tags": []
293
+ },
294
+ "outputs": [],
295
+ "source": [
296
+ "translate_dataset(\n",
297
+ " starting_from=0, count=50_000, dataset=df, batch_size=128, save_interval=512\n",
298
+ ")"
299
+ ]
300
+ }
301
+ ],
302
+ "metadata": {
303
+ "colab": {
304
+ "provenance": []
305
+ },
306
+ "kernelspec": {
307
+ "display_name": "DataSphere Kernel",
308
+ "language": "python",
309
+ "name": "python3"
310
+ },
311
+ "language_info": {
312
+ "codemirror_mode": {
313
+ "name": "ipython",
314
+ "version": 3
315
+ },
316
+ "file_extension": ".py",
317
+ "mimetype": "text/x-python",
318
+ "name": "python",
319
+ "nbconvert_exporter": "python",
320
+ "pygments_lexer": "ipython3",
321
+ "version": "3.10.12"
322
+ }
323
+ },
324
+ "nbformat": 4,
325
+ "nbformat_minor": 4
326
+ }