from functools import partial from langchain.llms.base import LLM from langchain.callbacks.manager import CallbackManagerForLLMRun from typing import Any, Dict, List, Optional from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig from exllama.tokenizer import ExLlamaTokenizer from exllama.generator import ExLlamaGenerator from exllama.lora import ExLlamaLora import os, glob from pydantic.v1 import root_validator BROKEN_UNICODE = b'\\ufffd'.decode('unicode_escape') class H2OExLlamaTokenizer(ExLlamaTokenizer): def __call__(self, text, *args, **kwargs): return dict(input_ids=self.encode(text)) class H2OExLlamaGenerator(ExLlamaGenerator): def is_exlama(self): return True class Exllama(LLM): client: Any #: :meta private: model_path: str = None model: Any = None sanitize_bot_response: bool = False prompter: Any = None context: Any = '' iinput: Any = '' chat_conversation: Any = [] user_prompt_for_fake_system_prompt: Any = None """The path to the GPTQ model folder.""" exllama_cache: ExLlamaCache = None #: :meta private: config: ExLlamaConfig = None #: :meta private: generator: ExLlamaGenerator = None #: :meta private: tokenizer: ExLlamaTokenizer = None #: :meta private: ##Langchain parameters logfunc = print stop_sequences: Optional[List[str]] = "" # , description="Sequences that immediately will stop the generator.") streaming: Optional[bool] = True # , description="Whether to stream the results, token by token.") ##Generator parameters disallowed_tokens: Optional[List[int]] = None # description="List of tokens to disallow during generation.") temperature: Optional[float] = None # description="Temperature for sampling diversity.") top_k: Optional[int] = None # description="Consider the most probable top_k samples, 0 to disable top_k sampling.") top_p: Optional[ float] = None # description="Consider tokens up to a cumulative probabiltiy of top_p, 0.0 to disable top_p sampling.") min_p: Optional[float] = None # description="Do not consider tokens with probability less than this.") typical: Optional[ float] = None # description="Locally typical sampling threshold, 0.0 to disable typical sampling.") token_repetition_penalty_max: Optional[float] = None # description="Repetition penalty for most recent tokens.") token_repetition_penalty_sustain: Optional[ int] = None # description="No. most recent tokens to repeat penalty for, -1 to apply to whole context.") token_repetition_penalty_decay: Optional[ int] = None # description="Gradually decrease penalty over this many tokens.") beams: Optional[int] = None # description="Number of beams for beam search.") beam_length: Optional[int] = None # description="Length of beams for beam search.") ##Config overrides max_seq_len: Optional[ int] = 2048 # decription="Reduce to save memory. Can also be increased, ideally while also using compress_pos_emn and a compatible model/LoRA") compress_pos_emb: Optional[ float] = 1.0 # description="Amount of compression to apply to the positional embedding.") set_auto_map: Optional[ str] = None # description="Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. 20,7,7") gpu_peer_fix: Optional[bool] = None # description="Prevent direct copies of data between GPUs") alpha_value: Optional[float] = 1.0 # , description="Rope context extension alpha") ##Tuning matmul_recons_thd: Optional[int] = None fused_mlp_thd: Optional[int] = None sdp_thd: Optional[int] = None fused_attn: Optional[bool] = None matmul_fused_remap: Optional[bool] = None rmsnorm_no_half2: Optional[bool] = None rope_no_half2: Optional[bool] = None matmul_no_half2: Optional[bool] = None silu_no_half2: Optional[bool] = None concurrent_streams: Optional[bool] = None ##Lora Parameters lora_path: Optional[str] = None # description="Path to your lora.") @staticmethod def get_model_path_at(path): patterns = ["*.safetensors", "*.bin", "*.pt"] model_paths = [] for pattern in patterns: full_pattern = os.path.join(path, pattern) model_paths = glob.glob(full_pattern) if model_paths: # If there are any files matching the current pattern break # Exit the loop as soon as we find a matching file if model_paths: # If there are any files matching any of the patterns return model_paths[0] else: return None # Return None if no matching files were found @staticmethod def configure_object(params, values, logfunc): obj_params = {k: values.get(k) for k in params} def apply_to(obj): for key, value in obj_params.items(): if value: if hasattr(obj, key): setattr(obj, key, value) logfunc(f"{key} {value}") else: raise AttributeError(f"{key} does not exist in {obj}") return apply_to @root_validator() def validate_environment(cls, values: Dict) -> Dict: model_param_names = [ "temperature", "top_k", "top_p", "min_p", "typical", "token_repetition_penalty_max", "token_repetition_penalty_sustain", "token_repetition_penalty_decay", "beams", "beam_length", ] config_param_names = [ "max_seq_len", "compress_pos_emb", "gpu_peer_fix", "alpha_value" ] tuning_parameters = [ "matmul_recons_thd", "fused_mlp_thd", "sdp_thd", "matmul_fused_remap", "rmsnorm_no_half2", "rope_no_half2", "matmul_no_half2", "silu_no_half2", "concurrent_streams", "fused_attn", ] ##Set logging function if verbose or set to empty lambda verbose = values['verbose'] if not verbose: values['logfunc'] = lambda *args, **kwargs: None logfunc = values['logfunc'] if values['model'] is None: model_path = values["model_path"] lora_path = values["lora_path"] tokenizer_path = os.path.join(model_path, "tokenizer.model") model_config_path = os.path.join(model_path, "config.json") model_path = Exllama.get_model_path_at(model_path) config = ExLlamaConfig(model_config_path) tokenizer = ExLlamaTokenizer(tokenizer_path) config.model_path = model_path configure_config = Exllama.configure_object(config_param_names, values, logfunc) configure_config(config) configure_tuning = Exllama.configure_object(tuning_parameters, values, logfunc) configure_tuning(config) ##Special parameter, set auto map, it's a function if values['set_auto_map']: config.set_auto_map(values['set_auto_map']) logfunc(f"set_auto_map {values['set_auto_map']}") model = ExLlama(config) exllama_cache = ExLlamaCache(model) generator = ExLlamaGenerator(model, tokenizer, exllama_cache) ##Load and apply lora to generator if lora_path is not None: lora_config_path = os.path.join(lora_path, "adapter_config.json") lora_path = Exllama.get_model_path_at(lora_path) lora = ExLlamaLora(model, lora_config_path, lora_path) generator.lora = lora logfunc(f"Loaded LORA @ {lora_path}") else: generator = values['model'] exllama_cache = generator.cache model = generator.model config = model.config tokenizer = generator.tokenizer # Set if model existed before or not since generation-time parameters configure_model = Exllama.configure_object(model_param_names, values, logfunc) values["stop_sequences"] = [x.strip().lower() for x in values["stop_sequences"]] configure_model(generator.settings) setattr(generator.settings, "stop_sequences", values["stop_sequences"]) logfunc(f"stop_sequences {values['stop_sequences']}") disallowed = values.get("disallowed_tokens") if disallowed: generator.disallow_tokens(disallowed) print(f"Disallowed Tokens: {generator.disallowed_tokens}") values["client"] = model values["generator"] = generator values["config"] = config values["tokenizer"] = tokenizer values["exllama_cache"] = exllama_cache return values @property def _llm_type(self) -> str: """Return type of llm.""" return "Exllama" def get_num_tokens(self, text: str) -> int: """Get the number of tokens present in the text.""" return self.generator.tokenizer.num_tokens(text) def get_token_ids(self, text: str) -> List[int]: return self.generator.tokenizer.encode(text) # avoid base method that is not aware of how to properly tokenize (uses GPT2) # return _get_token_ids_default_method(text) def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: assert self.tokenizer is not None from h2oai_pipeline import H2OTextGenerationPipeline prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer) # NOTE: exllama does not add prompting, so must do here data_point = dict(context=self.context, instruction=prompt, input=self.iinput) prompt = self.prompter.generate_prompt(data_point, chat_conversation=self.chat_conversation, user_prompt_for_fake_system_prompt=self.user_prompt_for_fake_system_prompt, ) text = '' for text1 in self.stream(prompt=prompt, stop=stop, run_manager=run_manager): text = text1 return text from enum import Enum class MatchStatus(Enum): EXACT_MATCH = 1 PARTIAL_MATCH = 0 NO_MATCH = 2 def match_status(self, sequence: str, banned_sequences: List[str]): sequence = sequence.strip().lower() for banned_seq in banned_sequences: if banned_seq == sequence: return self.MatchStatus.EXACT_MATCH elif banned_seq.startswith(sequence): return self.MatchStatus.PARTIAL_MATCH return self.MatchStatus.NO_MATCH def stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: config = self.config generator = self.generator beam_search = (self.beams and self.beams >= 1 and self.beam_length and self.beam_length >= 1) ids = generator.tokenizer.encode(prompt) generator.gen_begin_reuse(ids) if beam_search: generator.begin_beam_search() token_getter = generator.beam_search else: generator.end_beam_search() token_getter = generator.gen_single_token last_newline_pos = 0 seq_length = len(generator.tokenizer.decode(generator.sequence_actual[0])) response_start = seq_length cursor_head = response_start text_callback = None if run_manager: text_callback = partial( run_manager.on_llm_new_token, verbose=self.verbose ) # No longer assume below, assume always just new text so various langchain things work ##### parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter #### text_callback: #### text_callback(prompt) text = "" while (generator.gen_num_tokens() <= ( self.max_seq_len - 4)): # Slight extra padding space as we seem to occassionally get a few more than 1-2 tokens # Fetch a token token = token_getter() # If it's the ending token replace it and end the generation. if token.item() == generator.tokenizer.eos_token_id: generator.replace_last_token(generator.tokenizer.newline_token_id) if beam_search: generator.end_beam_search() return # Tokenize the string from the last new line, we can't just decode the last token due to how sentencepiece decodes. stuff = generator.tokenizer.decode(generator.sequence_actual[0][last_newline_pos:]) cursor_tail = len(stuff) has_unicode_combined = cursor_tail < cursor_head text_chunk = stuff[cursor_head:cursor_tail] if has_unicode_combined: # replace the broken unicode character with combined one text = text[:-2] text_chunk = stuff[cursor_tail - 1:cursor_tail] cursor_head = cursor_tail # Append the generated chunk to our stream buffer text += text_chunk text = self.prompter.get_response(prompt + text, prompt=prompt, sanitize_bot_response=self.sanitize_bot_response) if token.item() == generator.tokenizer.newline_token_id: last_newline_pos = len(generator.sequence_actual[0]) cursor_head = 0 cursor_tail = 0 # Check if the stream buffer is one of the stop sequences status = self.match_status(text, self.stop_sequences) if status == self.MatchStatus.EXACT_MATCH: # Encountered a stop, rewind our generator to before we hit the match and end generation. rewind_length = generator.tokenizer.encode(text).shape[-1] generator.gen_rewind(rewind_length) # gen = generator.tokenizer.decode(generator.sequence_actual[0][response_start:]) if beam_search: generator.end_beam_search() return elif status == self.MatchStatus.PARTIAL_MATCH: # Partially matched a stop, continue buffering but don't yield. continue elif status == self.MatchStatus.NO_MATCH: if text_callback and not (text_chunk == BROKEN_UNICODE): text_callback(text_chunk) yield text # Not a stop, yield the match buffer. return