|
from logging import getLogger |
|
from llama_cpp import Llama |
|
from functools import lru_cache |
|
|
|
logger = getLogger(__name__) |
|
|
|
class QwenTranslator: |
|
def __init__(self, model_path, system_prompt_en="", system_prompt_zh="") -> None: |
|
self.llm = Llama( |
|
model_path=model_path, |
|
chat_format="chatml", |
|
verbose=False) |
|
self.sys_prompt_en = system_prompt_en |
|
self.sys_prompt_zh = system_prompt_zh |
|
|
|
def to_message(self, prompt, src_lang, dst_lang): |
|
"""构造提示词""" |
|
return [ |
|
{"role": "system", "content": self.sys_prompt_en if src_lang == "en" else self.sys_prompt_zh}, |
|
{"role": "user", "content": prompt}, |
|
] |
|
|
|
@lru_cache(maxsize=10) |
|
def translate(self, prompt, src_lang, dst_lang) -> str: |
|
message = self.to_message(prompt, src_lang, dst_lang) |
|
output = self.llm.create_chat_completion(messages=message, temperature=0) |
|
return output['choices'][0]['message']['content'] |
|
|
|
def __call__(self, prompt,*args, **kwargs): |
|
return self.llm( |
|
prompt, |
|
*args, |
|
**kwargs |
|
) |
|
|