AriNubar commited on
Commit
068f983
·
verified ·
1 Parent(s): 2f130e3

Didnt work, reverting

Browse files
Files changed (1) hide show
  1. translation.py +13 -31
translation.py CHANGED
@@ -120,10 +120,9 @@ class Translator:
120
  self.hyw_splitter = pysbd.Segmenter(language="hy", clean=False)
121
  self.eng_splitter = pysbd.Segmenter(language="en", clean=False)
122
  self.languages = LANGUAGES
123
- self.BATCH_SIZE = 8
124
 
125
 
126
- def translate_batch(
127
  self,
128
  text,
129
  src_lang,
@@ -135,7 +134,7 @@ class Translator:
135
  ):
136
  self.tokenizer.src_lang = src_lang
137
  encoded = self.tokenizer(
138
- text, return_tensors="pt", truncation=True, max_length=256, padding=True,
139
  )
140
  if max_length == "auto":
141
  max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
@@ -164,42 +163,25 @@ class Translator:
164
 
165
  if by_sentence:
166
  if src_lang =="eng_Latn":
167
- # sents, fillers = sentenize_with_fillers(text, self.eng_splitter, ignore_errors=True)
168
- sentences = self.eng_splitter.segment(text)
169
  elif src_lang == "hyw_Armn":
170
- # sents, fillers = sentenize_with_fillers(text, self.hyw_splitter, ignore_errors=True)
171
- sentences = self.hyw_splitter.segment(text)
172
-
173
  else:
174
- sentences = [text]
175
- # fillers = ["", ""]
176
 
177
  if clean:
178
- sentences = [clean_text(sent, src_lang) for sent in sentences]
179
 
180
-
181
- num_batches = len(sentences) // self.BATCH_SIZE
182
- if len(sentences) % self.BATCH_SIZE != 0:
183
- num_batches += 1
184
-
185
  results = []
186
-
187
- for batch_num in range(num_batches):
188
- start = batch_num * self.BATCH_SIZE
189
- end = start + self.BATCH_SIZE
190
- batch = sentences[start:end]
191
- translated = self.translate_batch(batch, src_lang, tgt_lang)
192
- results.extend(translated)
193
- return " ".join(results).strip()
194
-
195
-
196
- # for sent, sep in zip(sents, fillers):
197
- # results.append(sep)
198
- # results.append(self.translate_batch(sent, src_lang, tgt_lang, max_length, num_beams, **kwargs))
199
 
200
- # results.append(fillers[-1])
201
 
202
- # return " ".join(results)
203
 
204
  if __name__ == "__main__":
205
  print("Initializing translator...")
 
120
  self.hyw_splitter = pysbd.Segmenter(language="hy", clean=False)
121
  self.eng_splitter = pysbd.Segmenter(language="en", clean=False)
122
  self.languages = LANGUAGES
 
123
 
124
 
125
+ def translate_single(
126
  self,
127
  text,
128
  src_lang,
 
134
  ):
135
  self.tokenizer.src_lang = src_lang
136
  encoded = self.tokenizer(
137
+ text, return_tensors="pt", truncation=True, max_length=256
138
  )
139
  if max_length == "auto":
140
  max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
 
163
 
164
  if by_sentence:
165
  if src_lang =="eng_Latn":
166
+ sents, fillers = sentenize_with_fillers(text, self.eng_splitter, ignore_errors=True)
 
167
  elif src_lang == "hyw_Armn":
168
+ sents, fillers = sentenize_with_fillers(text, self.hyw_splitter, ignore_errors=True)
169
+
 
170
  else:
171
+ sents = [text]
172
+ fillers = ["", ""]
173
 
174
  if clean:
175
+ sents = [clean_text(sent, src_lang) for sent in sents]
176
 
 
 
 
 
 
177
  results = []
178
+ for sent, sep in zip(sents, fillers):
179
+ results.append(sep)
180
+ results.append(self.translate_single(sent, src_lang, tgt_lang, max_length, num_beams, **kwargs))
 
 
 
 
 
 
 
 
 
 
181
 
182
+ results.append(fillers[-1])
183
 
184
+ return " ".join(results)
185
 
186
  if __name__ == "__main__":
187
  print("Initializing translator...")