TSjB commited on
Commit
f3fcd73
·
verified ·
1 Parent(s): e8cb8c1

Add supporting several sentences

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -6,6 +6,7 @@ from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
6
  import pandas as pd
7
  import random
8
  import string
 
9
 
10
  # 2. Constants
11
  # Translation
@@ -56,7 +57,7 @@ dictionary_ru = dictionary[dictionary.til == "rus"]
56
  # Tranlation
57
  tokenizer = NllbTokenizer.from_pretrained(MODEL_TRANSLATE_PATH)
58
  model_translate = AutoModelForSeq2SeqLM.from_pretrained(MODEL_TRANSLATE_PATH)
59
-
60
  # TTS
61
  model_tts, _ = torch.hub.load(repo_or_dir = REPO_TTS_PATH,
62
  model = MODEL_TTS_PATH,
@@ -376,14 +377,14 @@ def translatePy(text, src_lang='rus_Cyrl', tgt_lang='krc_Cyrl',
376
  text, return_tensors='pt', padding=True, truncation=True,
377
  max_length=max_input_length
378
  )
379
- model_translate.eval() # turn off training mode
380
  result = model_translate.generate(
381
  **inputs.to(model_translate.device),
382
  forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
383
  max_new_tokens=int(a + b * inputs.input_ids.shape[1]),
384
  num_beams=num_beams, **kwargs
385
  )
386
- return tokenizer.batch_decode(result, skip_special_tokens=True)[0]
387
 
388
 
389
  def translateDisp(text, from_, to, dialect):
@@ -405,8 +406,11 @@ def translateDisp(text, from_, to, dialect):
405
  if from_ == 'krc_Cyrl':
406
  text = toModel(text)
407
 
 
 
408
  str_ = translatePy(text, src_lang = from_, tgt_lang = to)
409
-
 
410
  if to == 'krc_Cyrl':
411
  str_ = fromModel(str_, dialect = dialect)
412
 
 
6
  import pandas as pd
7
  import random
8
  import string
9
+ import re
10
 
11
  # 2. Constants
12
  # Translation
 
57
  # Tranlation
58
  tokenizer = NllbTokenizer.from_pretrained(MODEL_TRANSLATE_PATH)
59
  model_translate = AutoModelForSeq2SeqLM.from_pretrained(MODEL_TRANSLATE_PATH)
60
+ model_translate.eval() # turn off training mode
61
  # TTS
62
  model_tts, _ = torch.hub.load(repo_or_dir = REPO_TTS_PATH,
63
  model = MODEL_TTS_PATH,
 
377
  text, return_tensors='pt', padding=True, truncation=True,
378
  max_length=max_input_length
379
  )
380
+
381
  result = model_translate.generate(
382
  **inputs.to(model_translate.device),
383
  forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
384
  max_new_tokens=int(a + b * inputs.input_ids.shape[1]),
385
  num_beams=num_beams, **kwargs
386
  )
387
+ return tokenizer.batch_decode(result, skip_special_tokens=True)
388
 
389
 
390
  def translateDisp(text, from_, to, dialect):
 
406
  if from_ == 'krc_Cyrl':
407
  text = toModel(text)
408
 
409
+ # Разбиваем текст на предложения, сохраняя знаки препинания
410
+ text = re.findall(r'.+?[.!?\n](?:\s|$)', text)
411
  str_ = translatePy(text, src_lang = from_, tgt_lang = to)
412
+ str_ = ' '.join(str_).strip()
413
+
414
  if to == 'krc_Cyrl':
415
  str_ = fromModel(str_, dialect = dialect)
416