JoeShingleton commited on
Commit
888d016
·
verified ·
1 Parent(s): 696f1c8

Upload translator.py

Browse files

Added translation module

Files changed (1) hide show
  1. translator.py +88 -0
translator.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
2
+ from lingua import LanguageDetectorBuilder, Language
3
+
4
+ class Translator:
5
+
6
+ def __init__(self, languages:list=None, model_size:str='418M'):
7
+ """Detects and translates text into a required language, using the
8
+ M2M100 model and the Lingua package. If the language is being detected
9
+ from a pool of possible languages these can be stated to improve
10
+ computational efficiency, otherwise leave blank to translate from any
11
+ language.
12
+
13
+ Args:
14
+ languages (list, optional): A list of potential source languages as
15
+ ISO-639-1 codes. Leave as None if source language is unknown.
16
+ Defaults to None.
17
+ model_str (str, optional): The model being used. Can be '418M' or
18
+ '1.2B'. Defaults to '418M'.
19
+ """
20
+ if languages:
21
+ self.languages = [getattr(Language, l.upper()) for l in languages]
22
+ else:
23
+ self.languages = None
24
+
25
+ self.detector = self.get_detector()
26
+ self.model_str = f'facebook/m2m100_{model_size}'
27
+ self.model = M2M100ForConditionalGeneration.from_pretrained(self.model_str)
28
+
29
+ def get_detector(self)-> LanguageDetectorBuilder:
30
+ """Retrieves the language detection model. If a list of potential
31
+ languages has been provided in the class initialisation then the
32
+ detector will chose from those classes.
33
+
34
+ Returns:
35
+ LanguageDetectorBuilder: initialised laguage detection model.
36
+ """
37
+ if self.languages:
38
+ detector = LanguageDetectorBuilder.from_iso_codes_639_1(*self.languages)
39
+ else:
40
+ detector = LanguageDetectorBuilder.from_all_languages()
41
+
42
+ return detector.build()
43
+
44
+ def translate(self, text:str, out_lang:str)->str:
45
+ """translates text to the language defined by out_lang. Source language
46
+ is detected automatically.
47
+
48
+ Args:
49
+ text (str): text to be translated
50
+ out_lang (str): ISO Code 639-1 of target language (e.g. "en")
51
+
52
+ Returns:
53
+ str: translated text in out_lang
54
+ """
55
+ src_lang = self.detect_language(text)
56
+ src_tokenizer = self.get_tokenizer(src_lang)
57
+ src_tokens = src_tokenizer(text, return_tensors='pt')
58
+ out_tokens = self.model.generate(**src_tokens, forced_bos_token_id=src_tokenizer.get_lang_id(out_lang))
59
+ out_text = src_tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
60
+
61
+ return {'lanuage':src_lang, 'translation':out_text}
62
+
63
+ def get_tokenizer(self, src_lang:str)->M2M100Tokenizer:
64
+ """Retrieves the tokenizer in the required source language. If the
65
+
66
+ Args:
67
+ src_lang (str): ISO0-639-1 country code
68
+
69
+ Returns:
70
+ M2M100Tokenizer: _description_
71
+ """
72
+ try:
73
+ return M2M100Tokenizer.from_pretrained(self.model_str, src_lang=src_lang)
74
+ except:
75
+ return M2M100Tokenizer.from_pretrained(self.model_str)
76
+
77
+
78
+ def detect_language(self, text:str)-> str:
79
+ """USes the Lingua package to detect the language of the text.
80
+
81
+ Args:
82
+ text (str): text to be analyzed.
83
+
84
+ Returns:
85
+ str: iso-639-1 code of the detected language.
86
+ """
87
+ lang = self.detector.detect_language_of(text)
88
+ return lang.iso_code_639_1.name.lower()