File size: 1,682 Bytes
e67043b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from ...tool import Tool


def build_tool(config) -> Tool:
    tool = Tool(
        "Translator Info",
        "Translate a given text from one language to another.",
        name_for_model="Translator",
        description_for_model="Plugin for translating text from one language to another.",
        logo_url="https://your-app-url.com/.well-known/logo.png",
        contact_email="[email protected]",
        legal_info_url="[email protected]",
    )

    BASE_MODEL = (
        config["model"] if "model" in config else "facebook/nllb-200-distilled-600M"
    )
    SRC_LANG = config["src_lang"] if "src_lang" in config else "eng_Latn"
    tokenizer = AutoTokenizer.from_pretrained(
        BASE_MODEL, use_auth_token=True, src_lang=SRC_LANG
    )
    model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL, use_auth_token=True)

    @tool.get("/get_translation")
    def get_translation(
        input_text: str or list, tgt_lang: str, max_length: int
    ) -> str or list:
        inputs = tokenizer(input_text, return_tensors="pt", padding=True)
        translated_tokens = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang],
            max_length=max_length
        )
        if isinstance(input_text, str):
            translations = tokenizer.batch_decode(
                translated_tokens, skip_special_tokens=True
            )[0]
        elif isinstance(input_text, list):
            translations = tokenizer.batch_decode(
                translated_tokens, skip_special_tokens=True
            )
        return translations

    return tool