christian-compass / src /llama_cpp_chat_engine.py
nick-konovalchuk's picture
init
adc37f1
from llama_cpp import Llama
from llama_cpp.llama_chat_format import Jinja2ChatFormatter
from src.const import SYSTEM_PROMPT
class LlamaCPPChatEngine:
def __init__(self, model_path):
self._model = Llama(
model_path=model_path,
n_ctx=0,
verbose=False,
)
self.n_ctx = self._model.context_params.n_ctx
self._eos_token = self._model._model.token_get_text(
int(self._model.metadata['tokenizer.ggml.eos_token_id'])
)
self._formatter = Jinja2ChatFormatter(
template=self._model.metadata['tokenizer.chat_template'],
bos_token=self._model._model.token_get_text(
int(self._model.metadata['tokenizer.ggml.bos_token_id'])
),
eos_token=self._eos_token,
stop_token_ids=self._model.metadata['tokenizer.ggml.eos_token_id']
)
self._tokenizer = self._model.tokenizer()
def chat(self, messages, user_message, context):
if context:
user_message_extended = "\n".join(context + [f"Question: {user_message}"])
else:
user_message_extended = user_message
messages = (
[
{
"role": "system",
"content": SYSTEM_PROMPT
}
] + messages + [
{
"role": "user",
"content": user_message_extended,
}
]
)
prompt = self._formatter(messages=messages).prompt
tokens = self._tokenizer.encode(prompt, add_bos=False)
n_tokens = len(tokens)
response_generator = self._model.create_completion(
tokens,
stop=self._eos_token,
max_tokens=self.n_ctx - n_tokens,
stream=True,
temperature=0
)
return response_generator, n_tokens