saiga-api-cuda-v2 / llm_backend.py
muryshev's picture
fix
e1080e8
raw
history blame
5.37 kB
from llama_cpp import Llama
import gc
import threading
class LlmBackend:
SYSTEM_PROMPT = "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык."
SYSTEM_TOKEN = 1788
USER_TOKEN = 1404
BOT_TOKEN = 9225
LINEBREAK_TOKEN = 13
ROLE_TOKENS = {
"user": USER_TOKEN,
"bot": BOT_TOKEN,
"system": SYSTEM_TOKEN
}
_instance = None
_model = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super(LlmBackend, cls).__new__(cls)
return cls._instance
def is_model_loaded(self):
return self._model is not None
def load_model(self, model_path, context_size=2000, enable_gpu=True, gpu_layer_number=35, n_gqa=8, chat_format='llama-2'):
if self._model is not None:
self.unload_model()
with self._lock:
if enable_gpu:
self._model = Llama(
model_path=model_path,
chat_format=chat_format,
n_ctx=context_size,
n_parts=1,
#n_batch=100,
logits_all=True,
#n_threads=12,
verbose=True,
n_gpu_layers=gpu_layer_number,
n_gqa=n_gqa #must be set for 70b models
)
return self._model
else:
self._model = Llama(
model_path=model_path,
chat_format=chat_format,
n_ctx=context_size,
n_parts=1,
#n_batch=100,
logits_all=True,
#n_threads=12,
verbose=True,
n_gqa=n_gqa #must be set for 70b models
)
return self._model
def set_system_prompt(self, prompt):
with self._lock:
self.SYSTEM_PROMPT = prompt
def unload_model(self):
with self._lock:
if self._model is not None:
self._model.llama_free_model()
del self._model
def generate_tokens(self, generator):
print('generate_tokens called')
with self._lock:
print('generate_tokens started')
try:
for token in generator:
if token == self._model.token_eos():
print('End generating')
yield b'' # End of chunk
break
token_str = self._model.detokenize([token])#.decode("utf-8", errors="ignore")
yield token_str
except Exception as e:
print('generator exception')
print(e)
yield b'' # End of chunk
def create_chat_completion(self, messages, stream=True):
print('create_chat_completion called')
with self._lock:
print('create_chat_completion started')
try:
return self._model.create_chat_completion(messages=messages, stream=stream)
except Exception as e:
print('create_chat_completion exception')
print(e)
return None
def get_message_tokens(self, role, content):
message_tokens = self._model.tokenize(content.encode("utf-8"))
message_tokens.insert(1, self.ROLE_TOKENS[role])
message_tokens.insert(2, self.LINEBREAK_TOKEN)
message_tokens.append(self._model.token_eos())
return message_tokens
def get_system_tokens(self):
return self.get_message_tokens(role="system", content=self.SYSTEM_PROMPT)
def create_chat_generator_for_saiga(self, messages, parameters):
print('create_chat_completion called')
with self._lock:
tokens = self.get_system_tokens()
for message in messages:
message_tokens = self.get_message_tokens(role=message.get("from"), content=message.get("content", ""))
tokens.extend(message_tokens)
tokens.extend([self._model.token_bos(), self.BOT_TOKEN, self.LINEBREAK_TOKEN])
generator = self._model.generate(
tokens,
top_k=parameters['top_k'],
top_p=parameters['top_p'],
temp=parameters['temperature'],
repeat_penalty=parameters['repetition_penalty']
)
return generator
def generate_tokens(self, generator):
with self._lock:
try:
for token in generator:
if token == self._model.token_eos():
yield b'' # End of chunk
break
token_str = self._model.detokenize([token])#.decode("utf-8", errors="ignore")
yield token_str
except Exception as e:
yield b'' # End of chunk