import os import torch from transformers import TextGenerationPipeline from transformers.pipelines.text_generation import ReturnType, Chat from stopping import get_stopping from prompter import Prompter, convert_messages_and_extract_images, get_prompt # keep for export_hf_checkpoint.py class H2OTextGenerationPipeline(TextGenerationPipeline): def __init__(self, *args, debug=False, chat=False, stream_output=False, sanitize_bot_response=False, use_prompter=True, prompter=None, context='', iinput='', chat_conversation=[], user_prompt_for_fake_system_prompt=None, prompt_type=None, prompt_dict=None, max_input_tokens=2048 - 256, base_model=None, stop=None, truncation_generation=None, max_time=None, image_file=None, image_control=None, images_num_max=None, image_resolution=None, image_format=None, rotate_align_resize_image=None, video_frame_period=None, image_batch_image_prompt=None, image_batch_final_prompt=None, image_batch_stream=None, visible_vision_models=None, video_file=None, verbose=False, **kwargs): """ HF-like pipeline, but handle instruction prompting and stopping (for some models) :param args: :param debug: :param chat: :param stream_output: :param sanitize_bot_response: :param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter :param prompter: prompter, can pass if have already :param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py. If use_prompter, then will make prompter and use it. :param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom :param max_input_tokens: :param kwargs: """ super().__init__(*args, **kwargs) self.prompt_text = None self.use_prompter = use_prompter self.prompts = [] self.prompt_type = prompt_type self.prompt_dict = prompt_dict self.prompter = prompter self.context = context self.iinput = iinput self.chat_conversation = chat_conversation self.user_prompt_for_fake_system_prompt = user_prompt_for_fake_system_prompt self.debug = debug if self.use_prompter: if self.prompter is not None: assert self.prompter.prompt_type is not None else: self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug, stream_output=stream_output, tokenizer=self.tokenizer, base_model=base_model) self.human = self.prompter.humanstr self.bot = self.prompter.botstr self.can_stop = True else: self.prompter = None self.human = None self.bot = None self.can_stop = False self.stop = stop self.sanitize_bot_response = sanitize_bot_response self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs self.base_model = base_model self.verbose = verbose self.truncation_generation = truncation_generation self.max_time = max_time self.image_file = image_file self.image_control = image_control self.images_num_max = images_num_max self.image_resolution = image_resolution self.image_format = image_format self.rotate_align_resize_image = rotate_align_resize_image self.video_frame_period = video_frame_period self.image_batch_image_prompt = image_batch_image_prompt self.image_batch_final_prompt = image_batch_final_prompt self.image_batch_stream = image_batch_stream self.visible_vision_models = visible_vision_models self.video_file = video_file @staticmethod def get_token_count(x, tokenizer): # NOTE: Somewhat duplicates get_token_count() # handle ambiguity in if get dict or list if hasattr(tokenizer, 'encode'): tokens = tokenizer.encode(x) else: tokens = tokenizer(x) if isinstance(tokens, dict) and 'input_ids' in tokens: tokens = tokens['input_ids'] if isinstance(tokens, list): n_tokens = len(tokens) elif len(tokens.shape) == 2: n_tokens = tokens.shape[1] elif len(tokens.shape) == 1: n_tokens = tokens.shape[0] else: raise RuntimeError("Cannot handle tokens: %s" % tokens) return n_tokens @staticmethod def limit_prompt(prompt_text, tokenizer, max_prompt_length=None, buffer=256): if prompt_text is None: prompt_text = '' verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0'))) if hasattr(tokenizer, 'model_max_length'): # model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py model_max_length = int(tokenizer.model_max_length) if max_prompt_length is not None: model_max_length = int(min(model_max_length, max_prompt_length)) buffer = 0 # cut at some upper likely limit to avoid excessive tokenization etc # upper bound of 10 chars/token, e.g. special chars sometimes are long if model_max_length == 0: len0 = len(prompt_text) prompt_text = '' if verbose: print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True) elif len(prompt_text) > model_max_length * 10: len0 = len(prompt_text) prompt_text = prompt_text[-model_max_length * 10:] if verbose: print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True) elif max_prompt_length is not None: model_max_length = max_prompt_length else: # unknown model_max_length = None num_prompt_tokens = None if model_max_length is not None: # can't wait for "hole" if not plain prompt_type, since would lose prefix like : # For https://github.com/h2oai/h2ogpt/issues/192 for trial in range(0, 5): if prompt_text: num_prompt_tokens = H2OTextGenerationPipeline.get_token_count(prompt_text, tokenizer) else: num_prompt_tokens = 0 if num_prompt_tokens > model_max_length and num_prompt_tokens > 0: # conservative by using int() chars_per_token = len(prompt_text) / num_prompt_tokens # keep tail, where question is if using langchain model_max_length_with_buffer = max(0, model_max_length - buffer) prompt_text = prompt_text[-int(model_max_length_with_buffer * chars_per_token):] if verbose: print("reducing %s tokens, assuming average of %s chars/token for %s characters" % ( num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True) else: if verbose: print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True) break if num_prompt_tokens is not None and num_prompt_tokens > model_max_length and model_max_length > 0: print( "Failed to reduce %s tokens with %s chars: %s" % (num_prompt_tokens, len(prompt_text), prompt_text), flush=True) return prompt_text, num_prompt_tokens def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs): prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer) data_point = dict(context=self.context, instruction=prompt_text, input=self.iinput) if self.prompter is not None and not self.image_file: prompt_text = self.prompter.generate_prompt(data_point, chat_conversation=self.chat_conversation, user_prompt_for_fake_system_prompt=self.user_prompt_for_fake_system_prompt, ) self.prompt_text = prompt_text self.prompts.append(prompt_text) if handle_long_generation is None: # forces truncation of inputs to avoid critical failure handle_long_generation = None # disable with new approaches return self._preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation, **generate_kwargs) def _preprocess( self, prompt_text, prefix="", handle_long_generation=None, add_special_tokens=False, truncation=None, padding=False, max_length=None, **generate_kwargs, ): if self.image_file: from transformers.image_utils import load_image images = [load_image(x) for x in self.image_file] # Create inputs from transformers import AutoProcessor # `http://` or `https://`, a valid path to an image file, or a base64 encoded string. processor = AutoProcessor.from_pretrained(self.base_model) history = self.chat_conversation.copy() history.append([(prompt_text, images), None]) messages, images = convert_messages_and_extract_images(history) prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=images, return_tensors="pt") raise NotImplementedError("Not functioning yet.") elif isinstance(prompt_text, Chat): inputs = self.tokenizer.apply_chat_template( prompt_text.messages, truncation=truncation, padding=padding, max_length=max_length, add_generation_prompt=True, return_dict=True, return_tensors=self.framework, ) else: inputs = self.tokenizer( prefix + prompt_text, truncation=truncation, padding=padding, max_length=max_length, add_special_tokens=add_special_tokens, return_tensors=self.framework, ) inputs["prompt_text"] = prompt_text if handle_long_generation == "hole": cur_len = inputs["input_ids"].shape[-1] if "max_new_tokens" in generate_kwargs: new_tokens = generate_kwargs["max_new_tokens"] else: new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len if new_tokens < 0: raise ValueError("We cannot infer how many new tokens are expected") if cur_len + new_tokens > self.tokenizer.model_max_length: keep_length = self.tokenizer.model_max_length - new_tokens if keep_length <= 0: raise ValueError( "We cannot use `hole` to handle this generation the number of desired tokens exceeds the" " models max length" ) inputs["input_ids"] = inputs["input_ids"][:, -keep_length:] if "attention_mask" in inputs: inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:] return inputs def _postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True, conditional_type=False): generated_sequence = model_outputs["generated_sequence"][0] input_ids = model_outputs["input_ids"] prompt_text = model_outputs["prompt_text"] generated_sequence = generated_sequence.numpy().tolist() records = [] for sequence in generated_sequence: if return_type == ReturnType.TENSORS: record = {"generated_token_ids": sequence} elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}: # Decode text text = self.tokenizer.decode( sequence, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces, ) if conditional_type: all_text = text else: # Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used if input_ids is None: prompt_length = 0 else: prompt_length = len( self.tokenizer.decode( input_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces, ) ) if return_type == ReturnType.FULL_TEXT: all_text = prompt_text + text[prompt_length:] else: all_text = text[prompt_length:] record = {"generated_text": all_text} records.append(record) return records def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True): conditional_type = hasattr(self.model, 'conditional_type') and self.model.conditional_type records = self._postprocess(model_outputs, return_type=return_type, clean_up_tokenization_spaces=clean_up_tokenization_spaces, conditional_type=conditional_type) key = 'generated_text' for rec in records: if self.use_prompter: outputs = rec[key] if return_type == ReturnType.NEW_TEXT: output_with_prompt = outputs prompt = None only_new_text = True elif conditional_type: if self.prompter.botstr: prompt = self.prompter.botstr output_with_prompt = prompt + outputs only_new_text = False else: prompt = None output_with_prompt = outputs only_new_text = True else: output_with_prompt = outputs prompt = self.prompt_text only_new_text = False outputs = self.prompter.get_response(output_with_prompt, prompt=prompt, only_new_text=only_new_text, sanitize_bot_response=self.sanitize_bot_response) elif self.bot in rec[key]: if self.human: outputs = rec[key].split(self.bot)[-1].split(self.human)[0] else: outputs = rec[key].split(self.bot)[-1].split(self.bot)[0] else: outputs = rec[key] rec[key] = outputs if self.debug: print("prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs), flush=True) if hasattr(self.model, 'memory') and hasattr(self.model.memory, 'reset'): self.model.memory.reset() return records def _forward(self, model_inputs, **generate_kwargs): stop = [] if generate_kwargs.get('stop'): stop += generate_kwargs['stop'] if self.stop: stop += self.stop stop = sorted(set(self.stop)) if self.can_stop or stop: self.stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict, self.tokenizer, self.device, self.base_model, human=self.human, bot=self.bot, model_max_length=self.tokenizer.model_max_length, prompter=self.prompter, stop=stop, truncation_generation=self.truncation_generation, max_time=self.max_time) generate_kwargs['stopping_criteria'] = self.stopping_criteria generate_kwargs.pop('stop', None) # return super()._forward(model_inputs, **generate_kwargs) return self.__forward(model_inputs, **generate_kwargs) # FIXME: Copy-paste of original _forward, but removed copy.deepcopy() # FIXME: https://github.com/h2oai/h2ogpt/issues/172 def __forward(self, model_inputs, **generate_kwargs): input_ids = model_inputs["input_ids"] attention_mask = model_inputs.get("attention_mask", None) # Allow empty prompts if input_ids.shape[1] == 0: input_ids = None attention_mask = None in_b = 1 else: in_b = input_ids.shape[0] prompt_text = model_inputs.pop("prompt_text") ## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying ## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline. # generate_kwargs = copy.deepcopy(generate_kwargs) prefix_length = generate_kwargs.pop("prefix_length", 0) if prefix_length > 0: has_max_new_tokens = "max_new_tokens" in generate_kwargs or ( "generation_config" in generate_kwargs and generate_kwargs["generation_config"].max_new_tokens is not None ) if not has_max_new_tokens: generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length generate_kwargs["max_length"] += prefix_length has_min_new_tokens = "min_new_tokens" in generate_kwargs or ( "generation_config" in generate_kwargs and generate_kwargs["generation_config"].min_new_tokens is not None ) if not has_min_new_tokens and "min_length" in generate_kwargs: generate_kwargs["min_length"] += prefix_length # BS x SL seed = generate_kwargs.pop('seed', 1234) torch.manual_seed(seed) generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) out_b = generated_sequence.shape[0] if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) elif self.framework == "tf": from transformers import is_tf_available if is_tf_available(): import tensorflow as tf generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:])) else: raise ValueError("TF not avaialble.") return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}