|
import re |
|
from functools import partial |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from modules import RoPE, shared |
|
from modules.callbacks import Iteratorize |
|
from modules.logging_colors import logger |
|
from modules.text_generation import get_max_prompt_length |
|
|
|
try: |
|
import llama_cpp |
|
except: |
|
llama_cpp = None |
|
|
|
try: |
|
import llama_cpp_cuda |
|
except: |
|
llama_cpp_cuda = None |
|
|
|
|
|
def llama_cpp_lib(): |
|
if (shared.args.cpu and llama_cpp is not None) or llama_cpp_cuda is None: |
|
return llama_cpp |
|
else: |
|
return llama_cpp_cuda |
|
|
|
|
|
def ban_eos_logits_processor(eos_token, input_ids, logits): |
|
logits[eos_token] = -float('inf') |
|
return logits |
|
|
|
|
|
def custom_token_ban_logits_processor(token_ids, input_ids, logits): |
|
for token_id in token_ids: |
|
logits[token_id] = -float('inf') |
|
|
|
return logits |
|
|
|
|
|
class LlamaCppModel: |
|
def __init__(self): |
|
self.initialized = False |
|
self.grammar_string = '' |
|
self.grammar = None |
|
|
|
def __del__(self): |
|
self.model.__del__() |
|
|
|
@classmethod |
|
def from_pretrained(self, path): |
|
|
|
Llama = llama_cpp_lib().Llama |
|
LlamaCache = llama_cpp_lib().LlamaCache |
|
|
|
result = self() |
|
cache_capacity = 0 |
|
if shared.args.cache_capacity is not None: |
|
if 'GiB' in shared.args.cache_capacity: |
|
cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000 * 1000 |
|
elif 'MiB' in shared.args.cache_capacity: |
|
cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000 |
|
else: |
|
cache_capacity = int(shared.args.cache_capacity) |
|
|
|
logger.info("Cache capacity is " + str(cache_capacity) + " bytes") |
|
|
|
if shared.args.tensor_split is None or shared.args.tensor_split.strip() == '': |
|
tensor_split_list = None |
|
else: |
|
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")] |
|
|
|
params = { |
|
'model_path': str(path), |
|
'n_ctx': shared.args.n_ctx, |
|
'seed': int(shared.args.llama_cpp_seed), |
|
'n_threads': shared.args.threads or None, |
|
'n_threads_batch': shared.args.threads_batch or None, |
|
'n_batch': shared.args.n_batch, |
|
'use_mmap': not shared.args.no_mmap, |
|
'use_mlock': shared.args.mlock, |
|
'mul_mat_q': not shared.args.no_mul_mat_q, |
|
'numa': shared.args.numa, |
|
'n_gpu_layers': shared.args.n_gpu_layers, |
|
'rope_freq_base': RoPE.get_rope_freq_base(shared.args.alpha_value, shared.args.rope_freq_base), |
|
'tensor_split': tensor_split_list, |
|
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb, |
|
} |
|
|
|
result.model = Llama(**params) |
|
if cache_capacity > 0: |
|
result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity)) |
|
|
|
|
|
return result, result |
|
|
|
def encode(self, string): |
|
if type(string) is str: |
|
string = string.encode() |
|
|
|
return self.model.tokenize(string) |
|
|
|
def decode(self, ids, **kwargs): |
|
return self.model.detokenize(ids).decode('utf-8') |
|
|
|
def get_logits(self, tokens): |
|
self.model.eval(tokens) |
|
logits = self.model._scores |
|
logits = np.expand_dims(logits, 0) |
|
return torch.tensor(logits, dtype=torch.float32) |
|
|
|
def load_grammar(self, string): |
|
if string != self.grammar_string: |
|
self.grammar_string = string |
|
if string.strip() != '': |
|
self.grammar = llama_cpp_lib().LlamaGrammar.from_string(string) |
|
else: |
|
self.grammar = None |
|
|
|
def generate(self, prompt, state, callback=None): |
|
|
|
LogitsProcessorList = llama_cpp_lib().LogitsProcessorList |
|
|
|
prompt = prompt if type(prompt) is str else prompt.decode() |
|
|
|
|
|
prompt = self.encode(prompt) |
|
prompt = prompt[-get_max_prompt_length(state):] |
|
prompt = self.decode(prompt) |
|
|
|
self.load_grammar(state['grammar_string']) |
|
logit_processors = LogitsProcessorList() |
|
if state['ban_eos_token']: |
|
logit_processors.append(partial(ban_eos_logits_processor, self.model.token_eos())) |
|
|
|
if state['custom_token_bans']: |
|
to_ban = [int(x) for x in state['custom_token_bans'].split(',')] |
|
if len(to_ban) > 0: |
|
logit_processors.append(partial(custom_token_ban_logits_processor, to_ban)) |
|
|
|
completion_chunks = self.model.create_completion( |
|
prompt=prompt, |
|
max_tokens=state['max_new_tokens'], |
|
temperature=state['temperature'], |
|
top_p=state['top_p'], |
|
top_k=state['top_k'], |
|
repeat_penalty=state['repetition_penalty'], |
|
presence_penalty=state['presence_penalty'], |
|
frequency_penalty=state['frequency_penalty'], |
|
tfs_z=state['tfs'], |
|
mirostat_mode=int(state['mirostat_mode']), |
|
mirostat_tau=state['mirostat_tau'], |
|
mirostat_eta=state['mirostat_eta'], |
|
stream=True, |
|
logits_processor=logit_processors, |
|
grammar=self.grammar |
|
) |
|
|
|
output = "" |
|
for completion_chunk in completion_chunks: |
|
if shared.stop_everything: |
|
break |
|
text = completion_chunk['choices'][0]['text'] |
|
output += text |
|
if callback: |
|
callback(text) |
|
|
|
return output |
|
|
|
def generate_with_streaming(self, *args, **kwargs): |
|
with Iteratorize(self.generate, args, kwargs, callback=None) as generator: |
|
reply = '' |
|
for token in generator: |
|
reply += token |
|
yield reply |
|
|