Spaces:
Sleeping
Sleeping
from llama_cpp import Llama | |
from llama_cpp.llama_chat_format import Jinja2ChatFormatter | |
from src.const import SYSTEM_PROMPT | |
class LlamaCPPChatEngine: | |
def __init__(self, model_path): | |
self._model = Llama( | |
model_path=model_path, | |
n_ctx=0, | |
verbose=False, | |
) | |
self.n_ctx = self._model.context_params.n_ctx | |
self._eos_token = self._model._model.token_get_text( | |
int(self._model.metadata['tokenizer.ggml.eos_token_id']) | |
) | |
self._formatter = Jinja2ChatFormatter( | |
template=self._model.metadata['tokenizer.chat_template'], | |
bos_token=self._model._model.token_get_text( | |
int(self._model.metadata['tokenizer.ggml.bos_token_id']) | |
), | |
eos_token=self._eos_token, | |
stop_token_ids=self._model.metadata['tokenizer.ggml.eos_token_id'] | |
) | |
self._tokenizer = self._model.tokenizer() | |
def chat(self, messages, user_message, context): | |
if context: | |
user_message_extended = "\n".join(context + [f"Question: {user_message}"]) | |
else: | |
user_message_extended = user_message | |
messages = ( | |
[ | |
{ | |
"role": "system", | |
"content": SYSTEM_PROMPT | |
} | |
] + messages + [ | |
{ | |
"role": "user", | |
"content": user_message_extended, | |
} | |
] | |
) | |
prompt = self._formatter(messages=messages).prompt | |
tokens = self._tokenizer.encode(prompt, add_bos=False) | |
n_tokens = len(tokens) | |
response_generator = self._model.create_completion( | |
tokens, | |
stop=self._eos_token, | |
max_tokens=self.n_ctx - n_tokens, | |
stream=True, | |
temperature=0 | |
) | |
return response_generator, n_tokens | |