|
from transformers import TextGenerationPipeline |
|
from transformers.pipelines.text_generation import ReturnType |
|
|
|
from stopping import get_stopping |
|
from prompter import Prompter |
|
|
|
|
|
class H2OTextGenerationPipeline(TextGenerationPipeline): |
|
def __init__(self, *args, debug=False, chat=False, stream_output=False, |
|
sanitize_bot_response=True, |
|
use_prompter=True, prompter=None, prompt_type=None, |
|
max_input_tokens=2048 - 256, **kwargs): |
|
""" |
|
HF-like pipeline, but handle instruction prompting and stopping (for some models) |
|
:param args: |
|
:param debug: |
|
:param chat: |
|
:param stream_output: |
|
:param sanitize_bot_response: |
|
:param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter |
|
:param prompter: prompter, can pass if have already |
|
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py. |
|
If use_prompter, then will make prompter and use it. |
|
:param max_input_tokens: |
|
:param kwargs: |
|
""" |
|
super().__init__(*args, **kwargs) |
|
self.prompt_text = None |
|
self.use_prompter = use_prompter |
|
self.prompt_type = prompt_type |
|
self.prompter = prompter |
|
if self.use_prompter: |
|
if self.prompter is not None: |
|
assert self.prompter.prompt_type is not None |
|
else: |
|
self.prompter = Prompter(self.prompt_type, debug=debug, chat=chat, stream_output=stream_output) |
|
self.human = self.prompter.humanstr |
|
self.bot = self.prompter.botstr |
|
self.can_stop = True |
|
else: |
|
self.prompter = None |
|
self.human = None |
|
self.bot = None |
|
self.can_stop = False |
|
self.sanitize_bot_response = sanitize_bot_response |
|
self.max_input_tokens = max_input_tokens |
|
|
|
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs): |
|
data_point = dict(context='', instruction=prompt_text, input='') |
|
if self.prompter is not None: |
|
prompt_text = self.prompter.generate_prompt(data_point) |
|
self.prompt_text = prompt_text |
|
if handle_long_generation is None: |
|
|
|
handle_long_generation = 'hole' |
|
return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation, |
|
**generate_kwargs) |
|
|
|
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True): |
|
records = super().postprocess(model_outputs, return_type=return_type, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces) |
|
for rec in records: |
|
if self.use_prompter: |
|
outputs = rec['generated_text'] |
|
outputs = self.prompter.get_response(outputs, prompt=self.prompt_text, |
|
sanitize_bot_response=self.sanitize_bot_response) |
|
elif self.bot and self.human: |
|
outputs = rec['generated_text'].split(self.bot)[1].strip().split(self.human)[0].strip() |
|
else: |
|
outputs = rec['generated_text'] |
|
rec['generated_text'] = outputs |
|
return records |
|
|
|
def _forward(self, model_inputs, **generate_kwargs): |
|
if self.can_stop: |
|
stopping_criteria = get_stopping(self.prompt_type, self.tokenizer, self.device, human=self.human, |
|
bot=self.bot) |
|
generate_kwargs['stopping_criteria'] = stopping_criteria |
|
|
|
return self.__forward(model_inputs, **generate_kwargs) |
|
|
|
|
|
|
|
def __forward(self, model_inputs, **generate_kwargs): |
|
input_ids = model_inputs["input_ids"] |
|
attention_mask = model_inputs.get("attention_mask", None) |
|
|
|
if input_ids.shape[1] == 0: |
|
input_ids = None |
|
attention_mask = None |
|
in_b = 1 |
|
else: |
|
in_b = input_ids.shape[0] |
|
prompt_text = model_inputs.pop("prompt_text") |
|
|
|
|
|
|
|
|
|
prefix_length = generate_kwargs.pop("prefix_length", 0) |
|
if prefix_length > 0: |
|
has_max_new_tokens = "max_new_tokens" in generate_kwargs or ( |
|
"generation_config" in generate_kwargs |
|
and generate_kwargs["generation_config"].max_new_tokens is not None |
|
) |
|
if not has_max_new_tokens: |
|
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length |
|
generate_kwargs["max_length"] += prefix_length |
|
has_min_new_tokens = "min_new_tokens" in generate_kwargs or ( |
|
"generation_config" in generate_kwargs |
|
and generate_kwargs["generation_config"].min_new_tokens is not None |
|
) |
|
if not has_min_new_tokens and "min_length" in generate_kwargs: |
|
generate_kwargs["min_length"] += prefix_length |
|
|
|
|
|
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) |
|
out_b = generated_sequence.shape[0] |
|
if self.framework == "pt": |
|
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) |
|
elif self.framework == "tf": |
|
from transformers import is_tf_available |
|
if is_tf_available(): |
|
import tensorflow as tf |
|
generated_sequence = tf.reshape(generated_sequence, |
|
(in_b, out_b // in_b, *generated_sequence.shape[1:])) |
|
else: |
|
raise ValueError("TF not avaialble.") |
|
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text} |
|
|