Spaces:
Running
Running
Add supporting several sentences
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
|
|
6 |
import pandas as pd
|
7 |
import random
|
8 |
import string
|
|
|
9 |
|
10 |
# 2. Constants
|
11 |
# Translation
|
@@ -56,7 +57,7 @@ dictionary_ru = dictionary[dictionary.til == "rus"]
|
|
56 |
# Tranlation
|
57 |
tokenizer = NllbTokenizer.from_pretrained(MODEL_TRANSLATE_PATH)
|
58 |
model_translate = AutoModelForSeq2SeqLM.from_pretrained(MODEL_TRANSLATE_PATH)
|
59 |
-
|
60 |
# TTS
|
61 |
model_tts, _ = torch.hub.load(repo_or_dir = REPO_TTS_PATH,
|
62 |
model = MODEL_TTS_PATH,
|
@@ -376,14 +377,14 @@ def translatePy(text, src_lang='rus_Cyrl', tgt_lang='krc_Cyrl',
|
|
376 |
text, return_tensors='pt', padding=True, truncation=True,
|
377 |
max_length=max_input_length
|
378 |
)
|
379 |
-
|
380 |
result = model_translate.generate(
|
381 |
**inputs.to(model_translate.device),
|
382 |
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
|
383 |
max_new_tokens=int(a + b * inputs.input_ids.shape[1]),
|
384 |
num_beams=num_beams, **kwargs
|
385 |
)
|
386 |
-
return tokenizer.batch_decode(result, skip_special_tokens=True)
|
387 |
|
388 |
|
389 |
def translateDisp(text, from_, to, dialect):
|
@@ -405,8 +406,11 @@ def translateDisp(text, from_, to, dialect):
|
|
405 |
if from_ == 'krc_Cyrl':
|
406 |
text = toModel(text)
|
407 |
|
|
|
|
|
408 |
str_ = translatePy(text, src_lang = from_, tgt_lang = to)
|
409 |
-
|
|
|
410 |
if to == 'krc_Cyrl':
|
411 |
str_ = fromModel(str_, dialect = dialect)
|
412 |
|
|
|
6 |
import pandas as pd
|
7 |
import random
|
8 |
import string
|
9 |
+
import re
|
10 |
|
11 |
# 2. Constants
|
12 |
# Translation
|
|
|
57 |
# Tranlation
|
58 |
tokenizer = NllbTokenizer.from_pretrained(MODEL_TRANSLATE_PATH)
|
59 |
model_translate = AutoModelForSeq2SeqLM.from_pretrained(MODEL_TRANSLATE_PATH)
|
60 |
+
model_translate.eval() # turn off training mode
|
61 |
# TTS
|
62 |
model_tts, _ = torch.hub.load(repo_or_dir = REPO_TTS_PATH,
|
63 |
model = MODEL_TTS_PATH,
|
|
|
377 |
text, return_tensors='pt', padding=True, truncation=True,
|
378 |
max_length=max_input_length
|
379 |
)
|
380 |
+
|
381 |
result = model_translate.generate(
|
382 |
**inputs.to(model_translate.device),
|
383 |
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
|
384 |
max_new_tokens=int(a + b * inputs.input_ids.shape[1]),
|
385 |
num_beams=num_beams, **kwargs
|
386 |
)
|
387 |
+
return tokenizer.batch_decode(result, skip_special_tokens=True)
|
388 |
|
389 |
|
390 |
def translateDisp(text, from_, to, dialect):
|
|
|
406 |
if from_ == 'krc_Cyrl':
|
407 |
text = toModel(text)
|
408 |
|
409 |
+
# Разбиваем текст на предложения, сохраняя знаки препинания
|
410 |
+
text = re.findall(r'.+?[.!?\n](?:\s|$)', text)
|
411 |
str_ = translatePy(text, src_lang = from_, tgt_lang = to)
|
412 |
+
str_ = ' '.join(str_).strip()
|
413 |
+
|
414 |
if to == 'krc_Cyrl':
|
415 |
str_ = fromModel(str_, dialect = dialect)
|
416 |
|