|
from functools import partial |
|
|
|
from langchain.llms.base import LLM |
|
from langchain.callbacks.manager import CallbackManagerForLLMRun |
|
from typing import Any, Dict, List, Optional |
|
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig |
|
from exllama.tokenizer import ExLlamaTokenizer |
|
from exllama.generator import ExLlamaGenerator |
|
from exllama.lora import ExLlamaLora |
|
import os, glob |
|
|
|
from pydantic.v1 import root_validator |
|
|
|
BROKEN_UNICODE = b'\\ufffd'.decode('unicode_escape') |
|
|
|
|
|
class H2OExLlamaTokenizer(ExLlamaTokenizer): |
|
def __call__(self, text, *args, **kwargs): |
|
return dict(input_ids=self.encode(text)) |
|
|
|
|
|
class H2OExLlamaGenerator(ExLlamaGenerator): |
|
def is_exlama(self): |
|
return True |
|
|
|
|
|
class Exllama(LLM): |
|
client: Any |
|
model_path: str = None |
|
model: Any = None |
|
sanitize_bot_response: bool = False |
|
prompter: Any = None |
|
context: Any = '' |
|
iinput: Any = '' |
|
chat_conversation: Any = [] |
|
user_prompt_for_fake_system_prompt: Any = None |
|
|
|
"""The path to the GPTQ model folder.""" |
|
exllama_cache: ExLlamaCache = None |
|
config: ExLlamaConfig = None |
|
generator: ExLlamaGenerator = None |
|
tokenizer: ExLlamaTokenizer = None |
|
|
|
|
|
logfunc = print |
|
stop_sequences: Optional[List[str]] = "" |
|
streaming: Optional[bool] = True |
|
|
|
|
|
disallowed_tokens: Optional[List[int]] = None |
|
temperature: Optional[float] = None |
|
top_k: Optional[int] = None |
|
top_p: Optional[ |
|
float] = None |
|
min_p: Optional[float] = None |
|
typical: Optional[ |
|
float] = None |
|
token_repetition_penalty_max: Optional[float] = None |
|
token_repetition_penalty_sustain: Optional[ |
|
int] = None |
|
token_repetition_penalty_decay: Optional[ |
|
int] = None |
|
beams: Optional[int] = None |
|
beam_length: Optional[int] = None |
|
|
|
|
|
max_seq_len: Optional[ |
|
int] = 2048 |
|
compress_pos_emb: Optional[ |
|
float] = 1.0 |
|
set_auto_map: Optional[ |
|
str] = None |
|
gpu_peer_fix: Optional[bool] = None |
|
alpha_value: Optional[float] = 1.0 |
|
|
|
|
|
matmul_recons_thd: Optional[int] = None |
|
fused_mlp_thd: Optional[int] = None |
|
sdp_thd: Optional[int] = None |
|
fused_attn: Optional[bool] = None |
|
matmul_fused_remap: Optional[bool] = None |
|
rmsnorm_no_half2: Optional[bool] = None |
|
rope_no_half2: Optional[bool] = None |
|
matmul_no_half2: Optional[bool] = None |
|
silu_no_half2: Optional[bool] = None |
|
concurrent_streams: Optional[bool] = None |
|
|
|
|
|
lora_path: Optional[str] = None |
|
|
|
@staticmethod |
|
def get_model_path_at(path): |
|
patterns = ["*.safetensors", "*.bin", "*.pt"] |
|
model_paths = [] |
|
for pattern in patterns: |
|
full_pattern = os.path.join(path, pattern) |
|
model_paths = glob.glob(full_pattern) |
|
if model_paths: |
|
break |
|
if model_paths: |
|
return model_paths[0] |
|
else: |
|
return None |
|
|
|
@staticmethod |
|
def configure_object(params, values, logfunc): |
|
obj_params = {k: values.get(k) for k in params} |
|
|
|
def apply_to(obj): |
|
for key, value in obj_params.items(): |
|
if value: |
|
if hasattr(obj, key): |
|
setattr(obj, key, value) |
|
logfunc(f"{key} {value}") |
|
else: |
|
raise AttributeError(f"{key} does not exist in {obj}") |
|
|
|
return apply_to |
|
|
|
@root_validator() |
|
def validate_environment(cls, values: Dict) -> Dict: |
|
model_param_names = [ |
|
"temperature", |
|
"top_k", |
|
"top_p", |
|
"min_p", |
|
"typical", |
|
"token_repetition_penalty_max", |
|
"token_repetition_penalty_sustain", |
|
"token_repetition_penalty_decay", |
|
"beams", |
|
"beam_length", |
|
] |
|
|
|
config_param_names = [ |
|
"max_seq_len", |
|
"compress_pos_emb", |
|
"gpu_peer_fix", |
|
"alpha_value" |
|
] |
|
|
|
tuning_parameters = [ |
|
"matmul_recons_thd", |
|
"fused_mlp_thd", |
|
"sdp_thd", |
|
"matmul_fused_remap", |
|
"rmsnorm_no_half2", |
|
"rope_no_half2", |
|
"matmul_no_half2", |
|
"silu_no_half2", |
|
"concurrent_streams", |
|
"fused_attn", |
|
] |
|
|
|
|
|
verbose = values['verbose'] |
|
if not verbose: |
|
values['logfunc'] = lambda *args, **kwargs: None |
|
logfunc = values['logfunc'] |
|
|
|
if values['model'] is None: |
|
model_path = values["model_path"] |
|
lora_path = values["lora_path"] |
|
|
|
tokenizer_path = os.path.join(model_path, "tokenizer.model") |
|
model_config_path = os.path.join(model_path, "config.json") |
|
model_path = Exllama.get_model_path_at(model_path) |
|
|
|
config = ExLlamaConfig(model_config_path) |
|
tokenizer = ExLlamaTokenizer(tokenizer_path) |
|
config.model_path = model_path |
|
|
|
configure_config = Exllama.configure_object(config_param_names, values, logfunc) |
|
configure_config(config) |
|
configure_tuning = Exllama.configure_object(tuning_parameters, values, logfunc) |
|
configure_tuning(config) |
|
|
|
|
|
if values['set_auto_map']: |
|
config.set_auto_map(values['set_auto_map']) |
|
logfunc(f"set_auto_map {values['set_auto_map']}") |
|
|
|
model = ExLlama(config) |
|
exllama_cache = ExLlamaCache(model) |
|
generator = ExLlamaGenerator(model, tokenizer, exllama_cache) |
|
|
|
|
|
if lora_path is not None: |
|
lora_config_path = os.path.join(lora_path, "adapter_config.json") |
|
lora_path = Exllama.get_model_path_at(lora_path) |
|
lora = ExLlamaLora(model, lora_config_path, lora_path) |
|
generator.lora = lora |
|
logfunc(f"Loaded LORA @ {lora_path}") |
|
else: |
|
generator = values['model'] |
|
exllama_cache = generator.cache |
|
model = generator.model |
|
config = model.config |
|
tokenizer = generator.tokenizer |
|
|
|
|
|
configure_model = Exllama.configure_object(model_param_names, values, logfunc) |
|
values["stop_sequences"] = [x.strip().lower() for x in values["stop_sequences"]] |
|
configure_model(generator.settings) |
|
|
|
setattr(generator.settings, "stop_sequences", values["stop_sequences"]) |
|
logfunc(f"stop_sequences {values['stop_sequences']}") |
|
|
|
disallowed = values.get("disallowed_tokens") |
|
if disallowed: |
|
generator.disallow_tokens(disallowed) |
|
print(f"Disallowed Tokens: {generator.disallowed_tokens}") |
|
|
|
values["client"] = model |
|
values["generator"] = generator |
|
values["config"] = config |
|
values["tokenizer"] = tokenizer |
|
values["exllama_cache"] = exllama_cache |
|
|
|
return values |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
"""Return type of llm.""" |
|
return "Exllama" |
|
|
|
def get_num_tokens(self, text: str) -> int: |
|
"""Get the number of tokens present in the text.""" |
|
return self.generator.tokenizer.num_tokens(text) |
|
|
|
def get_token_ids(self, text: str) -> List[int]: |
|
return self.generator.tokenizer.encode(text) |
|
|
|
|
|
|
|
def _call( |
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> str: |
|
assert self.tokenizer is not None |
|
from h2oai_pipeline import H2OTextGenerationPipeline |
|
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer) |
|
|
|
|
|
data_point = dict(context=self.context, instruction=prompt, input=self.iinput) |
|
prompt = self.prompter.generate_prompt(data_point, |
|
chat_conversation=self.chat_conversation, |
|
user_prompt_for_fake_system_prompt=self.user_prompt_for_fake_system_prompt, |
|
) |
|
|
|
text = '' |
|
for text1 in self.stream(prompt=prompt, stop=stop, run_manager=run_manager): |
|
text = text1 |
|
return text |
|
|
|
from enum import Enum |
|
|
|
class MatchStatus(Enum): |
|
EXACT_MATCH = 1 |
|
PARTIAL_MATCH = 0 |
|
NO_MATCH = 2 |
|
|
|
def match_status(self, sequence: str, banned_sequences: List[str]): |
|
sequence = sequence.strip().lower() |
|
for banned_seq in banned_sequences: |
|
if banned_seq == sequence: |
|
return self.MatchStatus.EXACT_MATCH |
|
elif banned_seq.startswith(sequence): |
|
return self.MatchStatus.PARTIAL_MATCH |
|
return self.MatchStatus.NO_MATCH |
|
|
|
def stream( |
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
) -> str: |
|
config = self.config |
|
generator = self.generator |
|
beam_search = (self.beams and self.beams >= 1 and self.beam_length and self.beam_length >= 1) |
|
|
|
ids = generator.tokenizer.encode(prompt) |
|
generator.gen_begin_reuse(ids) |
|
|
|
if beam_search: |
|
generator.begin_beam_search() |
|
token_getter = generator.beam_search |
|
else: |
|
generator.end_beam_search() |
|
token_getter = generator.gen_single_token |
|
|
|
last_newline_pos = 0 |
|
seq_length = len(generator.tokenizer.decode(generator.sequence_actual[0])) |
|
response_start = seq_length |
|
cursor_head = response_start |
|
|
|
text_callback = None |
|
if run_manager: |
|
text_callback = partial( |
|
run_manager.on_llm_new_token, verbose=self.verbose |
|
) |
|
|
|
|
|
|
|
|
|
text = "" |
|
while (generator.gen_num_tokens() <= ( |
|
self.max_seq_len - 4)): |
|
|
|
token = token_getter() |
|
|
|
|
|
if token.item() == generator.tokenizer.eos_token_id: |
|
generator.replace_last_token(generator.tokenizer.newline_token_id) |
|
if beam_search: |
|
generator.end_beam_search() |
|
return |
|
|
|
|
|
stuff = generator.tokenizer.decode(generator.sequence_actual[0][last_newline_pos:]) |
|
cursor_tail = len(stuff) |
|
has_unicode_combined = cursor_tail < cursor_head |
|
text_chunk = stuff[cursor_head:cursor_tail] |
|
if has_unicode_combined: |
|
|
|
text = text[:-2] |
|
text_chunk = stuff[cursor_tail - 1:cursor_tail] |
|
|
|
cursor_head = cursor_tail |
|
|
|
|
|
text += text_chunk |
|
text = self.prompter.get_response(prompt + text, prompt=prompt, |
|
sanitize_bot_response=self.sanitize_bot_response) |
|
|
|
if token.item() == generator.tokenizer.newline_token_id: |
|
last_newline_pos = len(generator.sequence_actual[0]) |
|
cursor_head = 0 |
|
cursor_tail = 0 |
|
|
|
|
|
status = self.match_status(text, self.stop_sequences) |
|
|
|
if status == self.MatchStatus.EXACT_MATCH: |
|
|
|
rewind_length = generator.tokenizer.encode(text).shape[-1] |
|
generator.gen_rewind(rewind_length) |
|
|
|
if beam_search: |
|
generator.end_beam_search() |
|
return |
|
elif status == self.MatchStatus.PARTIAL_MATCH: |
|
|
|
continue |
|
elif status == self.MatchStatus.NO_MATCH: |
|
if text_callback and not (text_chunk == BROKEN_UNICODE): |
|
text_callback(text_chunk) |
|
yield text |
|
|
|
return |
|
|