Spaces:
Paused
Paused
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM | |
import gradio as gr | |
model = AutoModelForSeq2SeqLM.from_pretrained('alimboff/nllb-200-kbd')#.cpu() | |
tokenizer = NllbTokenizer.from_pretrained('alimboff/nllb-200-kbd') | |
def fix_tokenizer(tokenizer, new_lang='kbd_Cyrl'): | |
""" | |
Add a new language token to the tokenizer vocabulary | |
(this should be done each time after its initialization) | |
""" | |
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder) | |
tokenizer.lang_code_to_id[new_lang] = old_len-1 | |
tokenizer.id_to_lang_code[old_len-1] = new_lang | |
# always move "mask" to the last position | |
tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset | |
tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id) | |
tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()} | |
if new_lang not in tokenizer._additional_special_tokens: | |
tokenizer._additional_special_tokens.append(new_lang) | |
# clear the added token encoder; otherwise a new token may end up there by mistake | |
tokenizer.added_tokens_encoder = {} | |
tokenizer.added_tokens_decoder = {} | |
fix_tokenizer(tokenizer) | |
language_codes = { | |
"Кабардино-Черкесский": "kbd_Cyrl", | |
"Русский": "rus_Cyrl" | |
} | |
def translate( | |
text, input_language, output_language, | |
a=32, b=3, max_input_length=1024, num_beams=8, **kwargs | |
): | |
src_lang = language_codes[input_language] | |
tgt_lang = language_codes[output_language] | |
"""Turn a text or a list of texts into a list of translations""" | |
tokenizer.src_lang = src_lang | |
tokenizer.tgt_lang = tgt_lang | |
inputs = tokenizer( | |
text, return_tensors='pt', padding=True, truncation=True, | |
max_length=max_input_length | |
) | |
model.eval() # turn off training mode | |
result = model.generate( | |
**inputs.to(model.device), | |
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang), | |
max_new_tokens=int(a + b * inputs.input_ids.shape[1]), | |
num_beams=num_beams, **kwargs | |
) | |
return tokenizer.batch_decode(result, skip_special_tokens=True)[0] #без [0] | |
with gr.Blocks() as demo: | |
gr.Markdown("### Переводчик через ИИ") | |
with gr.Row(): | |
input_language = gr.Radio(choices=["Кабардино-Черкесский", "Русский"], label="Выберите язык исходного текста", value="Кабардино-Черкесский") | |
output_language = gr.Radio(choices=["Кабардино-Черкесский", "Русский"], label="Выберите язык для перевода", value="Русский") | |
with gr.Row(): | |
text_input = gr.Textbox(label="Введите текст для перевода") | |
text_output = gr.Textbox(label="Перевод", interactive=False) | |
with gr.Row(): | |
translate_button = gr.Button("Перевести") | |
translate_button.click( | |
fn=translate, | |
inputs=[text_input, input_language, output_language], | |
outputs=text_output | |
) | |
demo.launch() | |
# # Example usage: | |
# # Ӏ | |
# t = 'пэшым лӀы зыбжанэ щӀэсщ' | |
# kbdru = translate(t, 'kbd_Cyrl', 'rus_Cyrl') | |
# rukbd = translate(kbdru, 'rus_Cyrl', 'kbd_Cyrl') | |
# print(kbdru) | |
# print(rukbd) |