NadavShaked commited on
Commit
91da6cc
โ€ข
1 Parent(s): c55ba1a

Upload 7 files

Browse files
Files changed (7) hide show
  1. handler.py +126 -0
  2. main.py +596 -0
  3. src/models.py +74 -0
  4. src/models_utils.py +561 -0
  5. src/plot_helpers.py +58 -0
  6. src/running_params.py +3 -0
  7. src/utiles_data.py +737 -0
handler.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoConfig, AutoTokenizer
3
+ from src.models import DNikudModel, ModelConfig
4
+ from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
5
+ from src.utiles_data import Nikud, NikudDataset
6
+ from src.models_utils import predict_single, predict
7
+ import torch
8
+ import os
9
+ from tqdm import tqdm
10
+
11
+
12
+ class EndpointHandler:
13
+ def __init__(self, path=""):
14
+ self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ self.tokenizer = AutoTokenizer.from_pretrained("tau/tavbert-he")
17
+ dir_model_config = os.path.join("models", "config.yml")
18
+ self.config = ModelConfig.load_from_file(dir_model_config)
19
+ self.model = DNikudModel(
20
+ self.config,
21
+ len(Nikud.label_2_id["nikud"]),
22
+ len(Nikud.label_2_id["dagesh"]),
23
+ len(Nikud.label_2_id["sin"]),
24
+ device=self.DEVICE,
25
+ ).to(self.DEVICE)
26
+ state_dict_model = self.model.state_dict()
27
+ state_dict_model.update(torch.load("./models/Dnikud_best_model.pth"))
28
+ self.model.load_state_dict(state_dict_model)
29
+ self.max_length = MAX_LENGTH_SEN
30
+
31
+ def back_2_text(self, labels, text):
32
+ nikud = Nikud()
33
+ new_line = ""
34
+
35
+ for indx_char, c in enumerate(text):
36
+ new_line += (
37
+ c
38
+ + nikud.id_2_char(labels[indx_char][1][1], "dagesh")
39
+ + nikud.id_2_char(labels[indx_char][1][2], "sin")
40
+ + nikud.id_2_char(labels[indx_char][1][0], "nikud")
41
+ )
42
+ print(indx_char, c)
43
+ print(labels)
44
+ return new_line
45
+
46
+ def prepare_data(self, data, name="train"):
47
+ print("Data = ", data)
48
+ dataset = []
49
+ for index, (sentence, label) in tqdm(
50
+ enumerate(data), desc=f"Prepare data {name}"
51
+ ):
52
+ encoded_sequence = self.tokenizer.encode_plus(
53
+ sentence,
54
+ add_special_tokens=True,
55
+ max_length=self.max_length,
56
+ padding="max_length",
57
+ truncation=True,
58
+ return_attention_mask=True,
59
+ return_tensors="pt",
60
+ )
61
+ label_lists = [
62
+ [letter.nikud, letter.dagesh, letter.sin] for letter in label
63
+ ]
64
+ label = torch.tensor(
65
+ [
66
+ [
67
+ Nikud.PAD_OR_IRRELEVANT,
68
+ Nikud.PAD_OR_IRRELEVANT,
69
+ Nikud.PAD_OR_IRRELEVANT,
70
+ ]
71
+ ]
72
+ + label_lists[: (self.max_length - 1)]
73
+ + [
74
+ [
75
+ Nikud.PAD_OR_IRRELEVANT,
76
+ Nikud.PAD_OR_IRRELEVANT,
77
+ Nikud.PAD_OR_IRRELEVANT,
78
+ ]
79
+ for i in range(self.max_length - len(label) - 1)
80
+ ]
81
+ )
82
+
83
+ dataset.append(
84
+ (
85
+ encoded_sequence["input_ids"][0],
86
+ encoded_sequence["attention_mask"][0],
87
+ label,
88
+ )
89
+ )
90
+
91
+ self.prepered_data = dataset
92
+
93
+ def predict_single_text(
94
+ self,
95
+ text,
96
+ ):
97
+ dataset = NikudDataset(tokenizer=self.tokenizer, max_length=MAX_LENGTH_SEN)
98
+ data, orig_data = dataset.read_single_text(text)
99
+ print("data", data, len(data))
100
+ dataset.prepare_data(name="inference")
101
+ mtb_prediction_dl = torch.utils.data.DataLoader(
102
+ dataset.prepered_data, batch_size=BATCH_SIZE
103
+ )
104
+ # print("dataset", dataset, len(dataset))
105
+ # data = self.tokenizer(text, return_tensors="pt")
106
+ all_labels = predict(self.model, mtb_prediction_dl, self.DEVICE)
107
+ text_data_with_labels = dataset.back_2_text(labels=all_labels)
108
+ # all_labels = predict_single(self.model, dataset, self.DEVICE)
109
+ return text_data_with_labels
110
+
111
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
112
+ """
113
+ data args:
114
+ """
115
+
116
+ # get inputs
117
+ inputs = data.pop("text", data)
118
+
119
+ # run normal prediction
120
+ prediction = self.predict_single_text(inputs)
121
+
122
+ # result = []
123
+ # for pred in prediction:
124
+ # result.append(self.back_2_text(pred, inputs))
125
+ # result = self.back_2_text(prediction, inputs)
126
+ return prediction
main.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general
2
+ import argparse
3
+ import os
4
+ import sys
5
+ from datetime import datetime
6
+ import logging
7
+ from logging.handlers import RotatingFileHandler
8
+ from pathlib import Path
9
+
10
+ # ML
11
+ import torch
12
+ import torch.nn as nn
13
+ from transformers import AutoConfig, AutoTokenizer
14
+
15
+ # DL
16
+ from src.models import DNikudModel, ModelConfig
17
+ from src.models_utils import training, evaluate, predict
18
+ from src.plot_helpers import (
19
+ generate_plot_by_nikud_dagesh_sin_dict,
20
+ generate_word_and_letter_accuracy_plot,
21
+ )
22
+ from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
23
+ from src.utiles_data import (
24
+ NikudDataset,
25
+ Nikud,
26
+ create_missing_folders,
27
+ extract_text_to_compare_nakdimon,
28
+ )
29
+
30
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
+ assert DEVICE == "cuda"
32
+
33
+
34
+ def get_logger(
35
+ log_level, name_func, date_time=datetime.now().strftime("%d_%m_%y__%H_%M")
36
+ ):
37
+ log_location = os.path.join(
38
+ os.path.join(Path(__file__).parent, "logging"),
39
+ f"log_model_{name_func}_{date_time}",
40
+ )
41
+ create_missing_folders(log_location)
42
+
43
+ log_format = "%(asctime)s %(levelname)-8s Thread_%(thread)-6d ::: %(funcName)s(%(lineno)d) ::: %(message)s"
44
+ logger = logging.getLogger("algo")
45
+ logger.setLevel(getattr(logging, log_level))
46
+ cnsl_log_formatter = logging.Formatter(log_format)
47
+ cnsl_handler = logging.StreamHandler()
48
+ cnsl_handler.setFormatter(cnsl_log_formatter)
49
+ cnsl_handler.setLevel(log_level)
50
+ logger.addHandler(cnsl_handler)
51
+
52
+ create_missing_folders(log_location)
53
+
54
+ file_location = os.path.join(log_location, "Diacritization_Model_DEBUG.log")
55
+ file_log_formatter = logging.Formatter(log_format)
56
+
57
+ SINGLE_LOG_SIZE = 2 * 1024 * 1024 # in Bytes
58
+ MAX_LOG_FILES = 20
59
+ file_handler = RotatingFileHandler(
60
+ file_location, mode="a", maxBytes=SINGLE_LOG_SIZE, backupCount=MAX_LOG_FILES
61
+ )
62
+ file_handler.setFormatter(file_log_formatter)
63
+ file_handler.setLevel(log_level)
64
+ logger.addHandler(file_handler)
65
+
66
+ return logger
67
+
68
+
69
+ def evaluate_text(
70
+ path,
71
+ dnikud_model,
72
+ tokenizer_tavbert,
73
+ logger,
74
+ plots_folder=None,
75
+ batch_size=BATCH_SIZE,
76
+ ):
77
+ path_name = os.path.basename(path)
78
+
79
+ msg = f"evaluate text: {path_name} on D-nikud Model"
80
+ logger.debug(msg)
81
+
82
+ if os.path.isfile(path):
83
+ dataset = NikudDataset(
84
+ tokenizer_tavbert, file=path, logger=logger, max_length=MAX_LENGTH_SEN
85
+ )
86
+ elif os.path.isdir(path):
87
+ dataset = NikudDataset(
88
+ tokenizer_tavbert, folder=path, logger=logger, max_length=MAX_LENGTH_SEN
89
+ )
90
+ else:
91
+ raise Exception("input path doesnt exist")
92
+
93
+ dataset.prepare_data(name="evaluate")
94
+ mtb_dl = torch.utils.data.DataLoader(dataset.prepered_data, batch_size=batch_size)
95
+
96
+ word_level_correct, letter_level_correct_dev = evaluate(
97
+ dnikud_model, mtb_dl, plots_folder, device=DEVICE
98
+ )
99
+
100
+ msg = (
101
+ f"Dnikud Model\n{path_name} evaluate\nLetter level accuracy:{letter_level_correct_dev}\n"
102
+ f"Word level accuracy: {word_level_correct}"
103
+ )
104
+ logger.debug(msg)
105
+
106
+
107
+ def predict_text(
108
+ text_file,
109
+ tokenizer_tavbert,
110
+ output_file,
111
+ logger,
112
+ dnikud_model,
113
+ compare_nakdimon=False,
114
+ ):
115
+ dataset = NikudDataset(
116
+ tokenizer_tavbert, file=text_file, logger=logger, max_length=MAX_LENGTH_SEN
117
+ )
118
+
119
+ dataset.prepare_data(name="prediction")
120
+ mtb_prediction_dl = torch.utils.data.DataLoader(
121
+ dataset.prepered_data, batch_size=BATCH_SIZE
122
+ )
123
+ all_labels = predict(dnikud_model, mtb_prediction_dl, DEVICE)
124
+ text_data_with_labels = dataset.back_2_text(labels=all_labels)
125
+
126
+ if output_file is None:
127
+ for line in text_data_with_labels:
128
+ print(line)
129
+ else:
130
+ with open(output_file, "w", encoding="utf-8") as f:
131
+ if compare_nakdimon:
132
+ f.write(extract_text_to_compare_nakdimon(text_data_with_labels))
133
+ else:
134
+ f.write(text_data_with_labels)
135
+
136
+
137
+ def predict_folder(
138
+ folder,
139
+ output_folder,
140
+ logger,
141
+ tokenizer_tavbert,
142
+ dnikud_model,
143
+ compare_nakdimon=False,
144
+ ):
145
+ create_missing_folders(output_folder)
146
+
147
+ for filename in os.listdir(folder):
148
+ file_path = os.path.join(folder, filename)
149
+
150
+ if filename.lower().endswith(".txt") and os.path.isfile(file_path):
151
+ output_file = os.path.join(output_folder, filename)
152
+ predict_text(
153
+ file_path,
154
+ output_file=output_file,
155
+ logger=logger,
156
+ tokenizer_tavbert=tokenizer_tavbert,
157
+ dnikud_model=dnikud_model,
158
+ compare_nakdimon=compare_nakdimon,
159
+ )
160
+ elif (
161
+ os.path.isdir(file_path) and filename != ".git" and filename != "README.md"
162
+ ):
163
+ sub_folder = file_path
164
+ sub_folder_output = os.path.join(output_folder, filename)
165
+ predict_folder(
166
+ sub_folder,
167
+ sub_folder_output,
168
+ logger,
169
+ tokenizer_tavbert,
170
+ dnikud_model,
171
+ compare_nakdimon=compare_nakdimon,
172
+ )
173
+
174
+
175
+ def update_compare_folder(folder, output_folder):
176
+ create_missing_folders(output_folder)
177
+
178
+ for filename in os.listdir(folder):
179
+ file_path = os.path.join(folder, filename)
180
+
181
+ if filename.lower().endswith(".txt") and os.path.isfile(file_path):
182
+ output_file = os.path.join(output_folder, filename)
183
+ with open(file_path, "r", encoding="utf-8") as f:
184
+ text_data_with_labels = f.read()
185
+ with open(output_file, "w", encoding="utf-8") as f:
186
+ f.write(extract_text_to_compare_nakdimon(text_data_with_labels))
187
+ elif os.path.isdir(file_path) and filename != ".git":
188
+ sub_folder = file_path
189
+ sub_folder_output = os.path.join(output_folder, filename)
190
+ update_compare_folder(sub_folder, sub_folder_output)
191
+
192
+
193
+ def check_files_excepted(folder):
194
+ for filename in os.listdir(folder):
195
+ file_path = os.path.join(folder, filename)
196
+
197
+ if filename.lower().endswith(".txt") and os.path.isfile(file_path):
198
+ try:
199
+ x = NikudDataset(None, file=file_path)
200
+ except:
201
+ print(f"failed in file: {filename}")
202
+ elif os.path.isdir(file_path) and filename != ".git":
203
+ check_files_excepted(file_path)
204
+
205
+
206
+ def do_predict(
207
+ input_path, output_path, tokenizer_tavbert, logger, dnikud_model, compare_nakdimon
208
+ ):
209
+ if os.path.isdir(input_path):
210
+ predict_folder(
211
+ input_path,
212
+ output_path,
213
+ logger,
214
+ tokenizer_tavbert,
215
+ dnikud_model,
216
+ compare_nakdimon=compare_nakdimon,
217
+ )
218
+ elif os.path.isfile(input_path):
219
+ predict_text(
220
+ input_path,
221
+ output_file=output_path,
222
+ logger=logger,
223
+ tokenizer_tavbert=tokenizer_tavbert,
224
+ dnikud_model=dnikud_model,
225
+ compare_nakdimon=compare_nakdimon,
226
+ )
227
+ else:
228
+ raise Exception("Input file not exist")
229
+
230
+
231
+ def evaluate_folder(folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder):
232
+ msg = f"evaluate sub folder: {folder_path}"
233
+ logger.info(msg)
234
+
235
+ evaluate_text(
236
+ folder_path,
237
+ dnikud_model=dnikud_model,
238
+ tokenizer_tavbert=tokenizer_tavbert,
239
+ logger=logger,
240
+ plots_folder=plots_folder,
241
+ batch_size=BATCH_SIZE,
242
+ )
243
+
244
+ msg = f"\n***************************************\n"
245
+ logger.info(msg)
246
+
247
+ for sub_folder_name in os.listdir(folder_path):
248
+ sub_folder_path = os.path.join(folder_path, sub_folder_name)
249
+
250
+ if (
251
+ not os.path.isdir(sub_folder_path)
252
+ or sub_folder_path == ".git"
253
+ or "not_use" in sub_folder_path
254
+ or "NakdanResults" in sub_folder_path
255
+ ):
256
+ continue
257
+
258
+ evaluate_folder(
259
+ sub_folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder
260
+ )
261
+
262
+
263
+ def do_evaluate(
264
+ input_path,
265
+ logger,
266
+ dnikud_model,
267
+ tokenizer_tavbert,
268
+ plots_folder,
269
+ eval_sub_folders=False,
270
+ ):
271
+ msg = f"evaluate all_data: {input_path}"
272
+ logger.info(msg)
273
+
274
+ evaluate_text(
275
+ input_path,
276
+ dnikud_model=dnikud_model,
277
+ tokenizer_tavbert=tokenizer_tavbert,
278
+ logger=logger,
279
+ plots_folder=plots_folder,
280
+ batch_size=BATCH_SIZE,
281
+ )
282
+
283
+ msg = f"\n\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n"
284
+ logger.info(msg)
285
+
286
+ if eval_sub_folders:
287
+ for sub_folder_name in os.listdir(input_path):
288
+ sub_folder_path = os.path.join(input_path, sub_folder_name)
289
+
290
+ if (
291
+ not os.path.isdir(sub_folder_path)
292
+ or sub_folder_path == ".git"
293
+ or "not_use" in sub_folder_path
294
+ or "NakdanResults" in sub_folder_path
295
+ ):
296
+ continue
297
+
298
+ evaluate_folder(
299
+ sub_folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder
300
+ )
301
+
302
+
303
+ def do_train(
304
+ logger,
305
+ plots_folder,
306
+ dir_model_config,
307
+ tokenizer_tavbert,
308
+ dnikud_model,
309
+ output_trained_model_dir,
310
+ data_folder,
311
+ n_epochs,
312
+ checkpoints_frequency,
313
+ learning_rate,
314
+ batch_size,
315
+ ):
316
+ msg = "Loading data..."
317
+ logger.debug(msg)
318
+
319
+ dataset_train = NikudDataset(
320
+ tokenizer_tavbert,
321
+ folder=os.path.join(data_folder, "train"),
322
+ logger=logger,
323
+ max_length=MAX_LENGTH_SEN,
324
+ is_train=True,
325
+ )
326
+ dataset_dev = NikudDataset(
327
+ tokenizer=tokenizer_tavbert,
328
+ folder=os.path.join(data_folder, "dev"),
329
+ logger=logger,
330
+ max_length=dataset_train.max_length,
331
+ is_train=True,
332
+ )
333
+ dataset_test = NikudDataset(
334
+ tokenizer=tokenizer_tavbert,
335
+ folder=os.path.join(data_folder, "test"),
336
+ logger=logger,
337
+ max_length=dataset_train.max_length,
338
+ is_train=True,
339
+ )
340
+
341
+ dataset_train.show_data_labels(plots_folder=plots_folder)
342
+
343
+ msg = f"Max length of data: {dataset_train.max_length}"
344
+ logger.debug(msg)
345
+
346
+ msg = (
347
+ f"Num rows in train data: {len(dataset_train.data)}, "
348
+ f"Num rows in dev data: {len(dataset_dev.data)}, "
349
+ f"Num rows in test data: {len(dataset_test.data)}"
350
+ )
351
+ logger.debug(msg)
352
+
353
+ msg = "Loading tokenizer and prepare data..."
354
+ logger.debug(msg)
355
+
356
+ dataset_train.prepare_data(name="train")
357
+ dataset_dev.prepare_data(name="dev")
358
+ dataset_test.prepare_data(name="test")
359
+
360
+ mtb_train_dl = torch.utils.data.DataLoader(
361
+ dataset_train.prepered_data, batch_size=batch_size
362
+ )
363
+ mtb_dev_dl = torch.utils.data.DataLoader(
364
+ dataset_dev.prepered_data, batch_size=batch_size
365
+ )
366
+
367
+ if not os.path.isfile(dir_model_config):
368
+ our_model_config = ModelConfig(dataset_train.max_length)
369
+ our_model_config.save_to_file(dir_model_config)
370
+
371
+ optimizer = torch.optim.Adam(dnikud_model.parameters(), lr=learning_rate)
372
+
373
+ msg = "training..."
374
+ logger.debug(msg)
375
+
376
+ criterion_nikud = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to(
377
+ DEVICE
378
+ )
379
+ criterion_dagesh = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to(
380
+ DEVICE
381
+ )
382
+ criterion_sin = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to(DEVICE)
383
+
384
+ training_params = {
385
+ "n_epochs": n_epochs,
386
+ "checkpoints_frequency": checkpoints_frequency,
387
+ }
388
+ (
389
+ best_model_details,
390
+ best_accuracy,
391
+ epochs_loss_train_values,
392
+ steps_loss_train_values,
393
+ loss_dev_values,
394
+ accuracy_dev_values,
395
+ ) = training(
396
+ dnikud_model,
397
+ mtb_train_dl,
398
+ mtb_dev_dl,
399
+ criterion_nikud,
400
+ criterion_dagesh,
401
+ criterion_sin,
402
+ training_params,
403
+ logger,
404
+ output_trained_model_dir,
405
+ optimizer,
406
+ device=DEVICE,
407
+ )
408
+
409
+ generate_plot_by_nikud_dagesh_sin_dict(
410
+ epochs_loss_train_values, "Train epochs loss", "Loss", plots_folder
411
+ )
412
+ generate_plot_by_nikud_dagesh_sin_dict(
413
+ steps_loss_train_values, "Train steps loss", "Loss", plots_folder
414
+ )
415
+ generate_plot_by_nikud_dagesh_sin_dict(
416
+ loss_dev_values, "Dev epochs loss", "Loss", plots_folder
417
+ )
418
+ generate_plot_by_nikud_dagesh_sin_dict(
419
+ accuracy_dev_values, "Dev accuracy", "Accuracy", plots_folder
420
+ )
421
+ generate_word_and_letter_accuracy_plot(
422
+ accuracy_dev_values, "Accuracy", plots_folder
423
+ )
424
+
425
+ msg = "Done"
426
+ logger.info(msg)
427
+
428
+
429
+ if __name__ == "__main__":
430
+ tokenizer_tavbert = AutoTokenizer.from_pretrained("tau/tavbert-he")
431
+
432
+ parser = argparse.ArgumentParser(
433
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
434
+ description="""Predict D-nikud""",
435
+ )
436
+ parser.add_argument(
437
+ "-l",
438
+ "--log",
439
+ dest="log_level",
440
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
441
+ default="DEBUG",
442
+ help="Set the logging level",
443
+ )
444
+ parser.add_argument(
445
+ "-m",
446
+ "--output_model_dir",
447
+ type=str,
448
+ default="models",
449
+ help="save directory for model",
450
+ )
451
+ subparsers = parser.add_subparsers(
452
+ help="sub-command help", dest="command", required=True
453
+ )
454
+
455
+ parser_predict = subparsers.add_parser("predict", help="diacritize a text files ")
456
+ parser_predict.add_argument("input_path", help="input file or folder")
457
+ parser_predict.add_argument("output_path", help="output file")
458
+ parser_predict.add_argument(
459
+ "-ptmp",
460
+ "--pretrain_model_path",
461
+ type=str,
462
+ default=os.path.join(Path(__file__).parent, "models", "Dnikud_best_model.pth"),
463
+ help="pre-train model path - use only if you want to use trained model weights",
464
+ )
465
+ parser_predict.add_argument(
466
+ "-c",
467
+ "--compare",
468
+ dest="compare_nakdimon",
469
+ default=False,
470
+ help="predict text for comparing with Nakdimon",
471
+ )
472
+ parser_predict.set_defaults(func=do_predict)
473
+
474
+ parser_evaluate = subparsers.add_parser("evaluate", help="evaluate D-nikud")
475
+ parser_evaluate.add_argument("input_path", help="input file or folder")
476
+ parser_evaluate.add_argument(
477
+ "-ptmp",
478
+ "--pretrain_model_path",
479
+ type=str,
480
+ default=os.path.join(Path(__file__).parent, "models", "Dnikud_best_model.pth"),
481
+ help="pre-train model path - use only if you want to use trained model weights",
482
+ )
483
+ parser_evaluate.add_argument(
484
+ "-df",
485
+ "--plots_folder",
486
+ dest="plots_folder",
487
+ default=os.path.join(Path(__file__).parent, "plots"),
488
+ help="set the debug folder",
489
+ )
490
+ parser_evaluate.add_argument(
491
+ "-es",
492
+ "--eval_sub_folders",
493
+ dest="eval_sub_folders",
494
+ default=False,
495
+ help="accuracy calculation includes the evaluation of sub-folders "
496
+ "within the input_path folder, providing independent assessments "
497
+ "for each subfolder.",
498
+ )
499
+ parser_evaluate.set_defaults(func=do_evaluate)
500
+
501
+ # train --n_epochs 20
502
+
503
+ parser_train = subparsers.add_parser("train", help="train D-nikud")
504
+ parser_train.add_argument(
505
+ "-ptmp",
506
+ "--pretrain_model_path",
507
+ type=str,
508
+ default=None,
509
+ help="pre-train model path - use only if you want to use trained model weights",
510
+ )
511
+ parser_train.add_argument(
512
+ "--learning_rate", type=float, default=0.001, help="Learning rate"
513
+ )
514
+ parser_train.add_argument("--batch_size", type=int, default=32, help="batch_size")
515
+ parser_train.add_argument(
516
+ "--n_epochs", type=int, default=10, help="number of epochs"
517
+ )
518
+ parser_train.add_argument(
519
+ "--data_folder",
520
+ dest="data_folder",
521
+ default=os.path.join(Path(__file__).parent, "data"),
522
+ help="Set the debug folder",
523
+ )
524
+ parser_train.add_argument(
525
+ "--checkpoints_frequency",
526
+ type=int,
527
+ default=1,
528
+ help="checkpoints frequency for save the model",
529
+ )
530
+ parser_train.add_argument(
531
+ "-df",
532
+ "--plots_folder",
533
+ dest="plots_folder",
534
+ default=os.path.join(Path(__file__).parent, "plots"),
535
+ help="Set the debug folder",
536
+ )
537
+ parser_train.set_defaults(func=do_train)
538
+
539
+ args = parser.parse_args()
540
+ kwargs = vars(args).copy()
541
+ date_time = datetime.now().strftime("%d_%m_%y__%H_%M")
542
+ logger = get_logger(kwargs["log_level"], args.command, date_time)
543
+
544
+ del kwargs["log_level"]
545
+
546
+ kwargs["tokenizer_tavbert"] = tokenizer_tavbert
547
+ kwargs["logger"] = logger
548
+
549
+ msg = "Loading model..."
550
+ logger.debug(msg)
551
+
552
+ if args.command in ["evaluate", "predict"] or (
553
+ args.command == "train" and args.pretrain_model_path is not None
554
+ ):
555
+ dir_model_config = os.path.join("models", "config.yml")
556
+ config = ModelConfig.load_from_file(dir_model_config)
557
+
558
+ dnikud_model = DNikudModel(
559
+ config,
560
+ len(Nikud.label_2_id["nikud"]),
561
+ len(Nikud.label_2_id["dagesh"]),
562
+ len(Nikud.label_2_id["sin"]),
563
+ device=DEVICE,
564
+ ).to(DEVICE)
565
+ state_dict_model = dnikud_model.state_dict()
566
+ state_dict_model.update(torch.load(args.pretrain_model_path))
567
+ dnikud_model.load_state_dict(state_dict_model)
568
+ else:
569
+ base_model_name = "tau/tavbert-he"
570
+ config = AutoConfig.from_pretrained(base_model_name)
571
+ dnikud_model = DNikudModel(
572
+ config,
573
+ len(Nikud.label_2_id["nikud"]),
574
+ len(Nikud.label_2_id["dagesh"]),
575
+ len(Nikud.label_2_id["sin"]),
576
+ pretrain_model=base_model_name,
577
+ device=DEVICE,
578
+ ).to(DEVICE)
579
+
580
+ if args.command == "train":
581
+ output_trained_model_dir = os.path.join(
582
+ kwargs["output_model_dir"], "latest", f"output_models_{date_time}"
583
+ )
584
+ create_missing_folders(output_trained_model_dir)
585
+ dir_model_config = os.path.join(kwargs["output_model_dir"], "config.yml")
586
+ kwargs["dir_model_config"] = dir_model_config
587
+ kwargs["output_trained_model_dir"] = output_trained_model_dir
588
+ del kwargs["pretrain_model_path"]
589
+ del kwargs["output_model_dir"]
590
+ kwargs["dnikud_model"] = dnikud_model
591
+
592
+ del kwargs["command"]
593
+ del kwargs["func"]
594
+ args.func(**kwargs)
595
+
596
+ sys.exit(0)
src/models.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general
2
+ import subprocess
3
+ import yaml
4
+
5
+ # ML
6
+ import torch.nn as nn
7
+ from transformers import AutoConfig, RobertaForMaskedLM, PretrainedConfig
8
+
9
+
10
+ class DNikudModel(nn.Module):
11
+ def __init__(self, config, nikud_size, dagesh_size, sin_size, pretrain_model=None, device='cpu'):
12
+ super(DNikudModel, self).__init__()
13
+
14
+ if pretrain_model is not None:
15
+ model_base = RobertaForMaskedLM.from_pretrained(pretrain_model).to(device)
16
+ else:
17
+ model_base = RobertaForMaskedLM(config=config).to(device)
18
+
19
+ self.model = model_base.roberta
20
+ for name, param in self.model.named_parameters():
21
+ param.requires_grad = False
22
+
23
+ self.lstm1 = nn.LSTM(config.hidden_size, config.hidden_size, bidirectional=True, dropout=0.1, batch_first=True)
24
+ self.lstm2 = nn.LSTM(2 * config.hidden_size, config.hidden_size, bidirectional=True, dropout=0.1, batch_first=True)
25
+ self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
26
+ self.out_n = nn.Linear(config.hidden_size, nikud_size)
27
+ self.out_d = nn.Linear(config.hidden_size, dagesh_size)
28
+ self.out_s = nn.Linear(config.hidden_size, sin_size)
29
+
30
+ def forward(self, input_ids, attention_mask):
31
+ last_hidden_state = self.model(input_ids, attention_mask=attention_mask).last_hidden_state
32
+ lstm1, _ = self.lstm1(last_hidden_state)
33
+ lstm2, _ = self.lstm2(lstm1)
34
+ dense = self.dense(lstm2)
35
+
36
+ nikud = self.out_n(dense)
37
+ dagesh = self.out_d(dense)
38
+ sin = self.out_s(dense)
39
+
40
+ return nikud, dagesh, sin
41
+
42
+
43
+ def get_git_commit_hash():
44
+ try:
45
+ commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
46
+ return commit_hash
47
+ except subprocess.CalledProcessError:
48
+ # This will be raised if you're not in a Git repository
49
+ print("Not inside a Git repository!")
50
+ return None
51
+
52
+
53
+ class ModelConfig(PretrainedConfig):
54
+ def __init__(self, max_length=None, dict=None):
55
+ super(ModelConfig, self).__init__()
56
+ if dict is None:
57
+ self.__dict__.update(AutoConfig.from_pretrained("tau/tavbert-he").__dict__)
58
+ self.max_length = max_length
59
+ self._commit_hash = get_git_commit_hash()
60
+ else:
61
+ self.__dict__.update(dict)
62
+
63
+ def print(self):
64
+ print(self.__dict__)
65
+
66
+ def save_to_file(self, file_path):
67
+ with open(file_path, "w") as yaml_file:
68
+ yaml.dump(self.__dict__, yaml_file, default_flow_style=False)
69
+
70
+ @classmethod
71
+ def load_from_file(cls, file_path):
72
+ with open(file_path, "r") as yaml_file:
73
+ config_dict = yaml.safe_load(yaml_file)
74
+ return cls(dict=config_dict)
src/models_utils.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general
2
+ import json
3
+ import os
4
+
5
+ # ML
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+
10
+ # visual
11
+ import matplotlib.pyplot as plt
12
+ import seaborn as sns
13
+ from sklearn.metrics import confusion_matrix
14
+ from tqdm import tqdm
15
+
16
+ from src.running_params import DEBUG_MODE
17
+ from src.utiles_data import Nikud, create_missing_folders
18
+
19
+ CLASSES_LIST = ["nikud", "dagesh", "sin"]
20
+
21
+
22
+ def calc_num_correct_words(input, letter_correct_mask):
23
+ SPACE_TOKEN = 104
24
+ START_SENTENCE_TOKEN = 1
25
+ END_SENTENCE_TOKEN = 2
26
+
27
+ correct_words_count = 0
28
+ words_count = 0
29
+ for index in range(input.shape[0]):
30
+ input[index][np.where(input[index] == SPACE_TOKEN)[0]] = 0
31
+ input[index][np.where(input[index] == START_SENTENCE_TOKEN)[0]] = 0
32
+ input[index][np.where(input[index] == END_SENTENCE_TOKEN)[0]] = 0
33
+ words_end_index = np.concatenate(
34
+ (np.array([-1]), np.where(input[index] == 0)[0])
35
+ )
36
+ is_correct_words_array = [
37
+ bool(
38
+ letter_correct_mask[index][
39
+ list(range((words_end_index[s] + 1), words_end_index[s + 1]))
40
+ ].all()
41
+ )
42
+ for s in range(len(words_end_index) - 1)
43
+ if words_end_index[s + 1] - (words_end_index[s] + 1) > 1
44
+ ]
45
+ correct_words_count += np.array(is_correct_words_array).sum()
46
+ words_count += len(is_correct_words_array)
47
+
48
+ return correct_words_count, words_count
49
+
50
+
51
+ def predict(model, data_loader, device="cpu"):
52
+ model.to(device)
53
+
54
+ all_labels = None
55
+ with torch.no_grad():
56
+ for index_data, data in enumerate(data_loader):
57
+ (inputs, attention_mask, labels_demo) = data
58
+ inputs = inputs.to(device)
59
+ attention_mask = attention_mask.to(device)
60
+ labels_demo = labels_demo.to(device)
61
+
62
+ mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1
63
+ mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1
64
+ mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1
65
+
66
+ nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
67
+
68
+ pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape(
69
+ inputs.shape[0], inputs.shape[1], 1
70
+ )
71
+ pred_dagesh = np.array(torch.max(dagesh_probs, 2).indices.cpu()).reshape(
72
+ inputs.shape[0], inputs.shape[1], 1
73
+ )
74
+ pred_sin = np.array(torch.max(sin_probs, 2).indices.cpu()).reshape(
75
+ inputs.shape[0], inputs.shape[1], 1
76
+ )
77
+
78
+ pred_nikud[mask_cant_be_nikud] = -1
79
+ pred_dagesh[mask_cant_be_dagesh] = -1
80
+ pred_sin[mask_cant_be_sin] = -1
81
+
82
+ pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2)
83
+
84
+ if all_labels is None:
85
+ all_labels = pred_labels
86
+ else:
87
+ all_labels = np.concatenate((all_labels, pred_labels), axis=0)
88
+
89
+ return all_labels
90
+
91
+
92
+ def predict_single(model, data, device="cpu"):
93
+ # model.to(device)
94
+
95
+ all_labels = None
96
+ with torch.no_grad():
97
+ (inputs, attention_mask, labels_demo) = data
98
+ inputs = inputs.to(device)
99
+ attention_mask = attention_mask.to(device)
100
+ labels_demo = labels_demo.to(device)
101
+
102
+ mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1
103
+ mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1
104
+ mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1
105
+
106
+ nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
107
+ print("model output: ", nikud_probs, dagesh_probs, sin_probs)
108
+
109
+ pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape(
110
+ inputs.shape[0], inputs.shape[1], 1
111
+ )
112
+ pred_dagesh = np.array(torch.max(dagesh_probs, 2).indices.cpu()).reshape(
113
+ inputs.shape[0], inputs.shape[1], 1
114
+ )
115
+ pred_sin = np.array(torch.max(sin_probs, 2).indices.cpu()).reshape(
116
+ inputs.shape[0], inputs.shape[1], 1
117
+ )
118
+
119
+ pred_nikud[mask_cant_be_nikud] = -1
120
+ pred_dagesh[mask_cant_be_dagesh] = -1
121
+ pred_sin[mask_cant_be_sin] = -1
122
+ # print(pred_nikud, pred_dagesh, pred_sin)
123
+ pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2)
124
+ print(pred_labels)
125
+ if all_labels is None:
126
+ all_labels = pred_labels
127
+ else:
128
+ all_labels = np.concatenate((all_labels, pred_labels), axis=0)
129
+
130
+ return all_labels
131
+
132
+
133
+ def training(
134
+ model,
135
+ train_loader,
136
+ dev_loader,
137
+ criterion_nikud,
138
+ criterion_dagesh,
139
+ criterion_sin,
140
+ training_params,
141
+ logger,
142
+ output_model_path,
143
+ optimizer,
144
+ device="cpu",
145
+ ):
146
+ max_length = None
147
+ best_accuracy = 0.0
148
+
149
+ logger.info(f"start training with training_params: {training_params}")
150
+ model = model.to(device)
151
+
152
+ criteria = {
153
+ "nikud": criterion_nikud.to(device),
154
+ "dagesh": criterion_dagesh.to(device),
155
+ "sin": criterion_sin.to(device),
156
+ }
157
+
158
+ output_checkpoints_path = os.path.join(output_model_path, "checkpoints")
159
+ create_missing_folders(output_checkpoints_path)
160
+
161
+ train_steps_loss_values = {"nikud": [], "dagesh": [], "sin": []}
162
+ train_epochs_loss_values = {"nikud": [], "dagesh": [], "sin": []}
163
+ dev_loss_values = {"nikud": [], "dagesh": [], "sin": []}
164
+ dev_accuracy_values = {
165
+ "nikud": [],
166
+ "dagesh": [],
167
+ "sin": [],
168
+ "all_nikud_letter": [],
169
+ "all_nikud_word": [],
170
+ }
171
+
172
+ for epoch in tqdm(range(training_params["n_epochs"]), desc="Training"):
173
+ model.train()
174
+ train_loss = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
175
+ relevant_count = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
176
+
177
+ for index_data, data in enumerate(train_loader):
178
+ (inputs, attention_mask, labels) = data
179
+
180
+ if max_length is None:
181
+ max_length = labels.shape[1]
182
+
183
+ inputs = inputs.to(device)
184
+ attention_mask = attention_mask.to(device)
185
+ labels = labels.to(device)
186
+
187
+ optimizer.zero_grad()
188
+ nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
189
+
190
+ for i, (probs, class_name) in enumerate(
191
+ zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
192
+ ):
193
+ reshaped_tensor = (
194
+ torch.transpose(probs, 1, 2)
195
+ .contiguous()
196
+ .view(probs.shape[0], probs.shape[2], probs.shape[1])
197
+ )
198
+ loss = criteria[class_name](reshaped_tensor, labels[:, :, i]).to(device)
199
+
200
+ num_relevant = (labels[:, :, i] != -1).sum()
201
+ train_loss[class_name] += loss.item() * num_relevant
202
+ relevant_count[class_name] += num_relevant
203
+
204
+ loss.backward(retain_graph=True)
205
+
206
+ for i, class_name in enumerate(CLASSES_LIST):
207
+ train_steps_loss_values[class_name].append(
208
+ float(train_loss[class_name] / relevant_count[class_name])
209
+ )
210
+
211
+ optimizer.step()
212
+ if (index_data + 1) % 100 == 0:
213
+ msg = f"epoch: {epoch} , index_data: {index_data + 1}\n"
214
+ for i, class_name in enumerate(CLASSES_LIST):
215
+ msg += f"mean loss train {class_name}: {float(train_loss[class_name] / relevant_count[class_name])}, "
216
+
217
+ logger.debug(msg[:-2])
218
+
219
+ for i, class_name in enumerate(CLASSES_LIST):
220
+ train_epochs_loss_values[class_name].append(
221
+ float(train_loss[class_name] / relevant_count[class_name])
222
+ )
223
+
224
+ for class_name in train_loss.keys():
225
+ train_loss[class_name] /= relevant_count[class_name]
226
+
227
+ msg = f"Epoch {epoch + 1}/{training_params['n_epochs']}\n"
228
+ for i, class_name in enumerate(CLASSES_LIST):
229
+ msg += f"mean loss train {class_name}: {train_loss[class_name]}, "
230
+ logger.debug(msg[:-2])
231
+
232
+ model.eval()
233
+ dev_loss = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
234
+ dev_accuracy = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
235
+ relevant_count = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
236
+ correct_preds = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
237
+ un_masks = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
238
+ predictions = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
239
+ labels_class = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
240
+
241
+ all_nikud_types_correct_preds_letter = 0.0
242
+
243
+ letter_count = 0.0
244
+ correct_words_count = 0.0
245
+ word_count = 0.0
246
+ with torch.no_grad():
247
+ for index_data, data in enumerate(dev_loader):
248
+ (inputs, attention_mask, labels) = data
249
+ inputs = inputs.to(device)
250
+ attention_mask = attention_mask.to(device)
251
+ labels = labels.to(device)
252
+
253
+ nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
254
+
255
+ for i, (probs, class_name) in enumerate(
256
+ zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
257
+ ):
258
+ reshaped_tensor = (
259
+ torch.transpose(probs, 1, 2)
260
+ .contiguous()
261
+ .view(probs.shape[0], probs.shape[2], probs.shape[1])
262
+ )
263
+ loss = criteria[class_name](reshaped_tensor, labels[:, :, i]).to(
264
+ device
265
+ )
266
+ un_masked = labels[:, :, i] != -1
267
+ num_relevant = un_masked.sum()
268
+ relevant_count[class_name] += num_relevant
269
+ _, preds = torch.max(probs, 2)
270
+ dev_loss[class_name] += loss.item() * num_relevant
271
+ correct_preds[class_name] += torch.sum(
272
+ preds[un_masked] == labels[:, :, i][un_masked]
273
+ )
274
+ un_masks[class_name] = un_masked
275
+ predictions[class_name] = preds
276
+ labels_class[class_name] = labels[:, :, i]
277
+
278
+ un_mask_all_or = torch.logical_or(
279
+ torch.logical_or(un_masks["nikud"], un_masks["dagesh"]),
280
+ un_masks["sin"],
281
+ )
282
+
283
+ correct = {
284
+ class_name: (torch.ones(un_mask_all_or.shape) == 1).to(device)
285
+ for class_name in CLASSES_LIST
286
+ }
287
+
288
+ for i, class_name in enumerate(CLASSES_LIST):
289
+ correct[class_name][un_masks[class_name]] = (
290
+ predictions[class_name][un_masks[class_name]]
291
+ == labels_class[class_name][un_masks[class_name]]
292
+ )
293
+
294
+ letter_correct_mask = torch.logical_and(
295
+ torch.logical_and(correct["sin"], correct["dagesh"]),
296
+ correct["nikud"],
297
+ )
298
+ all_nikud_types_correct_preds_letter += torch.sum(
299
+ letter_correct_mask[un_mask_all_or]
300
+ )
301
+
302
+ letter_correct_mask[~un_mask_all_or] = True
303
+ correct_num, total_words_num = calc_num_correct_words(
304
+ inputs.cpu(), letter_correct_mask
305
+ )
306
+
307
+ word_count += total_words_num
308
+ correct_words_count += correct_num
309
+ letter_count += un_mask_all_or.sum()
310
+
311
+ for class_name in CLASSES_LIST:
312
+ dev_loss[class_name] /= relevant_count[class_name]
313
+ dev_accuracy[class_name] = float(
314
+ correct_preds[class_name].double() / relevant_count[class_name]
315
+ )
316
+
317
+ dev_loss_values[class_name].append(float(dev_loss[class_name]))
318
+ dev_accuracy_values[class_name].append(float(dev_accuracy[class_name]))
319
+
320
+ dev_all_nikud_types_accuracy_letter = float(
321
+ all_nikud_types_correct_preds_letter / letter_count
322
+ )
323
+
324
+ dev_accuracy_values["all_nikud_letter"].append(
325
+ dev_all_nikud_types_accuracy_letter
326
+ )
327
+
328
+ word_all_nikud_accuracy = correct_words_count / word_count
329
+ dev_accuracy_values["all_nikud_word"].append(word_all_nikud_accuracy)
330
+
331
+ msg = (
332
+ f"Epoch {epoch + 1}/{training_params['n_epochs']}\n"
333
+ f'mean loss Dev nikud: {train_loss["nikud"]}, '
334
+ f'mean loss Dev dagesh: {train_loss["dagesh"]}, '
335
+ f'mean loss Dev sin: {train_loss["sin"]}, '
336
+ f"Dev all nikud types letter Accuracy: {dev_all_nikud_types_accuracy_letter}, "
337
+ f'Dev nikud letter Accuracy: {dev_accuracy["nikud"]}, '
338
+ f'Dev dagesh letter Accuracy: {dev_accuracy["dagesh"]}, '
339
+ f'Dev sin letter Accuracy: {dev_accuracy["sin"]}, '
340
+ f"Dev word Accuracy: {word_all_nikud_accuracy}"
341
+ )
342
+ logger.debug(msg)
343
+
344
+ save_progress_details(
345
+ dev_accuracy_values,
346
+ train_epochs_loss_values,
347
+ dev_loss_values,
348
+ train_steps_loss_values,
349
+ )
350
+
351
+ if dev_all_nikud_types_accuracy_letter > best_accuracy:
352
+ best_accuracy = dev_all_nikud_types_accuracy_letter
353
+ best_model = {
354
+ "epoch": epoch,
355
+ "model_state_dict": model.state_dict(),
356
+ "optimizer_state_dict": optimizer.state_dict(),
357
+ "loss": loss,
358
+ }
359
+
360
+ if epoch % training_params["checkpoints_frequency"] == 0:
361
+ save_checkpoint_path = os.path.join(
362
+ output_checkpoints_path, f"checkpoint_model_epoch_{epoch + 1}.pth"
363
+ )
364
+ checkpoint = {
365
+ "epoch": epoch,
366
+ "model_state_dict": model.state_dict(),
367
+ "optimizer_state_dict": optimizer.state_dict(),
368
+ "loss": loss,
369
+ }
370
+ torch.save(checkpoint["model_state_dict"], save_checkpoint_path)
371
+
372
+ save_model_path = os.path.join(output_model_path, "best_model.pth")
373
+ torch.save(best_model["model_state_dict"], save_model_path)
374
+ return (
375
+ best_model,
376
+ best_accuracy,
377
+ train_epochs_loss_values,
378
+ train_steps_loss_values,
379
+ dev_loss_values,
380
+ dev_accuracy_values,
381
+ )
382
+
383
+
384
+ def save_progress_details(
385
+ accuracy_dev_values,
386
+ epochs_loss_train_values,
387
+ loss_dev_values,
388
+ steps_loss_train_values,
389
+ ):
390
+ epochs_data_path = "epochs_data"
391
+ create_missing_folders(epochs_data_path)
392
+
393
+ save_dict_as_json(
394
+ steps_loss_train_values, epochs_data_path, "steps_loss_train_values.json"
395
+ )
396
+ save_dict_as_json(
397
+ epochs_loss_train_values, epochs_data_path, "epochs_loss_train_values.json"
398
+ )
399
+ save_dict_as_json(loss_dev_values, epochs_data_path, "loss_dev_values.json")
400
+ save_dict_as_json(accuracy_dev_values, epochs_data_path, "accuracy_dev_values.json")
401
+
402
+
403
+ def save_dict_as_json(dict, file_path, file_name):
404
+ json_data = json.dumps(dict, indent=4)
405
+ with open(os.path.join(file_path, file_name), "w") as json_file:
406
+ json_file.write(json_data)
407
+
408
+
409
+ def evaluate(model, test_data, plots_folder=None, device="cpu"):
410
+ model.to(device)
411
+ model.eval()
412
+
413
+ true_labels = {"nikud": [], "dagesh": [], "sin": []}
414
+ predictions = {"nikud": 0, "dagesh": 0, "sin": 0}
415
+ predicted_labels_2_report = {"nikud": [], "dagesh": [], "sin": []}
416
+ not_masks = {"nikud": 0, "dagesh": 0, "sin": 0}
417
+ correct_preds = {"nikud": 0, "dagesh": 0, "sin": 0}
418
+ relevant_count = {"nikud": 0, "dagesh": 0, "sin": 0}
419
+ labels_class = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
420
+
421
+ all_nikud_types_letter_level_correct = 0.0
422
+ nikud_letter_level_correct = 0.0
423
+ dagesh_letter_level_correct = 0.0
424
+ sin_letter_level_correct = 0.0
425
+
426
+ letters_count = 0.0
427
+ words_count = 0.0
428
+ correct_words_count = 0.0
429
+ with torch.no_grad():
430
+ for index_data, data in enumerate(test_data):
431
+ if DEBUG_MODE and index_data > 100:
432
+ break
433
+
434
+ (inputs, attention_mask, labels) = data
435
+
436
+ inputs = inputs.to(device)
437
+ attention_mask = attention_mask.to(device)
438
+ labels = labels.to(device)
439
+
440
+ nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
441
+
442
+ for i, (probs, class_name) in enumerate(
443
+ zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
444
+ ):
445
+ labels_class[class_name] = labels[:, :, i]
446
+ not_masked = labels_class[class_name] != -1
447
+ num_relevant = not_masked.sum()
448
+ relevant_count[class_name] += num_relevant
449
+ _, preds = torch.max(probs, 2)
450
+ correct_preds[class_name] += torch.sum(
451
+ preds[not_masked] == labels_class[class_name][not_masked]
452
+ )
453
+ predictions[class_name] = preds
454
+ not_masks[class_name] = not_masked
455
+
456
+ if len(true_labels[class_name]) == 0:
457
+ true_labels[class_name] = (
458
+ labels_class[class_name][not_masked].cpu().numpy()
459
+ )
460
+ else:
461
+ true_labels[class_name] = np.concatenate(
462
+ (
463
+ true_labels[class_name],
464
+ labels_class[class_name][not_masked].cpu().numpy(),
465
+ )
466
+ )
467
+
468
+ if len(predicted_labels_2_report[class_name]) == 0:
469
+ predicted_labels_2_report[class_name] = (
470
+ preds[not_masked].cpu().numpy()
471
+ )
472
+ else:
473
+ predicted_labels_2_report[class_name] = np.concatenate(
474
+ (
475
+ predicted_labels_2_report[class_name],
476
+ preds[not_masked].cpu().numpy(),
477
+ )
478
+ )
479
+
480
+ not_mask_all_or = torch.logical_or(
481
+ torch.logical_or(not_masks["nikud"], not_masks["dagesh"]),
482
+ not_masks["sin"],
483
+ )
484
+
485
+ correct_nikud = (torch.ones(not_mask_all_or.shape) == 1).to(device)
486
+ correct_dagesh = (torch.ones(not_mask_all_or.shape) == 1).to(device)
487
+ correct_sin = (torch.ones(not_mask_all_or.shape) == 1).to(device)
488
+
489
+ correct_nikud[not_masks["nikud"]] = (
490
+ predictions["nikud"][not_masks["nikud"]]
491
+ == labels_class["nikud"][not_masks["nikud"]]
492
+ )
493
+ correct_dagesh[not_masks["dagesh"]] = (
494
+ predictions["dagesh"][not_masks["dagesh"]]
495
+ == labels_class["dagesh"][not_masks["dagesh"]]
496
+ )
497
+ correct_sin[not_masks["sin"]] = (
498
+ predictions["sin"][not_masks["sin"]]
499
+ == labels_class["sin"][not_masks["sin"]]
500
+ )
501
+
502
+ letter_correct_mask = torch.logical_and(
503
+ torch.logical_and(correct_sin, correct_dagesh), correct_nikud
504
+ )
505
+ all_nikud_types_letter_level_correct += torch.sum(
506
+ letter_correct_mask[not_mask_all_or]
507
+ )
508
+
509
+ letter_correct_mask[~not_mask_all_or] = True
510
+ total_correct_count, total_words_num = calc_num_correct_words(
511
+ inputs.cpu(), letter_correct_mask
512
+ )
513
+
514
+ words_count += total_words_num
515
+ correct_words_count += total_correct_count
516
+
517
+ letters_count += not_mask_all_or.sum()
518
+
519
+ nikud_letter_level_correct += torch.sum(correct_nikud[not_mask_all_or])
520
+ dagesh_letter_level_correct += torch.sum(correct_dagesh[not_mask_all_or])
521
+ sin_letter_level_correct += torch.sum(correct_sin[not_mask_all_or])
522
+
523
+ for i, name in enumerate(CLASSES_LIST):
524
+ index_labels = np.unique(true_labels[name])
525
+ cm = confusion_matrix(
526
+ true_labels[name], predicted_labels_2_report[name], labels=index_labels
527
+ )
528
+
529
+ vowel_label = [Nikud.id_2_label[name][l] for l in index_labels]
530
+ unique_vowels_names = [
531
+ Nikud.sign_2_name[int(vowel)] for vowel in vowel_label if vowel != "WITHOUT"
532
+ ]
533
+ if "WITHOUT" in vowel_label:
534
+ unique_vowels_names += ["WITHOUT"]
535
+ cm_df = pd.DataFrame(cm, index=unique_vowels_names, columns=unique_vowels_names)
536
+
537
+ # Display confusion matrix
538
+ plt.figure(figsize=(10, 8))
539
+ sns.heatmap(cm_df, annot=True, cmap="Blues", fmt="d")
540
+ plt.title("Confusion Matrix")
541
+ plt.xlabel("True Label")
542
+ plt.ylabel("Predicted Label")
543
+ if plots_folder is None:
544
+ plt.show()
545
+ else:
546
+ plt.savefig(os.path.join(plots_folder, f"Confusion_Matrix_{name}.jpg"))
547
+
548
+ all_nikud_types_letter_level_correct = (
549
+ all_nikud_types_letter_level_correct / letters_count
550
+ )
551
+ all_nikud_types_word_level_correct = correct_words_count / words_count
552
+ nikud_letter_level_correct = nikud_letter_level_correct / letters_count
553
+ dagesh_letter_level_correct = dagesh_letter_level_correct / letters_count
554
+ sin_letter_level_correct = sin_letter_level_correct / letters_count
555
+ print("\n")
556
+ print(f"nikud_letter_level_correct = {nikud_letter_level_correct}")
557
+ print(f"dagesh_letter_level_correct = {dagesh_letter_level_correct}")
558
+ print(f"sin_letter_level_correct = {sin_letter_level_correct}")
559
+ print(f"word_level_correct = {all_nikud_types_word_level_correct}")
560
+
561
+ return all_nikud_types_word_level_correct, all_nikud_types_letter_level_correct
src/plot_helpers.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general
2
+ import os
3
+
4
+ # visual
5
+ import matplotlib.pyplot as plt
6
+
7
+ cols = ["precision", "recall", "f1-score", "support"]
8
+
9
+
10
+ def generate_plot_by_nikud_dagesh_sin_dict(nikud_dagesh_sin_dict, title, y_axis, plot_folder=None):
11
+ # Create a figure and axis
12
+ plt.figure(figsize=(8, 6))
13
+ plt.title(title)
14
+
15
+ ax = plt.gca()
16
+ indexes = list(range(1, len(nikud_dagesh_sin_dict["nikud"]) + 1))
17
+
18
+ # Plot data series with different colors and labels
19
+ ax.plot(indexes, nikud_dagesh_sin_dict["nikud"], color='blue', label='Nikud')
20
+ ax.plot(indexes, nikud_dagesh_sin_dict["dagesh"], color='green', label='Dagesh')
21
+ ax.plot(indexes, nikud_dagesh_sin_dict["sin"], color='red', label='Sin')
22
+
23
+ # Add legend
24
+ ax.legend()
25
+
26
+ # Set labels and title
27
+ ax.set_xlabel('Epoch')
28
+ ax.set_ylabel(y_axis)
29
+
30
+ if plot_folder is None:
31
+ plt.show()
32
+ else:
33
+ plt.savefig(os.path.join(plot_folder, f'{title.replace(" ", "_")}_plot.jpg'))
34
+
35
+
36
+ def generate_word_and_letter_accuracy_plot(word_and_letter_accuracy_dict, title, plot_folder=None):
37
+ # Create a figure and axis
38
+ plt.figure(figsize=(8, 6))
39
+ plt.title(title)
40
+
41
+ ax = plt.gca()
42
+ indexes = list(range(1, len(word_and_letter_accuracy_dict["all_nikud_letter"]) + 1))
43
+
44
+ # Plot data series with different colors and labels
45
+ ax.plot(indexes, word_and_letter_accuracy_dict["all_nikud_letter"], color='blue', label='Letter')
46
+ ax.plot(indexes, word_and_letter_accuracy_dict["all_nikud_word"], color='green', label='Word')
47
+
48
+ # Add legend
49
+ ax.legend()
50
+
51
+ # Set labels and title
52
+ ax.set_xlabel("Epoch")
53
+ ax.set_ylabel("Accuracy")
54
+
55
+ if plot_folder is None:
56
+ plt.show()
57
+ else:
58
+ plt.savefig(os.path.join(plot_folder, 'word_and_letter_accuracy_plot.jpg'))
src/running_params.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ DEBUG_MODE = False
2
+ BATCH_SIZE = 32
3
+ MAX_LENGTH_SEN = 1024
src/utiles_data.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general
2
+ import os.path
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ from typing import List, Tuple
6
+ from uuid import uuid1
7
+ import re
8
+ import glob2
9
+
10
+ # visual
11
+ import matplotlib
12
+ import matplotlib.pyplot as plt
13
+ from tqdm import tqdm
14
+
15
+ # ML
16
+ import numpy as np
17
+ import torch
18
+ from torch.utils.data import Dataset
19
+
20
+ from src.running_params import DEBUG_MODE, MAX_LENGTH_SEN
21
+
22
+ matplotlib.use("agg")
23
+ unique_key = str(uuid1())
24
+
25
+
26
+ class Nikud:
27
+ """
28
+ 1456 HEBREW POINT SHEVA
29
+ 1457 HEBREW POINT HATAF SEGOL
30
+ 1458 HEBREW POINT HATAF PATAH
31
+ 1459 HEBREW POINT HATAF QAMATS
32
+ 1460 HEBREW POINT HIRIQ
33
+ 1461 HEBREW POINT TSERE
34
+ 1462 HEBREW POINT SEGOL
35
+ 1463 HEBREW POINT PATAH
36
+ 1464 HEBREW POINT QAMATS
37
+ 1465 HEBREW POINT HOLAM
38
+ 1466 HEBREW POINT HOLAM HASER FOR VAV ***EXTENDED***
39
+ 1467 HEBREW POINT QUBUTS
40
+ 1468 HEBREW POINT DAGESH OR MAPIQ
41
+ 1469 HEBREW POINT METEG ***EXTENDED***
42
+ 1470 HEBREW PUNCTUATION MAQAF ***EXTENDED***
43
+ 1471 HEBREW POINT RAFE ***EXTENDED***
44
+ 1472 HEBREW PUNCTUATION PASEQ ***EXTENDED***
45
+ 1473 HEBREW POINT SHIN DOT
46
+ 1474 HEBREW POINT SIN DOT
47
+ """
48
+
49
+ nikud_dict = {
50
+ "SHVA": 1456,
51
+ "REDUCED_SEGOL": 1457,
52
+ "REDUCED_PATAKH": 1458,
53
+ "REDUCED_KAMATZ": 1459,
54
+ "HIRIK": 1460,
55
+ "TZEIRE": 1461,
56
+ "SEGOL": 1462,
57
+ "PATAKH": 1463,
58
+ "KAMATZ": 1464,
59
+ "KAMATZ_KATAN": 1479,
60
+ "HOLAM": 1465,
61
+ "HOLAM HASER VAV": 1466,
62
+ "KUBUTZ": 1467,
63
+ "DAGESH OR SHURUK": 1468,
64
+ "METEG": 1469,
65
+ "PUNCTUATION MAQAF": 1470,
66
+ "RAFE": 1471,
67
+ "PUNCTUATION PASEQ": 1472,
68
+ "SHIN_YEMANIT": 1473,
69
+ "SHIN_SMALIT": 1474,
70
+ }
71
+
72
+ skip_nikud = (
73
+ []
74
+ ) # [nikud_dict["KAMATZ_KATAN"], nikud_dict["HOLAM HASER VAV"], nikud_dict["METEG"], nikud_dict["PUNCTUATION MAQAF"], nikud_dict["PUNCTUATION PASEQ"]]
75
+ sign_2_name = {sign: name for name, sign in nikud_dict.items()}
76
+ sin = [nikud_dict["RAFE"], nikud_dict["SHIN_YEMANIT"], nikud_dict["SHIN_SMALIT"]]
77
+ dagesh = [
78
+ nikud_dict["RAFE"],
79
+ nikud_dict["DAGESH OR SHURUK"],
80
+ ] # note that DAGESH and SHURUK are one and the same
81
+ nikud = []
82
+ for v in nikud_dict.values():
83
+ if v not in sin and v not in skip_nikud:
84
+ nikud.append(v)
85
+ all_nikud_ord = {v for v in nikud_dict.values()}
86
+ all_nikud_chr = {chr(v) for v in nikud_dict.values()}
87
+
88
+ label_2_id = {
89
+ "nikud": {label: i for i, label in enumerate(nikud + ["WITHOUT"])},
90
+ "dagesh": {label: i for i, label in enumerate(dagesh + ["WITHOUT"])},
91
+ "sin": {label: i for i, label in enumerate(sin + ["WITHOUT"])},
92
+ }
93
+ id_2_label = {
94
+ "nikud": {i: label for i, label in enumerate(nikud + ["WITHOUT"])},
95
+ "dagesh": {i: label for i, label in enumerate(dagesh + ["WITHOUT"])},
96
+ "sin": {i: label for i, label in enumerate(sin + ["WITHOUT"])},
97
+ }
98
+
99
+ DAGESH_LETTER = nikud_dict["DAGESH OR SHURUK"]
100
+ RAFE = nikud_dict["RAFE"]
101
+ PAD_OR_IRRELEVANT = -1
102
+
103
+ LEN_NIKUD = len(label_2_id["nikud"])
104
+ LEN_DAGESH = len(label_2_id["dagesh"])
105
+ LEN_SIN = len(label_2_id["sin"])
106
+
107
+ def id_2_char(self, c, class_type):
108
+ if c == -1:
109
+ return ""
110
+
111
+ label = self.id_2_label[class_type][c]
112
+
113
+ if label != "WITHOUT":
114
+ print("Label =", chr(self.id_2_label[class_type][c]))
115
+ return chr(self.id_2_label[class_type][c])
116
+ return ""
117
+
118
+
119
+ class Letters:
120
+ hebrew = [chr(c) for c in range(0x05D0, 0x05EA + 1)]
121
+ VALID_LETTERS = [
122
+ " ",
123
+ "!",
124
+ '"',
125
+ "'",
126
+ "(",
127
+ ")",
128
+ ",",
129
+ "-",
130
+ ".",
131
+ ":",
132
+ ";",
133
+ "?",
134
+ ] + hebrew
135
+ SPECIAL_TOKENS = ["H", "O", "5", "1"]
136
+ ENDINGS_TO_REGULAR = dict(zip("ืšืืŸืฃืฅ", "ื›ืžื ืคืฆ"))
137
+ vocab = VALID_LETTERS + SPECIAL_TOKENS
138
+ vocab_size = len(vocab)
139
+
140
+
141
+ class Letter:
142
+ def __init__(self, letter):
143
+ self.letter = letter
144
+ self.normalized = None
145
+ self.dagesh = None
146
+ self.sin = None
147
+ self.nikud = None
148
+
149
+ def normalize(self, letter):
150
+ if letter in Letters.VALID_LETTERS:
151
+ return letter
152
+ if letter in Letters.ENDINGS_TO_REGULAR:
153
+ return Letters.ENDINGS_TO_REGULAR[letter]
154
+ if letter in ["\n", "\t"]:
155
+ return " "
156
+ if letter in ["โ€’", "โ€“", "โ€”", "โ€•", "โˆ’", "+"]:
157
+ return "-"
158
+ if letter == "[":
159
+ return "("
160
+ if letter == "]":
161
+ return ")"
162
+ if letter in ["ยด", "โ€˜", "โ€™"]:
163
+ return "'"
164
+ if letter in ["โ€œ", "โ€", "ืด"]:
165
+ return '"'
166
+ if letter.isdigit():
167
+ if int(letter) == 1:
168
+ return "1"
169
+ else:
170
+ return "5"
171
+ if letter == "โ€ฆ":
172
+ return ","
173
+ if letter in ["ืฒ", "ืฐ", "ืฑ"]:
174
+ return "H"
175
+ return "O"
176
+
177
+ def can_dagesh(self, letter):
178
+ return letter in ("ื‘ื’ื“ื”ื•ื–ื˜ื™ื›ืœืžื ืกืคืฆืงืฉืช" + "ืšืฃ")
179
+
180
+ def can_sin(self, letter):
181
+ return letter == "ืฉ"
182
+
183
+ def can_nikud(self, letter):
184
+ return letter in ("ืื‘ื’ื“ื”ื•ื–ื—ื˜ื™ื›ืœืžื ืกืขืคืฆืงืจืฉืช" + "ืšืŸ")
185
+
186
+ def get_label_letter(self, labels):
187
+ dagesh_sin_nikud = [
188
+ True if self.can_dagesh(self.letter) else False,
189
+ True if self.can_sin(self.letter) else False,
190
+ True if self.can_nikud(self.letter) else False,
191
+ ]
192
+
193
+ labels_ids = {
194
+ "nikud": Nikud.PAD_OR_IRRELEVANT,
195
+ "dagesh": Nikud.PAD_OR_IRRELEVANT,
196
+ "sin": Nikud.PAD_OR_IRRELEVANT,
197
+ }
198
+
199
+ normalized = self.normalize(self.letter)
200
+
201
+ i = 0
202
+ if Nikud.nikud_dict["PUNCTUATION PASEQ"] in labels:
203
+ labels.remove(Nikud.nikud_dict["PUNCTUATION PASEQ"])
204
+ if Nikud.nikud_dict["PUNCTUATION MAQAF"] in labels:
205
+ labels.remove(Nikud.nikud_dict["PUNCTUATION MAQAF"])
206
+ if Nikud.nikud_dict["HOLAM HASER VAV"] in labels:
207
+ labels.remove(Nikud.nikud_dict["HOLAM HASER VAV"])
208
+ if Nikud.nikud_dict["METEG"] in labels:
209
+ labels.remove(Nikud.nikud_dict["METEG"])
210
+ if Nikud.nikud_dict["KAMATZ_KATAN"] in labels:
211
+ labels[labels.index(Nikud.nikud_dict["KAMATZ_KATAN"])] = Nikud.nikud_dict[
212
+ "KAMATZ"
213
+ ]
214
+ for index, (class_name, group) in enumerate(
215
+ zip(
216
+ ["dagesh", "sin", "nikud"],
217
+ [[Nikud.DAGESH_LETTER], Nikud.sin, Nikud.nikud],
218
+ )
219
+ ):
220
+ # notice - order is important: dagesh then sin and then nikud
221
+ if dagesh_sin_nikud[index]:
222
+ if i < len(labels) and labels[i] in group:
223
+ labels_ids[class_name] = Nikud.label_2_id[class_name][labels[i]]
224
+ i += 1
225
+ else:
226
+ labels_ids[class_name] = Nikud.label_2_id[class_name]["WITHOUT"]
227
+
228
+ if (
229
+ np.array(dagesh_sin_nikud).all()
230
+ and len(labels) == 3
231
+ and labels[0] in Nikud.sin
232
+ ):
233
+ labels_ids["nikud"] = Nikud.label_2_id["nikud"][labels[2]]
234
+ labels_ids["dagesh"] = Nikud.label_2_id["dagesh"][labels[1]]
235
+
236
+ if (
237
+ self.can_sin(self.letter)
238
+ and len(labels) == 2
239
+ and labels[1] == Nikud.DAGESH_LETTER
240
+ ):
241
+ labels_ids["dagesh"] = Nikud.label_2_id["dagesh"][labels[1]]
242
+ labels_ids["nikud"] = Nikud.label_2_id[class_name]["WITHOUT"]
243
+
244
+ if (
245
+ self.letter == "ื•"
246
+ and labels_ids["dagesh"] == Nikud.DAGESH_LETTER
247
+ and labels_ids["nikud"] == Nikud.label_2_id["nikud"]["WITHOUT"]
248
+ ):
249
+ labels_ids["dagesh"] = Nikud.label_2_id["dagesh"]["WITHOUT"]
250
+ labels_ids["nikud"] = Nikud.DAGESH_LETTER
251
+
252
+ self.normalized = normalized
253
+ self.dagesh = labels_ids["dagesh"]
254
+ self.sin = labels_ids["sin"]
255
+ self.nikud = labels_ids["nikud"]
256
+
257
+ def name_of(self, letter):
258
+ if "ื" <= letter <= "ืช":
259
+ return letter
260
+ if letter == Nikud.DAGESH_LETTER:
261
+ return "ื“ื’ืฉ\ืฉื•ืจื•ืง"
262
+ if letter == Nikud.KAMATZ:
263
+ return "ืงืžืฅ"
264
+ if letter == Nikud.PATAKH:
265
+ return "ืคืชื—"
266
+ if letter == Nikud.TZEIRE:
267
+ return "ืฆื™ืจื”"
268
+ if letter == Nikud.SEGOL:
269
+ return "ืกื’ื•ืœ"
270
+ if letter == Nikud.SHVA:
271
+ return "ืฉื•ื"
272
+ if letter == Nikud.HOLAM:
273
+ return "ื—ื•ืœื"
274
+ if letter == Nikud.KUBUTZ:
275
+ return "ืงื•ื‘ื•ืฅ"
276
+ if letter == Nikud.HIRIK:
277
+ return "ื—ื™ืจื™ืง"
278
+ if letter == Nikud.REDUCED_KAMATZ:
279
+ return "ื—ื˜ืฃ-ืงืžืฅ"
280
+ if letter == Nikud.REDUCED_PATAKH:
281
+ return "ื—ื˜ืฃ-ืคืชื—"
282
+ if letter == Nikud.REDUCED_SEGOL:
283
+ return "ื—ื˜ืฃ-ืกื’ื•ืœ"
284
+ if letter == Nikud.SHIN_SMALIT:
285
+ return "ืฉื™ืŸ-ืฉืžืืœื™ืช"
286
+ if letter == Nikud.SHIN_YEMANIT:
287
+ return "ืฉื™ืŸ-ื™ืžื ื™ืช"
288
+ if letter.isprintable():
289
+ return letter
290
+ return "ืœื ื™ื“ื•ืข ({})".format(hex(ord(letter)))
291
+
292
+
293
+ def text_contains_nikud(text):
294
+ return len(set(text) & Nikud.all_nikud_chr) > 0
295
+
296
+
297
+ def combine_sentences(list_sentences, max_length=0, is_train=False):
298
+ all_new_sentences = []
299
+ new_sen = ""
300
+ index = 0
301
+ while index < len(list_sentences):
302
+ sen = list_sentences[index]
303
+
304
+ if not text_contains_nikud(sen) and (
305
+ "------------------" in sen or sen == "\n"
306
+ ):
307
+ if len(new_sen) > 0:
308
+ all_new_sentences.append(new_sen)
309
+ if not is_train:
310
+ all_new_sentences.append(sen)
311
+ new_sen = ""
312
+ index += 1
313
+ continue
314
+
315
+ if not text_contains_nikud(sen) and is_train:
316
+ index += 1
317
+ continue
318
+
319
+ if len(sen) > max_length:
320
+ update_sen = sen.replace(". ", f". {unique_key}")
321
+ update_sen = update_sen.replace("? ", f"? {unique_key}")
322
+ update_sen = update_sen.replace("! ", f"! {unique_key}")
323
+ update_sen = update_sen.replace("โ€ ", f"โ€ {unique_key}")
324
+ update_sen = update_sen.replace("\t", f"\t{unique_key}")
325
+ part_sentence = update_sen.split(unique_key)
326
+
327
+ good_parts = []
328
+ for p in part_sentence:
329
+ if len(p) < max_length:
330
+ good_parts.append(p)
331
+ else:
332
+ prev = 0
333
+ while prev <= len(p):
334
+ part = p[prev : (prev + max_length)]
335
+ last_space = 0
336
+ if " " in part:
337
+ last_space = part[::-1].index(" ") + 1
338
+ next = prev + max_length - last_space
339
+ part = p[prev:next]
340
+ good_parts.append(part)
341
+ prev = next
342
+ list_sentences = (
343
+ list_sentences[:index] + good_parts + list_sentences[index + 1 :]
344
+ )
345
+ continue
346
+ if new_sen == "":
347
+ new_sen = sen
348
+ elif len(new_sen) + len(sen) < max_length:
349
+ new_sen += sen
350
+ else:
351
+ all_new_sentences.append(new_sen)
352
+ new_sen = sen
353
+
354
+ index += 1
355
+ if len(new_sen) > 0:
356
+ all_new_sentences.append(new_sen)
357
+ return all_new_sentences
358
+
359
+
360
+ class NikudDataset(Dataset):
361
+ def __init__(
362
+ self,
363
+ tokenizer,
364
+ folder=None,
365
+ file=None,
366
+ logger=None,
367
+ max_length=0,
368
+ is_train=False,
369
+ ):
370
+ self.max_length = max_length
371
+ self.tokenizer = tokenizer
372
+ self.is_train = is_train
373
+ self.data = None
374
+ self.origin_data = None
375
+ if folder is not None:
376
+ self.data, self.origin_data = self.read_data_folder(folder, logger)
377
+ elif file is not None:
378
+ self.data, self.origin_data = self.read_data(file, logger)
379
+ self.prepered_data = None
380
+
381
+ def read_data_folder(self, folder_path: str, logger=None):
382
+ all_files = glob2.glob(f"{folder_path}/**/*.txt", recursive=True)
383
+ msg = f"number of files: " + str(len(all_files))
384
+ if logger:
385
+ logger.debug(msg)
386
+ else:
387
+ print(msg)
388
+ all_data = []
389
+ all_origin_data = []
390
+ if DEBUG_MODE:
391
+ all_files = all_files[0:2]
392
+ for file in all_files:
393
+ if "not_use" in file or "NakdanResults" in file:
394
+ continue
395
+ data, origin_data = self.read_data(file, logger)
396
+ all_data.extend(data)
397
+ all_origin_data.extend(origin_data)
398
+ return all_data, all_origin_data
399
+
400
+ def read_data(self, filepath: str, logger=None) -> List[Tuple[str, list]]:
401
+ msg = f"read file: {filepath}"
402
+ if logger:
403
+ logger.debug(msg)
404
+ else:
405
+ print(msg)
406
+ data = []
407
+ orig_data = []
408
+ with open(filepath, "r", encoding="utf-8") as file:
409
+ file_data = file.read()
410
+ data_list = self.split_text(file_data)
411
+
412
+ for sen in tqdm(data_list, desc=f"Source: {os.path.basename(filepath)}"):
413
+ if sen == "":
414
+ continue
415
+
416
+ labels = []
417
+ text = ""
418
+ text_org = ""
419
+ index = 0
420
+ sentence_length = len(sen)
421
+ while index < sentence_length:
422
+ if (
423
+ ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION MAQAF"]
424
+ or ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION PASEQ"]
425
+ or ord(sen[index]) == Nikud.nikud_dict["METEG"]
426
+ ):
427
+ index += 1
428
+ continue
429
+
430
+ label = []
431
+ l = Letter(sen[index])
432
+ if not (l.letter not in Nikud.all_nikud_chr):
433
+ if sen[index - 1] == "\n":
434
+ index += 1
435
+ continue
436
+ assert l.letter not in Nikud.all_nikud_chr
437
+ if sen[index] in Letters.hebrew:
438
+ index += 1
439
+ while (
440
+ index < sentence_length
441
+ and ord(sen[index]) in Nikud.all_nikud_ord
442
+ ):
443
+ label.append(ord(sen[index]))
444
+ index += 1
445
+ else:
446
+ index += 1
447
+
448
+ l.get_label_letter(label)
449
+ text += l.normalized
450
+ text_org += l.letter
451
+ labels.append(l)
452
+
453
+ data.append((text, labels))
454
+ orig_data.append(text_org)
455
+
456
+ return data, orig_data
457
+
458
+ def read_single_text(self, text: str, logger=None) -> List[Tuple[str, list]]:
459
+ # msg = f"read file: {filepath}"
460
+ # if logger:
461
+ # logger.debug(msg)
462
+ # else:
463
+ # print(msg)
464
+ data = []
465
+ orig_data = []
466
+ # with open(filepath, "r", encoding="utf-8") as file:
467
+ # file_data = file.read()
468
+ data_list = self.split_text(text)
469
+ # print("data_list", data_list)
470
+ for sen in tqdm(data_list, desc=f"Source: {data}"):
471
+ if sen == "":
472
+ continue
473
+
474
+ labels = []
475
+ text = ""
476
+ text_org = ""
477
+ index = 0
478
+ sentence_length = len(sen)
479
+ while index < sentence_length:
480
+ if (
481
+ ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION MAQAF"]
482
+ or ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION PASEQ"]
483
+ or ord(sen[index]) == Nikud.nikud_dict["METEG"]
484
+ ):
485
+ index += 1
486
+ continue
487
+
488
+ label = []
489
+ l = Letter(sen[index])
490
+ if not (l.letter not in Nikud.all_nikud_chr):
491
+ if sen[index - 1] == "\n":
492
+ index += 1
493
+ continue
494
+ assert l.letter not in Nikud.all_nikud_chr
495
+ if sen[index] in Letters.hebrew:
496
+ index += 1
497
+ while (
498
+ index < sentence_length
499
+ and ord(sen[index]) in Nikud.all_nikud_ord
500
+ ):
501
+ label.append(ord(sen[index]))
502
+ index += 1
503
+ else:
504
+ index += 1
505
+
506
+ l.get_label_letter(label)
507
+ text += l.normalized
508
+ text_org += l.letter
509
+ labels.append(l)
510
+
511
+ data.append((text, labels))
512
+ orig_data.append(text_org)
513
+ self.data = data
514
+ self.origin_data = orig_data
515
+ return data, orig_data
516
+
517
+ def split_text(self, file_data):
518
+ file_data = file_data.replace("\n", f"\n{unique_key}")
519
+ data_list = file_data.split(unique_key)
520
+ data_list = combine_sentences(
521
+ data_list, is_train=self.is_train, max_length=MAX_LENGTH_SEN
522
+ )
523
+ return data_list
524
+
525
+ def show_data_labels(self, plots_folder=None):
526
+ nikud = [
527
+ Nikud.id_2_label["nikud"][label.nikud]
528
+ for _, label_list in self.data
529
+ for label in label_list
530
+ if label.nikud != -1
531
+ ]
532
+ dagesh = [
533
+ Nikud.id_2_label["dagesh"][label.dagesh]
534
+ for _, label_list in self.data
535
+ for label in label_list
536
+ if label.dagesh != -1
537
+ ]
538
+ sin = [
539
+ Nikud.id_2_label["sin"][label.sin]
540
+ for _, label_list in self.data
541
+ for label in label_list
542
+ if label.sin != -1
543
+ ]
544
+
545
+ vowels = nikud + dagesh + sin
546
+ unique_vowels, label_counts = np.unique(vowels, return_counts=True)
547
+ unique_vowels_names = [
548
+ Nikud.sign_2_name[int(vowel)]
549
+ for vowel in unique_vowels
550
+ if vowel != "WITHOUT"
551
+ ] + ["WITHOUT"]
552
+ fig, ax = plt.subplots(figsize=(16, 6))
553
+
554
+ bar_positions = np.arange(len(unique_vowels))
555
+ bar_width = 0.15
556
+ ax.bar(bar_positions, list(label_counts), bar_width)
557
+
558
+ ax.set_title("Distribution of Vowels in dataset")
559
+ ax.set_xlabel("Vowels")
560
+ ax.set_ylabel("Count")
561
+ ax.legend(loc="right", bbox_to_anchor=(1, 0.85))
562
+ ax.set_xticks(bar_positions)
563
+ ax.set_xticklabels(unique_vowels_names, rotation=30, ha="right", fontsize=8)
564
+
565
+ if plots_folder is None:
566
+ plt.show()
567
+ else:
568
+ plt.savefig(os.path.join(plots_folder, "show_data_labels.jpg"))
569
+
570
+ def calc_max_length(self, maximum=MAX_LENGTH_SEN):
571
+ if self.max_length > maximum:
572
+ self.max_length = maximum
573
+ return self.max_length
574
+
575
+ def prepare_data(self, name="train"):
576
+ dataset = []
577
+ for index, (sentence, label) in tqdm(
578
+ enumerate(self.data), desc=f"prepare data {name}"
579
+ ):
580
+ encoded_sequence = self.tokenizer.encode_plus(
581
+ sentence,
582
+ add_special_tokens=True,
583
+ max_length=self.max_length,
584
+ padding="max_length",
585
+ truncation=True,
586
+ return_attention_mask=True,
587
+ return_tensors="pt",
588
+ )
589
+ label_lists = [
590
+ [letter.nikud, letter.dagesh, letter.sin] for letter in label
591
+ ]
592
+ label = torch.tensor(
593
+ [
594
+ [
595
+ Nikud.PAD_OR_IRRELEVANT,
596
+ Nikud.PAD_OR_IRRELEVANT,
597
+ Nikud.PAD_OR_IRRELEVANT,
598
+ ]
599
+ ]
600
+ + label_lists[: (self.max_length - 1)]
601
+ + [
602
+ [
603
+ Nikud.PAD_OR_IRRELEVANT,
604
+ Nikud.PAD_OR_IRRELEVANT,
605
+ Nikud.PAD_OR_IRRELEVANT,
606
+ ]
607
+ for i in range(self.max_length - len(label) - 1)
608
+ ]
609
+ )
610
+
611
+ dataset.append(
612
+ (
613
+ encoded_sequence["input_ids"][0],
614
+ encoded_sequence["attention_mask"][0],
615
+ label,
616
+ )
617
+ )
618
+
619
+ self.prepered_data = dataset
620
+
621
+ def back_2_text(self, labels):
622
+ nikud = Nikud()
623
+ all_text = ""
624
+ for indx_sentance, (input_ids, _, label) in enumerate(self.prepered_data):
625
+ new_line = ""
626
+ for indx_char, c in enumerate(self.origin_data[indx_sentance]):
627
+ new_line += (
628
+ c
629
+ + nikud.id_2_char(labels[indx_sentance, indx_char + 1, 1], "dagesh")
630
+ + nikud.id_2_char(labels[indx_sentance, indx_char + 1, 2], "sin")
631
+ + nikud.id_2_char(labels[indx_sentance, indx_char + 1, 0], "nikud")
632
+ )
633
+ all_text += new_line
634
+ return all_text
635
+
636
+ def __len__(self):
637
+ return self.data.shape[0]
638
+
639
+ def __getitem__(self, idx):
640
+ row = self.data[idx]
641
+
642
+
643
+ def get_sub_folders_paths(main_folder):
644
+ list_paths = []
645
+ for filename in os.listdir(main_folder):
646
+ path = os.path.join(main_folder, filename)
647
+ if os.path.isdir(path) and filename != ".git":
648
+ list_paths.append(path)
649
+ list_paths.extend(get_sub_folders_paths(path))
650
+ return list_paths
651
+
652
+
653
+ def create_missing_folders(folder_path):
654
+ # Check if the folder doesn't exist and create it if needed
655
+ if not os.path.exists(folder_path):
656
+ os.makedirs(folder_path)
657
+
658
+
659
+ def info_folder(folder, num_files, num_hebrew_letters):
660
+ """
661
+ Recursively counts the number of files and the number of Hebrew letters in all subfolders of the given folder path.
662
+
663
+ Args:
664
+ folder (str): The path of the folder to be analyzed.
665
+ num_files (int): The running total of the number of files encountered so far.
666
+ num_hebrew_letters (int): The running total of the number of Hebrew letters encountered so far.
667
+
668
+ Returns:
669
+ Tuple[int, int]: A tuple containing the total number of files and the total number of Hebrew letters.
670
+ """
671
+ for filename in os.listdir(folder):
672
+ file_path = os.path.join(folder, filename)
673
+ if filename.lower().endswith(".txt") and os.path.isfile(file_path):
674
+ num_files += 1
675
+ dataset = NikudDataset(None, file=file_path)
676
+ for line in dataset.data:
677
+ for c in line[0]:
678
+ if c in Letters.hebrew:
679
+ num_hebrew_letters += 1
680
+
681
+ elif os.path.isdir(file_path) and filename != ".git":
682
+ sub_folder = file_path
683
+ n1, n2 = info_folder(sub_folder, num_files, num_hebrew_letters)
684
+ num_files += n1
685
+ num_hebrew_letters += n2
686
+ return num_files, num_hebrew_letters
687
+
688
+
689
+ def extract_text_to_compare_nakdimon(text):
690
+ res = text.replace("|", "")
691
+ res = res.replace(
692
+ chr(Nikud.nikud_dict["KUBUTZ"]) + "ื•" + chr(Nikud.nikud_dict["METEG"]),
693
+ "ื•" + chr(Nikud.nikud_dict["DAGESH OR SHURUK"]),
694
+ )
695
+ res = res.replace(
696
+ chr(Nikud.nikud_dict["HOLAM"]) + "ื•" + chr(Nikud.nikud_dict["METEG"]), "ื•"
697
+ )
698
+ res = res.replace(
699
+ "ื•" + chr(Nikud.nikud_dict["HOLAM"]) + chr(Nikud.nikud_dict["KAMATZ"]),
700
+ "ื•" + chr(Nikud.nikud_dict["KAMATZ"]),
701
+ )
702
+ res = res.replace(chr(Nikud.nikud_dict["METEG"]), "")
703
+ res = res.replace(
704
+ chr(Nikud.nikud_dict["KAMATZ"]) + chr(Nikud.nikud_dict["HIRIK"]),
705
+ chr(Nikud.nikud_dict["KAMATZ"]) + "ื™" + chr(Nikud.nikud_dict["HIRIK"]),
706
+ )
707
+ res = res.replace(
708
+ chr(Nikud.nikud_dict["PATAKH"]) + chr(Nikud.nikud_dict["HIRIK"]),
709
+ chr(Nikud.nikud_dict["PATAKH"]) + "ื™" + chr(Nikud.nikud_dict["HIRIK"]),
710
+ )
711
+ res = res.replace(chr(Nikud.nikud_dict["PUNCTUATION MAQAF"]), "")
712
+ res = res.replace(chr(Nikud.nikud_dict["PUNCTUATION PASEQ"]), "")
713
+ res = res.replace(
714
+ chr(Nikud.nikud_dict["KAMATZ_KATAN"]), chr(Nikud.nikud_dict["KAMATZ"])
715
+ )
716
+
717
+ res = re.sub(chr(Nikud.nikud_dict["KUBUTZ"]) + "ื•" + "(?=[ื-ืช])", "ื•", res)
718
+ res = res.replace(chr(Nikud.nikud_dict["REDUCED_KAMATZ"]) + "ื•", "ื•")
719
+
720
+ res = res.replace(
721
+ chr(Nikud.nikud_dict["DAGESH OR SHURUK"]) * 2,
722
+ chr(Nikud.nikud_dict["DAGESH OR SHURUK"]),
723
+ )
724
+ res = res.replace("\u05be", "-")
725
+ res = res.replace("ื™ึฐื”ื•ึนึธื”", "ื™ื”ื•ื”")
726
+
727
+ return res
728
+
729
+
730
+ def orgenize_data(main_folder, logger):
731
+ x = NikudDataset(None)
732
+ x.delete_files(os.path.join(Path(main_folder).parent, "train"))
733
+ x.delete_files(os.path.join(Path(main_folder).parent, "dev"))
734
+ x.delete_files(os.path.join(Path(main_folder).parent, "test"))
735
+ x.split_data(
736
+ main_folder, main_folder_name=os.path.basename(main_folder), logger=logger
737
+ )