File size: 3,263 Bytes
4870204
 
 
570c043
4870204
 
8b0ae10
4870204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00263ef
 
 
 
 
 
 
 
 
 
 
 
 
 
4870204
 
 
 
00263ef
4870204
 
00263ef
4870204
 
 
 
 
00263ef
4870204
 
 
 
 
 
 
 
00263ef
750c900
00263ef
 
 
 
750c900
00263ef
4870204
 
 
 
 
 
00263ef
750c900
4870204
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import transformers

from .get_device import get_device
from .streaming_generation_utils import Iteratorize, Stream


def generate(
    # model
    model,
    tokenizer,
    # input
    prompt,
    generation_config,
    max_new_tokens,
    stopping_criteria=[],
    # output options
    stream_output=False
):
    device = get_device()

    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    generate_params = {
        "input_ids": input_ids,
        "generation_config": generation_config,
        "return_dict_in_generate": True,
        "output_scores": True,
        "max_new_tokens": max_new_tokens,
        "stopping_criteria": transformers.StoppingCriteriaList() + stopping_criteria
    }

    skip_special_tokens = True

    if '/dolly' in tokenizer.name_or_path:
        # dolly has additional_special_tokens as ['### End', '### Instruction:', '### Response:'], skipping them will break the prompter's reply extraction.
        skip_special_tokens = False
        # Ensure generation stops once it generates "### End"
        end_key_token_id = tokenizer.encode("### End")
        end_key_token_id = end_key_token_id[0]  # 50277
        if isinstance(generate_params['generation_config'].eos_token_id, str):
            generate_params['generation_config'].eos_token_id = [generate_params['generation_config'].eos_token_id]
        elif not generate_params['generation_config'].eos_token_id:
            generate_params['generation_config'].eos_token_id = []
        generate_params['generation_config'].eos_token_id.append(end_key_token_id)

    if stream_output:
        # Stream the reply 1 token at a time.
        # This is based on the trick of using 'stopping_criteria' to create an iterator,
        # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
        generation_output = None

        def generate_with_callback(callback=None, **kwargs):
            nonlocal generation_output
            kwargs["stopping_criteria"].insert(
                0,
                Stream(callback_func=callback)
            )
            with torch.no_grad():
                generation_output = model.generate(**kwargs)

        def generate_with_streaming(**kwargs):
            return Iteratorize(
                generate_with_callback, kwargs, callback=None
            )

        with generate_with_streaming(**generate_params) as generator:
            for output in generator:
                decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
                yield decoded_output, output, False

        if generation_output:
            output = generation_output.sequences[0]
            decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
            yield decoded_output, output, True

        return  # early return for stream_output

    # Without streaming
    with torch.no_grad():
        generation_output = model.generate(**generate_params)
    output = generation_output.sequences[0]
    decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
    yield decoded_output, output, True
    return