File size: 1,300 Bytes
d536a02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
STYLE = "<|prompt|>{instruction}</s><|answer|>"
class H2OTextGenerationPipeline(TextGenerationPipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt = STYLE
def preprocess(
self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs
):
prompt_text = self.prompt.format(instruction=prompt_text)
return super().preprocess(
prompt_text,
prefix=prefix,
handle_long_generation=handle_long_generation,
**generate_kwargs,
)
def postprocess(
self,
model_outputs,
return_type=ReturnType.FULL_TEXT,
clean_up_tokenization_spaces=True,
):
records = super().postprocess(
model_outputs,
return_type=return_type,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
for rec in records:
rec["generated_text"] = (
rec["generated_text"]
.split("<|answer|>")[1]
.strip()
.split("<|prompt|>")[0]
.strip()
)
return records |