File size: 1,366 Bytes
627060a
19748f6
 
 
 
 
 
 
9b5b26a
627060a
 
 
 
 
 
 
 
 
 
 
 
 
8559392
 
 
 
 
26ea95c
 
8559392
26ea95c
 
 
 
 
 
 
 
627060a
8559392
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def agent_translate(text: str) -> str:
    """A tool that translate text with next logic:
        1) If text starts with 'PL:', translate the rest to Polish.
        2) If detected language == 'ru', translate to English.
        3) Else translate to Russian.
        and return translated text
    Args:
        text: text in russian/english or polish languages.
    """
    if text.startswith("PL:"):
        original = text[3:].strip()
        target_lang = "PL"
    else:
        detected_lang = detect(text)
        if detected_lang == "ru":
            target_lang = "EN"
            original = text
        else:
            target_lang = "RU"
            original = text

    try:
        model_name = "facebook/m2m100_418M"
        tokenizer = M2M100Tokenizer.from_pretrained(model_name)
        translation_model = M2M100ForConditionalGeneration.from_pretrained(model_name)


        inputs = tokenizer(original, return_tensors="pt")

        forced_bos_token_id = tokenizer.get_lang_id(target_lang.lower())

        generated_tokens = translation_model.generate(
            **inputs,
            forced_bos_token_id=forced_bos_token_id
        )

        translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
        return translated_text
    except Exception as e:
        return f"Error during translating: { str(e)}"