jedick commited on
Commit
142bd00
·
1 Parent(s): f027363

Don't import tqdm for BM25S tokenizer used in retrieval

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. mods/bm25s_retriever.py +5 -2
  3. mods/bm25s_tokenization.py +719 -0
app.py CHANGED
@@ -58,10 +58,10 @@ def cleanup_graph(request: gr.Request):
58
  timestamp = datetime.now().replace(microsecond=0).isoformat()
59
  if request.session_hash in graph_instances["local"]:
60
  del graph_instances["local"][request.session_hash]
61
- print(f"{timestamp} - Del local graph for session {request.session_hash}")
62
  if request.session_hash in graph_instances["remote"]:
63
  del graph_instances["remote"][request.session_hash]
64
- print(f"{timestamp} - Del remote graph for session {request.session_hash}")
65
 
66
 
67
  def append_content(chunk_messages, history, thinking_about):
 
58
  timestamp = datetime.now().replace(microsecond=0).isoformat()
59
  if request.session_hash in graph_instances["local"]:
60
  del graph_instances["local"][request.session_hash]
61
+ print(f"{timestamp} - Delete local graph for session {request.session_hash}")
62
  if request.session_hash in graph_instances["remote"]:
63
  del graph_instances["remote"][request.session_hash]
64
+ print(f"{timestamp} - Delete remote graph for session {request.session_hash}")
65
 
66
 
67
  def append_content(chunk_messages, history, thinking_about):
mods/bm25s_retriever.py CHANGED
@@ -155,13 +155,16 @@ class BM25SRetriever(BaseRetriever):
155
  *,
156
  run_manager: CallbackManagerForRetrieverRun,
157
  ) -> List[Document]:
158
- from bm25s import tokenize as bm25s_tokenize
159
 
160
  processed_query = bm25s_tokenize(query, return_ids=False)
161
  if self.activate_numba:
162
  self.vectorizer.activate_numba_scorer()
163
  return_docs = self.vectorizer.retrieve(
164
- processed_query, k=self.k, backend_selection="numba"
 
 
 
165
  )
166
  return [self.docs[i] for i in return_docs.documents[0]]
167
  else:
 
155
  *,
156
  run_manager: CallbackManagerForRetrieverRun,
157
  ) -> List[Document]:
158
+ from mods.bm25s_tokenization import tokenize as bm25s_tokenize
159
 
160
  processed_query = bm25s_tokenize(query, return_ids=False)
161
  if self.activate_numba:
162
  self.vectorizer.activate_numba_scorer()
163
  return_docs = self.vectorizer.retrieve(
164
+ processed_query,
165
+ k=self.k,
166
+ backend_selection="numba",
167
+ show_progress=False,
168
  )
169
  return [self.docs[i] for i in return_docs.documents[0]]
170
  else:
mods/bm25s_tokenization.py ADDED
@@ -0,0 +1,719 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ast import Tuple
2
+ from pathlib import Path
3
+ import re
4
+ from typing import Any, Dict, List, Union, Callable, NamedTuple
5
+ import typing
6
+
7
+ from bm25s.utils import json_functions
8
+
9
+ try:
10
+ # To hide progress bars, don't import tqdm
11
+ # from tqdm.auto import tqdm
12
+ raise ImportError("Not importing tqdm")
13
+ except ImportError:
14
+
15
+ def tqdm(iterable, *args, **kwargs):
16
+ return iterable
17
+
18
+
19
+ from bm25s.stopwords import (
20
+ STOPWORDS_EN,
21
+ STOPWORDS_EN_PLUS,
22
+ STOPWORDS_GERMAN,
23
+ STOPWORDS_DUTCH,
24
+ STOPWORDS_FRENCH,
25
+ STOPWORDS_SPANISH,
26
+ STOPWORDS_PORTUGUESE,
27
+ STOPWORDS_ITALIAN,
28
+ STOPWORDS_RUSSIAN,
29
+ STOPWORDS_SWEDISH,
30
+ STOPWORDS_NORWEGIAN,
31
+ STOPWORDS_CHINESE,
32
+ )
33
+
34
+
35
+ class Tokenized(NamedTuple):
36
+ """
37
+ NamedTuple with two fields: ids and vocab. The ids field is a list of list of token IDs
38
+ for each document. The vocab field is a dictionary mapping tokens to their index in the
39
+ vocabulary.
40
+ """
41
+
42
+ ids: List[List[int]]
43
+ vocab: Dict[str, int]
44
+
45
+ def __repr__(self):
46
+ """
47
+ Returns:
48
+ a string representation of the class.
49
+ for example, for a small corpus, it would be something like:
50
+ ----
51
+ Tokenized(
52
+ "ids": [
53
+ 0: [0, 1, 2, 3]
54
+ ],
55
+ "vocab": [
56
+ '': 4
57
+ 'cat': 0
58
+ 'feline': 1
59
+ 'likes': 2
60
+ 'purr': 3
61
+ ],
62
+ )
63
+ ----
64
+
65
+ and, for example, for a large corpus, it would be something like:
66
+ ----
67
+ Tokenized(
68
+ "ids": [
69
+ 0: [0, 1, 2, 3]
70
+ 1: [4, 5, 6, 7, 8, 9]
71
+ 2: [10, 11, 12, 13, 14]
72
+ 3: [15, 16, 17, 18, 19]
73
+ 4: [0, 1, 2, 3, 0, 20, 21, 22, 23, 24, ...]
74
+ 5: [0, 1, 2, 3]
75
+ 6: [4, 5, 6, 7, 8, 9]
76
+ 7: [10, 11, 12, 13, 14]
77
+ 8: [15, 16, 17, 18, 19]
78
+ 9: [0, 1, 2, 3, 0, 20, 21, 22, 23, 24, ...]
79
+ ... (total 500000 docs)
80
+ ],
81
+ "vocab": [
82
+ '': 29
83
+ 'animal': 12
84
+ 'beautiful': 11
85
+ 'best': 6
86
+ 'bird': 10
87
+ 'can': 13
88
+ 'carefully': 27
89
+ 'casually': 28
90
+ 'cat': 0
91
+ 'creature': 16
92
+ ... (total 30 tokens)
93
+ ],
94
+ )
95
+ ----
96
+ """
97
+ lines_print_max_num = 10
98
+ single_doc_print_max_len = 10
99
+ lines = ["Tokenized(", ' "ids": [']
100
+ for doc_idx, document in enumerate(self.ids[:lines_print_max_num]):
101
+ preview = document[:single_doc_print_max_len]
102
+ if len(document) > single_doc_print_max_len:
103
+ preview += ["..."]
104
+ lines.append(f" {doc_idx}: [{', '.join([str(x) for x in preview])}]")
105
+ if len(self.ids) > lines_print_max_num:
106
+ lines.append(f" ... (total {len(self.ids)} docs)")
107
+ lines.append(f' ],\n "vocab": [')
108
+ vocab_keys = sorted(list(self.vocab.keys()))
109
+ for vocab_idx, key_ in enumerate(vocab_keys[:lines_print_max_num]):
110
+ val_ = self.vocab[key_]
111
+ lines.append(f" {key_!r}: {val_}")
112
+ if len(list(vocab_keys)) > 10:
113
+ lines.append(f" ... (total {len(vocab_keys)} tokens)")
114
+ lines.append(" ],\n)")
115
+ return "\n".join(lines)
116
+
117
+
118
+ class Tokenizer:
119
+ """
120
+ Tokenizer class for tokenizing a list of strings and converting them to token IDs.
121
+
122
+ Parameters
123
+ ----------
124
+ lower : bool, optional
125
+ Whether to convert the text to lowercase before tokenization
126
+
127
+ splitter : Union[str, Callable], optional
128
+ If a string is provided, the tokenizer will interpret it as a regex pattern,
129
+ and use the `re.compile` function to compile the pattern and use the `findall` method
130
+ to split the text. If a callable is provided, the tokenizer will use the callable to
131
+ split the text. The callable should take a string as input and return a list of strings.
132
+
133
+ stopwords : Union[str, List[str]], optional
134
+ The list of stopwords to remove from the text. If "english" or "en" is provided,
135
+ the function will use the default English stopwords. If None or False is provided,
136
+ no stopwords will be removed. If a list of strings is provided, the tokenizer will
137
+ use the list of strings as stopwords.
138
+
139
+ stemmer : Callable, optional
140
+ The stemmer to use for stemming the tokens. It is recommended
141
+ to use the PyStemmer library for stemming, but you can also any callable that
142
+ takes a list of strings and returns a list of strings.
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ lower: bool = True,
148
+ splitter: Union[str, Callable] = r"(?u)\b\w\w+\b",
149
+ stopwords: Union[str, List[str]] = "english",
150
+ stemmer: Callable = None, # type: ignore
151
+ ):
152
+ self.lower = lower
153
+ if isinstance(splitter, str):
154
+ splitter = re.compile(splitter).findall
155
+ if not callable(splitter):
156
+ raise ValueError("splitter must be a callable or a regex pattern.")
157
+
158
+ # Exception handling for stemmer when we are using PyStemmer, which has a stemWords method
159
+ if hasattr(stemmer, "stemWord"):
160
+ stemmer = stemmer.stemWord
161
+ if not callable(stemmer) and stemmer is not None:
162
+ raise ValueError("stemmer must be callable or have a `stemWord` method.")
163
+
164
+ self.stopwords = _infer_stopwords(stopwords)
165
+ self.splitter = splitter
166
+ self.stemmer = stemmer
167
+
168
+ self.reset_vocab()
169
+
170
+ def reset_vocab(self):
171
+ """
172
+ Reset the vocabulary dictionaries to empty dictionaries, allowing you to
173
+ tokenize a new set of texts without reusing the previous vocabulary.
174
+ """
175
+ self.word_to_stem = {} # word -> stemmed word, e.g. "apple" -> "appl"
176
+ self.stem_to_sid = {} # stem -> stemmed id, e.g. "appl" -> 0
177
+ # word -> {stemmed, unstemmed} id, e.g. "apple" -> 0 (appl) or "apple" -> 2 (apple)
178
+ self.word_to_id = {}
179
+
180
+ def save_vocab(self, save_dir: str, vocab_name: str = "vocab.tokenizer.json"):
181
+ """
182
+ Save the vocabulary dictionaries to a file. The file is saved in JSON format.
183
+
184
+ Parameters
185
+ ----------
186
+ save_dir : str
187
+ The directory where the vocabulary file is saved.
188
+
189
+ vocab_name : str, optional
190
+ The name of the vocabulary file. Default is "vocab.tokenizer.json". Make
191
+ sure to not use the same name as the vocab.index.json file saved by the BM25
192
+ model, as it will overwrite the vocab.index.json file and cause errors.
193
+ """
194
+ save_dir: Path = Path(save_dir)
195
+ path = save_dir / vocab_name
196
+
197
+ save_dir.mkdir(parents=True, exist_ok=True)
198
+ with open(path, "w", encoding="utf-8") as f:
199
+ d = {
200
+ "word_to_stem": self.word_to_stem,
201
+ "stem_to_sid": self.stem_to_sid,
202
+ "word_to_id": self.word_to_id,
203
+ }
204
+ f.write(json_functions.dumps(d, ensure_ascii=False))
205
+
206
+ def load_vocab(self, save_dir: str, vocab_name: str = "vocab.tokenizer.json"):
207
+ """
208
+ Load the vocabulary dictionaries from a file. The file should be saved in JSON format.
209
+
210
+ Parameters
211
+ ----------
212
+ save_dir : str
213
+ The directory where the vocabulary file is saved.
214
+
215
+ vocab_name : str, optional
216
+ The name of the vocabulary file.
217
+
218
+ Note
219
+ ----
220
+ The vocabulary file should be saved in JSON format, with the following keys:
221
+ - word_to_stem: a dictionary mapping words to their stemmed words
222
+ - stem_to_sid: a dictionary mapping stemmed words to their stemmed IDs
223
+ - word_to_id: a dictionary mapping words to their word
224
+ """
225
+ path = Path(save_dir) / vocab_name
226
+
227
+ with open(path, "r", encoding="utf-8") as f:
228
+ d = json_functions.loads(f.read())
229
+ self.word_to_stem = d["word_to_stem"]
230
+ self.stem_to_sid = d["stem_to_sid"]
231
+ self.word_to_id = d["word_to_id"]
232
+
233
+ def save_stopwords(
234
+ self, save_dir: str, stopwords_name: str = "stopwords.tokenizer.json"
235
+ ):
236
+ """
237
+ Save the stopwords to a file. The file is saved in JSON format.
238
+
239
+ Parameters
240
+ ----------
241
+ save_dir : str
242
+ The directory where the stopwords file is saved.
243
+
244
+ stopwords_name : str, optional
245
+ The name of the stopwords file. Default is "stopwords.tokenizer.json".
246
+ """
247
+ save_dir: Path = Path(save_dir)
248
+ path = save_dir / stopwords_name
249
+
250
+ save_dir.mkdir(parents=True, exist_ok=True)
251
+ with open(path, "w") as f:
252
+ f.write(json_functions.dumps(self.stopwords))
253
+
254
+ def load_stopwords(
255
+ self, save_dir: str, stopwords_name: str = "stopwords.tokenizer.json"
256
+ ):
257
+ """
258
+ Load the stopwords from a file. The file should be saved in JSON format.
259
+
260
+ Parameters
261
+ ----------
262
+ save_dir : str
263
+ The directory where the stopwords file is saved.
264
+
265
+ stopwords_name : str, optional
266
+ The name of the stopwords file.
267
+ """
268
+ path = Path(save_dir) / stopwords_name
269
+
270
+ with open(path, "r") as f:
271
+ self.stopwords = json_functions.loads(f.read())
272
+
273
+ def streaming_tokenize(
274
+ self,
275
+ texts: List[str],
276
+ update_vocab: Union[bool, str] = True,
277
+ allow_empty: bool = True,
278
+ ):
279
+ """
280
+ Tokenize a list of strings and return a generator of token IDs.
281
+
282
+ Parameters
283
+ ----------
284
+ texts : List[str]
285
+ A list of strings to tokenize.
286
+
287
+ update_vocab : bool, optional
288
+ Whether to update the vocabulary dictionary with the new tokens. If true,
289
+ the different dictionaries making up the vocabulary will be updated with the
290
+ new tokens. If False, the function will not update the vocabulary. Unless you have
291
+ a stemmer and the stemmed word is in the stem_to_sid dictionary. If "never",
292
+ the function will never update the vocabulary, even if the stemmed word is in
293
+ the stem_to_sid dictionary. Note that update_vocab="if_empty" is not supported
294
+ in this method, only in the `tokenize` method.
295
+
296
+ allow_empty : bool, optional
297
+ Whether to allow the splitter to return an empty string. If False, the splitter
298
+ will return an empty list, which may cause issues if the tokenizer is not expecting
299
+ an empty list. If True, the splitter will return a list with a single empty string.
300
+ """
301
+ stopwords_set = set(self.stopwords) if self.stopwords is not None else None
302
+ using_stopwords = stopwords_set is not None
303
+ using_stemmer = self.stemmer is not None
304
+
305
+ if allow_empty is True and update_vocab is True and "" not in self.word_to_id:
306
+ idx = max(self.word_to_id.values(), default=-1) + 1
307
+ self.word_to_id[""] = idx
308
+
309
+ if using_stemmer:
310
+ if "" not in self.word_to_stem:
311
+ self.word_to_stem[""] = ""
312
+ if "" not in self.stem_to_sid:
313
+ self.stem_to_sid[""] = idx
314
+
315
+ for text in texts:
316
+ if self.lower:
317
+ text = text.lower()
318
+
319
+ splitted_words = list(self.splitter(text))
320
+
321
+ if allow_empty is True and len(splitted_words) == 0:
322
+ splitted_words = [""]
323
+
324
+ doc_ids = []
325
+ for word in splitted_words:
326
+ if word in self.word_to_id:
327
+ wid = self.word_to_id[word]
328
+ doc_ids.append(wid)
329
+ continue
330
+
331
+ if using_stopwords and word in stopwords_set:
332
+ continue
333
+
334
+ # We are always updating the word_to_stem mapping since even new
335
+ # words that we have never seen before can be stemmed, with the
336
+ # possibility that the stemmed ID is already in the stem_to_sid
337
+ if using_stemmer:
338
+ if word in self.word_to_stem:
339
+ stem = self.word_to_stem[word]
340
+ else:
341
+ stem = self.stemmer(word)
342
+ self.word_to_stem[word] = stem
343
+
344
+ # if the stem is already in the stem_to_sid, we can just use the ID
345
+ # and update the word_to_id dictionary, unless update_vocab is "never"
346
+ # in which case we skip this word
347
+ if update_vocab != "never" and stem in self.stem_to_sid:
348
+ sid = self.stem_to_sid[stem]
349
+ self.word_to_id[word] = sid
350
+ doc_ids.append(sid)
351
+
352
+ elif update_vocab is True:
353
+ sid = len(self.stem_to_sid)
354
+ self.stem_to_sid[stem] = sid
355
+ self.word_to_id[word] = sid
356
+ doc_ids.append(sid)
357
+ else:
358
+ # if we are not using a stemmer, we can just update the word_to_id
359
+ # directly rather than going through the stem_to_sid dictionary
360
+ if update_vocab is True and word not in self.word_to_id:
361
+ wid = len(self.word_to_id)
362
+ self.word_to_id[word] = wid
363
+ doc_ids.append(wid)
364
+
365
+ if len(doc_ids) == 0 and allow_empty is True and "" in self.word_to_id:
366
+ doc_ids = [self.word_to_id[""]]
367
+
368
+ yield doc_ids
369
+
370
+ def tokenize(
371
+ self,
372
+ texts: List[str],
373
+ update_vocab: Union[bool, str] = "if_empty",
374
+ leave_progress: bool = False,
375
+ show_progress: bool = True,
376
+ length: Union[int, None] = None,
377
+ return_as: str = "ids",
378
+ allow_empty: bool = True,
379
+ ) -> Union[List[List[int]], List[List[str]], typing.Generator, Tokenized]:
380
+ """
381
+ Tokenize a list of strings and return the token IDs.
382
+
383
+ Parameters
384
+ ----------
385
+ texts : List[str]
386
+ A list of strings to tokenize.
387
+
388
+ update_vocab : bool, optional
389
+ Whether to update the vocabulary dictionary with the new tokens. If true,
390
+ the different dictionaries making up the vocabulary will be updated with the
391
+ new tokens. If False, the vocabulary will not be updated unless you have a stemmer
392
+ and the stemmed word is in the stem_to_sid dictionary. If update_vocab="if_empty",
393
+ the function will only update the vocabulary if it is empty, i.e. when the
394
+ function is called for the first time, or if the vocabulary has been reset with
395
+ the `reset_vocab` method. If update_vocab="never", the "word_to_id" will never
396
+ be updated, even if the stemmed word is in the stem_to_sid dictionary. Only use
397
+ this if you are sure that the stemmed words are already in the stem_to_sid dictionary.
398
+
399
+ leave_progress : bool, optional
400
+ Whether to leave the progress bar after completion. If False, the progress bar
401
+ will disappear after completion. If True, the progress bar will stay on the screen.
402
+
403
+ show_progress : bool, optional
404
+ Whether to show the progress bar for tokenization. If False, the function will
405
+ not show the progress bar. If True, it will use tqdm.auto to show the progress bar.
406
+
407
+ length : int, optional
408
+ The length of the texts. If None, the function will call `len(texts)` to get the length.
409
+ This is mainly used when `texts` is a generator or a stream instead of a list, in which case
410
+ `len(texts)` will raise a TypeError, and you need to provide the length manually.
411
+
412
+ return_as : str, optional
413
+ The type of object to return by this function.
414
+ If "tuple", this returns a Tokenized namedtuple, which contains the token IDs
415
+ and the vocab dictionary.
416
+ If "string", this return a list of lists of strings, each string being a token.
417
+ If "ids", this return a list of lists of integers corresponding to the token IDs,
418
+ or stemmed IDs if a stemmer is used.
419
+
420
+ allow_empty : bool, optional
421
+ Whether to allow the splitter to return an empty string. If False, the splitter
422
+ will return an empty list, which may cause issues if the tokenizer is not expecting
423
+ an empty list. If True, the splitter will return a list with a single empty string.
424
+
425
+ Returns
426
+ -------
427
+ List[List[int]] or Generator[List[int]] or List[List[str]] or Tokenized object
428
+ If `return_as="stream"`, a Generator[List[int]] is returned, each integer being a token ID.
429
+ If `return_as="ids"`, a List[List[int]] is returned, each integer being a token ID.
430
+ If `return_as="string"`, a List[List[str]] is returned, each string being a token.
431
+ If `return_as="tuple"`, a Tokenized namedtuple is returned, with names `ids` and `vocab`.
432
+ """
433
+ incorrect_return_error = (
434
+ "return_as must be either 'tuple', 'string', 'ids', or 'stream'."
435
+ )
436
+ incorrect_update_vocab_error = (
437
+ "update_vocab must be either True, False, 'if_empty', or 'never'."
438
+ )
439
+ if return_as not in ["tuple", "string", "ids", "stream"]:
440
+ raise ValueError(incorrect_return_error)
441
+
442
+ if update_vocab not in [True, False, "if_empty", "never"]:
443
+ raise ValueError(incorrect_update_vocab_error)
444
+
445
+ if update_vocab == "if_empty":
446
+ update_vocab = len(self.word_to_id) == 0
447
+
448
+ stream_fn = self.streaming_tokenize(
449
+ texts=texts, update_vocab=update_vocab, allow_empty=allow_empty
450
+ )
451
+
452
+ if return_as == "stream":
453
+ return stream_fn
454
+
455
+ if length is None:
456
+ length = len(texts)
457
+
458
+ tqdm_kwargs = dict(
459
+ desc="Tokenize texts",
460
+ leave=leave_progress,
461
+ disable=not show_progress,
462
+ total=length,
463
+ )
464
+
465
+ token_ids = []
466
+ for doc_ids in tqdm(stream_fn, **tqdm_kwargs):
467
+ token_ids.append(doc_ids)
468
+
469
+ if return_as == "ids":
470
+ return token_ids
471
+ elif return_as == "string":
472
+ return self.decode(token_ids)
473
+ elif return_as == "tuple":
474
+ return self.to_tokenized_tuple(token_ids)
475
+ else:
476
+ raise ValueError(incorrect_return_error)
477
+
478
+ def get_vocab_dict(self) -> Dict[str, Any]:
479
+ if self.stemmer is None:
480
+ # if we are not using a stemmer, we return the word_to_id dictionary
481
+ # which maps the words to the word IDs
482
+ return self.word_to_id
483
+ else:
484
+ # if we are using a stemmer, we return the stem_to_sid dictionary,
485
+ # which we will use to map the stemmed words to the stemmed IDs
486
+ return self.stem_to_sid
487
+
488
+ def to_tokenized_tuple(self, docs: List[List[int]]) -> Tokenized:
489
+ """
490
+ Convert the token IDs to a Tokenized namedtuple, which contains the word IDs, or the stemmed IDs
491
+ if a stemmer is used. The Tokenized namedtuple contains two fields: ids and vocab. The latter
492
+ is a dictionary mapping the token IDs to the tokens, or a dictionary mapping the stemmed IDs to
493
+ the stemmed tokens (if a stemmer is used).
494
+ """
495
+ return Tokenized(ids=docs, vocab=self.get_vocab_dict())
496
+
497
+ def decode(self, docs: List[List[int]]) -> List[List[str]]:
498
+ """
499
+ Convert word IDs (or stemmed IDs if a stemmer is used) back to strings using the vocab dictionary,
500
+ which is a dictionary mapping the word IDs to the words or a dictionary mapping the stemmed IDs
501
+ to the stemmed words (if a stemmer is used).
502
+
503
+ Parameters
504
+ ----------
505
+ docs : List[List[int]]
506
+ A list of lists of word IDs or stemmed IDs.
507
+
508
+ Returns
509
+ -------
510
+ List[List[str]]
511
+ A list of lists of strings, each string being a word or a stemmed word if a stemmer is used.
512
+ """
513
+ vocab = self.get_vocab_dict()
514
+ reverse_vocab = {v: k for k, v in vocab.items()}
515
+ return [[reverse_vocab[token_id] for token_id in doc] for doc in docs]
516
+
517
+
518
+ def convert_tokenized_to_string_list(tokenized: Tokenized) -> List[List[str]]:
519
+ """
520
+ Convert the token IDs back to strings using the vocab dictionary.
521
+ """
522
+ reverse_vocab = {v: k for k, v in tokenized.vocab.items()}
523
+
524
+ return [
525
+ [reverse_vocab[token_id] for token_id in doc_ids] for doc_ids in tokenized.ids
526
+ ]
527
+
528
+
529
+ def _infer_stopwords(stopwords: Union[str, List[str]]) -> Union[List[str], tuple]:
530
+ # Source of stopwords: https://github.com/nltk/nltk/blob/96ee715997e1c8d9148b6d8e1b32f412f31c7ff7/nltk/corpus/__init__.py#L315
531
+ if stopwords in ["english", "en", True]: # True is added to support the default
532
+ return STOPWORDS_EN
533
+ elif stopwords in ["english_plus", "en_plus"]:
534
+ return STOPWORDS_EN_PLUS
535
+ elif stopwords in ["german", "de"]:
536
+ return STOPWORDS_GERMAN
537
+ elif stopwords in ["dutch", "nl"]:
538
+ return STOPWORDS_DUTCH
539
+ elif stopwords in ["french", "fr"]:
540
+ return STOPWORDS_FRENCH
541
+ elif stopwords in ["spanish", "es"]:
542
+ return STOPWORDS_SPANISH
543
+ elif stopwords in ["portuguese", "pt"]:
544
+ return STOPWORDS_PORTUGUESE
545
+ elif stopwords in ["italian", "it"]:
546
+ return STOPWORDS_ITALIAN
547
+ elif stopwords in ["russian", "ru"]:
548
+ return STOPWORDS_RUSSIAN
549
+ elif stopwords in ["swedish", "sv"]:
550
+ return STOPWORDS_SWEDISH
551
+ elif stopwords in ["norwegian", "no"]:
552
+ return STOPWORDS_NORWEGIAN
553
+ elif stopwords in ["chinese", "zh"]:
554
+ return STOPWORDS_CHINESE
555
+ elif stopwords in [None, False]:
556
+ return []
557
+ elif isinstance(stopwords, str):
558
+ raise ValueError(
559
+ f"{stopwords} not recognized. Only English stopwords as default, German, Dutch, French, Spanish, Portuguese, Italian, Russian, Swedish, Norwegian, and Chinese are currently supported. "
560
+ "Please input a list of stopwords"
561
+ )
562
+ else:
563
+ return stopwords
564
+
565
+
566
+ def tokenize(
567
+ texts: Union[str, List[str]],
568
+ lower: bool = True,
569
+ token_pattern: str = r"(?u)\b\w\w+\b",
570
+ stopwords: Union[str, List[str]] = "english",
571
+ stemmer: Callable = None, # type: ignore
572
+ return_ids: bool = True,
573
+ show_progress: bool = True,
574
+ leave: bool = False,
575
+ allow_empty: bool = True,
576
+ ) -> Union[List[List[str]], Tokenized]:
577
+ """
578
+ Tokenize a list using the same method as the scikit-learn CountVectorizer,
579
+ and optionally apply a stemmer to the tokens or stopwords removal.
580
+
581
+ If you provide stemmer, it must have a `stemWords` method, or be callable
582
+ that takes a list of strings and returns a list of strings. If your stemmer
583
+ can only be called on a single word, you can use a lambda function to wrap it,
584
+ e.g. `lambda lst: list(map(stemmer.stem, lst))`.
585
+
586
+ If return_ids is True, the function will return a namedtuple with: (1) the tokenized
587
+ IDs and (2) the token_to_index dictionary. You can access the tokenized IDs using
588
+ the `ids` attribute and the token_to_index dictionary using the `vocab` attribute,
589
+ You can also destructure the namedtuple to get the ids and vocab_dict variables,
590
+ e.g. `token_ids, vocab = tokenize(...)`.
591
+
592
+ Parameters
593
+ ----------
594
+ texts : Union[str, List[str]]
595
+ A list of strings to tokenize. If a single string is provided, it will be
596
+ converted to a list with a single element.
597
+
598
+ lower : bool, optional
599
+ Whether to convert the text to lowercase before tokenization
600
+
601
+ token_pattern : str, optional
602
+ The regex pattern to use for tokenization, by default, r"(?u)\\b\\w\\w+\\b"
603
+
604
+ stopwords : Union[str, List[str]], optional
605
+ The list of stopwords to remove from the text. If "english" or "en" is provided,
606
+ the function will use the default English stopwords
607
+
608
+ stemmer : Callable, optional
609
+ The stemmer to use for stemming the tokens. It is recommended
610
+ to use the PyStemmer library for stemming, but you can also any callable that
611
+ takes a list of strings and returns a list of strings.
612
+
613
+ return_ids : bool, optional
614
+ Whether to return the tokenized IDs and the vocab dictionary. If False, the
615
+ function will return the tokenized strings. If True, the function will return
616
+ a namedtuple with the tokenized IDs and the vocab dictionary.
617
+
618
+ show_progress : bool, optional
619
+ Whether to show the progress bar for tokenization. If False, the function will
620
+ not show the progress bar. If True, it will use tqdm.auto to show the progress bar.
621
+
622
+ leave : bool, optional
623
+ Whether to leave the progress bar after completion. If False, the progress bar
624
+ will disappear after completion. If True, the progress bar will stay on the screen.
625
+
626
+ allow_empty : bool, optional
627
+ Whether to allow the splitter to return an empty string. If False, the splitter
628
+ will return an empty list, which may cause issues if the tokenizer is not expecting
629
+ an empty list. If True, the splitter will return a list with a single empty string.
630
+ Note
631
+ -----
632
+ You may pass a single string or a list of strings. If you pass a single string,
633
+ this function will convert it to a list of strings with a single element.
634
+ """
635
+ if isinstance(texts, str):
636
+ texts = [texts]
637
+
638
+ split_fn = re.compile(token_pattern).findall
639
+ stopwords = _infer_stopwords(stopwords)
640
+
641
+ # Step 1: Split the strings using the regex pattern
642
+ corpus_ids = []
643
+ token_to_index = {}
644
+
645
+ for text in tqdm(
646
+ texts, desc="Split strings", leave=leave, disable=not show_progress
647
+ ):
648
+ stopwords_set = set(stopwords)
649
+ if lower:
650
+ text = text.lower()
651
+
652
+ splitted = split_fn(text)
653
+
654
+ if allow_empty is False and len(splitted) == 0:
655
+ splitted = [""]
656
+
657
+ doc_ids = []
658
+
659
+ for token in splitted:
660
+ if token in stopwords_set:
661
+ continue
662
+
663
+ if token not in token_to_index:
664
+ token_to_index[token] = len(token_to_index)
665
+
666
+ token_id = token_to_index[token]
667
+ doc_ids.append(token_id)
668
+
669
+ corpus_ids.append(doc_ids)
670
+
671
+ # Create a list of unique tokens that we will use to create the vocabulary
672
+ unique_tokens = list(token_to_index.keys())
673
+
674
+ # Step 2: Stem the tokens if a stemmer is provided
675
+ if stemmer is not None:
676
+ if hasattr(stemmer, "stemWords"):
677
+ stemmer_fn = stemmer.stemWords
678
+ elif callable(stemmer):
679
+ stemmer_fn = stemmer
680
+ else:
681
+ error_msg = "Stemmer must have a `stemWord` method, or be callable. For example, you can use the PyStemmer library."
682
+ raise ValueError(error_msg)
683
+
684
+ # Now, we use the stemmer on the token_to_index dictionary to get the stemmed tokens
685
+ tokens_stemmed = stemmer_fn(unique_tokens)
686
+ vocab = set(tokens_stemmed)
687
+ vocab_dict = {token: i for i, token in enumerate(vocab)}
688
+ stem_id_to_stem = {v: k for k, v in vocab_dict.items()}
689
+ # We create a dictionary mapping the stemmed tokens to their index
690
+ doc_id_to_stem_id = {
691
+ token_to_index[token]: vocab_dict[stem]
692
+ for token, stem in zip(unique_tokens, tokens_stemmed)
693
+ }
694
+
695
+ # Now, we simply need to replace the tokens in the corpus with the stemmed tokens
696
+ for i, doc_ids in enumerate(
697
+ tqdm(corpus_ids, desc="Stem Tokens", leave=leave, disable=not show_progress)
698
+ ):
699
+ corpus_ids[i] = [doc_id_to_stem_id[doc_id] for doc_id in doc_ids]
700
+ else:
701
+ vocab_dict = token_to_index
702
+
703
+ # Step 3: Return the tokenized IDs and the vocab dictionary or the tokenized strings
704
+ if return_ids:
705
+ return Tokenized(ids=corpus_ids, vocab=vocab_dict)
706
+ else:
707
+ # We need a reverse dictionary to convert the token IDs back to tokens
708
+ reverse_dict = stem_id_to_stem if stemmer is not None else unique_tokens
709
+ # We convert the token IDs back to tokens in-place
710
+ for i, token_ids in enumerate(
711
+ tqdm(
712
+ corpus_ids,
713
+ desc="Reconstructing token strings",
714
+ leave=leave,
715
+ disable=not show_progress,
716
+ )
717
+ ):
718
+ corpus_ids[i] = [reverse_dict[token_id] for token_id in token_ids]
719
+ return corpus_ids