aiben / src /h2oai_pipeline.py
abugaber's picture
Upload folder using huggingface_hub
3943768 verified
raw
history blame
20.3 kB
import os
import torch
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType, Chat
from stopping import get_stopping
from prompter import Prompter, convert_messages_and_extract_images, get_prompt # keep for export_hf_checkpoint.py
class H2OTextGenerationPipeline(TextGenerationPipeline):
def __init__(self, *args, debug=False, chat=False, stream_output=False,
sanitize_bot_response=False,
use_prompter=True, prompter=None,
context='', iinput='',
chat_conversation=[],
user_prompt_for_fake_system_prompt=None,
prompt_type=None, prompt_dict=None,
max_input_tokens=2048 - 256,
base_model=None,
stop=None,
truncation_generation=None,
max_time=None,
image_file=None,
image_control=None,
images_num_max=None,
image_resolution=None,
image_format=None,
rotate_align_resize_image=None,
video_frame_period=None,
image_batch_image_prompt=None,
image_batch_final_prompt=None,
image_batch_stream=None,
visible_vision_models=None,
video_file=None,
verbose=False,
**kwargs):
"""
HF-like pipeline, but handle instruction prompting and stopping (for some models)
:param args:
:param debug:
:param chat:
:param stream_output:
:param sanitize_bot_response:
:param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
:param prompter: prompter, can pass if have already
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
If use_prompter, then will make prompter and use it.
:param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
:param max_input_tokens:
:param kwargs:
"""
super().__init__(*args, **kwargs)
self.prompt_text = None
self.use_prompter = use_prompter
self.prompts = []
self.prompt_type = prompt_type
self.prompt_dict = prompt_dict
self.prompter = prompter
self.context = context
self.iinput = iinput
self.chat_conversation = chat_conversation
self.user_prompt_for_fake_system_prompt = user_prompt_for_fake_system_prompt
self.debug = debug
if self.use_prompter:
if self.prompter is not None:
assert self.prompter.prompt_type is not None
else:
self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug,
stream_output=stream_output, tokenizer=self.tokenizer,
base_model=base_model)
self.human = self.prompter.humanstr
self.bot = self.prompter.botstr
self.can_stop = True
else:
self.prompter = None
self.human = None
self.bot = None
self.can_stop = False
self.stop = stop
self.sanitize_bot_response = sanitize_bot_response
self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
self.base_model = base_model
self.verbose = verbose
self.truncation_generation = truncation_generation
self.max_time = max_time
self.image_file = image_file
self.image_control = image_control
self.images_num_max = images_num_max
self.image_resolution = image_resolution
self.image_format = image_format
self.rotate_align_resize_image = rotate_align_resize_image
self.video_frame_period = video_frame_period
self.image_batch_image_prompt = image_batch_image_prompt
self.image_batch_final_prompt = image_batch_final_prompt
self.image_batch_stream = image_batch_stream
self.visible_vision_models = visible_vision_models
self.video_file = video_file
@staticmethod
def get_token_count(x, tokenizer):
# NOTE: Somewhat duplicates get_token_count()
# handle ambiguity in if get dict or list
if hasattr(tokenizer, 'encode'):
tokens = tokenizer.encode(x)
else:
tokens = tokenizer(x)
if isinstance(tokens, dict) and 'input_ids' in tokens:
tokens = tokens['input_ids']
if isinstance(tokens, list):
n_tokens = len(tokens)
elif len(tokens.shape) == 2:
n_tokens = tokens.shape[1]
elif len(tokens.shape) == 1:
n_tokens = tokens.shape[0]
else:
raise RuntimeError("Cannot handle tokens: %s" % tokens)
return n_tokens
@staticmethod
def limit_prompt(prompt_text, tokenizer, max_prompt_length=None, buffer=256):
if prompt_text is None:
prompt_text = ''
verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))
if hasattr(tokenizer, 'model_max_length'):
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
model_max_length = int(tokenizer.model_max_length)
if max_prompt_length is not None:
model_max_length = int(min(model_max_length, max_prompt_length))
buffer = 0
# cut at some upper likely limit to avoid excessive tokenization etc
# upper bound of 10 chars/token, e.g. special chars sometimes are long
if model_max_length == 0:
len0 = len(prompt_text)
prompt_text = ''
if verbose:
print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True)
elif len(prompt_text) > model_max_length * 10:
len0 = len(prompt_text)
prompt_text = prompt_text[-model_max_length * 10:]
if verbose:
print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True)
elif max_prompt_length is not None:
model_max_length = max_prompt_length
else:
# unknown
model_max_length = None
num_prompt_tokens = None
if model_max_length is not None:
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
# For https://github.com/h2oai/h2ogpt/issues/192
for trial in range(0, 5):
if prompt_text:
num_prompt_tokens = H2OTextGenerationPipeline.get_token_count(prompt_text, tokenizer)
else:
num_prompt_tokens = 0
if num_prompt_tokens > model_max_length and num_prompt_tokens > 0:
# conservative by using int()
chars_per_token = len(prompt_text) / num_prompt_tokens
# keep tail, where question is if using langchain
model_max_length_with_buffer = max(0, model_max_length - buffer)
prompt_text = prompt_text[-int(model_max_length_with_buffer * chars_per_token):]
if verbose:
print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True)
else:
if verbose:
print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
break
if num_prompt_tokens is not None and num_prompt_tokens > model_max_length and model_max_length > 0:
print(
"Failed to reduce %s tokens with %s chars: %s" % (num_prompt_tokens, len(prompt_text), prompt_text),
flush=True)
return prompt_text, num_prompt_tokens
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
data_point = dict(context=self.context, instruction=prompt_text, input=self.iinput)
if self.prompter is not None and not self.image_file:
prompt_text = self.prompter.generate_prompt(data_point,
chat_conversation=self.chat_conversation,
user_prompt_for_fake_system_prompt=self.user_prompt_for_fake_system_prompt,
)
self.prompt_text = prompt_text
self.prompts.append(prompt_text)
if handle_long_generation is None:
# forces truncation of inputs to avoid critical failure
handle_long_generation = None # disable with new approaches
return self._preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
**generate_kwargs)
def _preprocess(
self,
prompt_text,
prefix="",
handle_long_generation=None,
add_special_tokens=False,
truncation=None,
padding=False,
max_length=None,
**generate_kwargs,
):
if self.image_file:
from transformers.image_utils import load_image
images = [load_image(x) for x in self.image_file]
# Create inputs
from transformers import AutoProcessor
# `http://` or `https://`, a valid path to an image file, or a base64 encoded string.
processor = AutoProcessor.from_pretrained(self.base_model)
history = self.chat_conversation.copy()
history.append([(prompt_text, images), None])
messages, images = convert_messages_and_extract_images(history)
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=images, return_tensors="pt")
raise NotImplementedError("Not functioning yet.")
elif isinstance(prompt_text, Chat):
inputs = self.tokenizer.apply_chat_template(
prompt_text.messages,
truncation=truncation,
padding=padding,
max_length=max_length,
add_generation_prompt=True,
return_dict=True,
return_tensors=self.framework,
)
else:
inputs = self.tokenizer(
prefix + prompt_text,
truncation=truncation,
padding=padding,
max_length=max_length,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
)
inputs["prompt_text"] = prompt_text
if handle_long_generation == "hole":
cur_len = inputs["input_ids"].shape[-1]
if "max_new_tokens" in generate_kwargs:
new_tokens = generate_kwargs["max_new_tokens"]
else:
new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len
if new_tokens < 0:
raise ValueError("We cannot infer how many new tokens are expected")
if cur_len + new_tokens > self.tokenizer.model_max_length:
keep_length = self.tokenizer.model_max_length - new_tokens
if keep_length <= 0:
raise ValueError(
"We cannot use `hole` to handle this generation the number of desired tokens exceeds the"
" models max length"
)
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
if "attention_mask" in inputs:
inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:]
return inputs
def _postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True,
conditional_type=False):
generated_sequence = model_outputs["generated_sequence"][0]
input_ids = model_outputs["input_ids"]
prompt_text = model_outputs["prompt_text"]
generated_sequence = generated_sequence.numpy().tolist()
records = []
for sequence in generated_sequence:
if return_type == ReturnType.TENSORS:
record = {"generated_token_ids": sequence}
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
# Decode text
text = self.tokenizer.decode(
sequence,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
if conditional_type:
all_text = text
else:
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
if input_ids is None:
prompt_length = 0
else:
prompt_length = len(
self.tokenizer.decode(
input_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
)
if return_type == ReturnType.FULL_TEXT:
all_text = prompt_text + text[prompt_length:]
else:
all_text = text[prompt_length:]
record = {"generated_text": all_text}
records.append(record)
return records
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
conditional_type = hasattr(self.model, 'conditional_type') and self.model.conditional_type
records = self._postprocess(model_outputs, return_type=return_type,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
conditional_type=conditional_type)
key = 'generated_text'
for rec in records:
if self.use_prompter:
outputs = rec[key]
if return_type == ReturnType.NEW_TEXT:
output_with_prompt = outputs
prompt = None
only_new_text = True
elif conditional_type:
if self.prompter.botstr:
prompt = self.prompter.botstr
output_with_prompt = prompt + outputs
only_new_text = False
else:
prompt = None
output_with_prompt = outputs
only_new_text = True
else:
output_with_prompt = outputs
prompt = self.prompt_text
only_new_text = False
outputs = self.prompter.get_response(output_with_prompt, prompt=prompt,
only_new_text=only_new_text,
sanitize_bot_response=self.sanitize_bot_response)
elif self.bot in rec[key]:
if self.human:
outputs = rec[key].split(self.bot)[-1].split(self.human)[0]
else:
outputs = rec[key].split(self.bot)[-1].split(self.bot)[0]
else:
outputs = rec[key]
rec[key] = outputs
if self.debug:
print("prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs), flush=True)
if hasattr(self.model, 'memory') and hasattr(self.model.memory, 'reset'):
self.model.memory.reset()
return records
def _forward(self, model_inputs, **generate_kwargs):
stop = []
if generate_kwargs.get('stop'):
stop += generate_kwargs['stop']
if self.stop:
stop += self.stop
stop = sorted(set(self.stop))
if self.can_stop or stop:
self.stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
self.tokenizer, self.device,
self.base_model,
human=self.human, bot=self.bot,
model_max_length=self.tokenizer.model_max_length,
prompter=self.prompter,
stop=stop,
truncation_generation=self.truncation_generation,
max_time=self.max_time)
generate_kwargs['stopping_criteria'] = self.stopping_criteria
generate_kwargs.pop('stop', None)
# return super()._forward(model_inputs, **generate_kwargs)
return self.__forward(model_inputs, **generate_kwargs)
# FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
# FIXME: https://github.com/h2oai/h2ogpt/issues/172
def __forward(self, model_inputs, **generate_kwargs):
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
in_b = 1
else:
in_b = input_ids.shape[0]
prompt_text = model_inputs.pop("prompt_text")
## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
# generate_kwargs = copy.deepcopy(generate_kwargs)
prefix_length = generate_kwargs.pop("prefix_length", 0)
if prefix_length > 0:
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].max_new_tokens is not None
)
if not has_max_new_tokens:
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
generate_kwargs["max_length"] += prefix_length
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].min_new_tokens is not None
)
if not has_min_new_tokens and "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length
# BS x SL
seed = generate_kwargs.pop('seed', 1234)
torch.manual_seed(seed)
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
out_b = generated_sequence.shape[0]
if self.framework == "pt":
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
elif self.framework == "tf":
from transformers import is_tf_available
if is_tf_available():
import tensorflow as tf
generated_sequence = tf.reshape(generated_sequence,
(in_b, out_b // in_b, *generated_sequence.shape[1:]))
else:
raise ValueError("TF not avaialble.")
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}