Fabriwin commited on
Commit
5677700
·
verified ·
1 Parent(s): 3175dca

Upload bambara_utils.py

Browse files
Files changed (1) hide show
  1. bambara_utils.py +50 -0
bambara_utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from tokenizers import AddedToken
4
+ from transformers import WhisperTokenizer, WhisperProcessor
5
+ import transformers.models.whisper.tokenization_whisper as whisper_tokenization
6
+ from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, TASK_IDS
7
+
8
+ CUSTOM_TO_LANGUAGE_CODE = {**TO_LANGUAGE_CODE, "bambara": "bm"}
9
+
10
+ # Note: We update the whisper tokenizer constants. Not ideal but at least it works
11
+ whisper_tokenization.TO_LANGUAGE_CODE.update(CUSTOM_TO_LANGUAGE_CODE)
12
+
13
+
14
+ class BambaraWhisperTokenizer(WhisperTokenizer):
15
+ def __init__(self, *args, **kwargs):
16
+ super().__init__(*args, **kwargs)
17
+ self.add_tokens(AddedToken(content="<|bm|>", lstrip=False, rstrip=False, normalized=False, special=True))
18
+
19
+ @property
20
+ def prefix_tokens(self) -> List[int]:
21
+ bos_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
22
+ translate_token_id = self.convert_tokens_to_ids("<|translate|>")
23
+ transcribe_token_id = self.convert_tokens_to_ids("<|transcribe|>")
24
+ notimestamps_token_id = self.convert_tokens_to_ids("<|notimestamps|>")
25
+
26
+ if self.language is not None:
27
+ self.language = self.language.lower()
28
+ if self.language in CUSTOM_TO_LANGUAGE_CODE:
29
+ language_id = CUSTOM_TO_LANGUAGE_CODE[self.language]
30
+ elif self.language in CUSTOM_TO_LANGUAGE_CODE.values():
31
+ language_id = self.language
32
+ else:
33
+ is_language_code = len(self.language) == 2
34
+ raise ValueError(
35
+ f"Unsupported language: {self.language}. Language should be one of:"
36
+ f" {list(CUSTOM_TO_LANGUAGE_CODE.values()) if is_language_code else list(CUSTOM_TO_LANGUAGE_CODE.keys())}."
37
+ )
38
+
39
+ if self.task is not None:
40
+ if self.task not in TASK_IDS:
41
+ raise ValueError(f"Unsupported task: {self.task}. Task should be in: {TASK_IDS}")
42
+
43
+ bos_sequence = [bos_token_id]
44
+ if self.language is not None:
45
+ bos_sequence.append(self.convert_tokens_to_ids(f"<|{language_id}|>"))
46
+ if self.task is not None:
47
+ bos_sequence.append(transcribe_token_id if self.task == "transcribe" else translate_token_id)
48
+ if not self.predict_timestamps:
49
+ bos_sequence.append(notimestamps_token_id)
50
+ return bos_sequence