File size: 1,300 Bytes
6aca619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType

STYLE = "<|prompt|>{instruction}</s><|answer|>"


class H2OTextGenerationPipeline(TextGenerationPipeline):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.prompt = STYLE

    def preprocess(
        self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs
    ):
        prompt_text = self.prompt.format(instruction=prompt_text)
        return super().preprocess(
            prompt_text,
            prefix=prefix,
            handle_long_generation=handle_long_generation,
            **generate_kwargs,
        )

    def postprocess(
        self,
        model_outputs,
        return_type=ReturnType.FULL_TEXT,
        clean_up_tokenization_spaces=True,
    ):
        records = super().postprocess(
            model_outputs,
            return_type=return_type,
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
        )
        for rec in records:
            rec["generated_text"] = (
                rec["generated_text"]
                .split("<|answer|>")[1]
                .strip()
                .split("<|prompt|>")[0]
                .strip()
            )
        return records