File size: 2,352 Bytes
6dc0c9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gc
from threading import Thread

import torch
from transformers import TextIteratorStreamer


@torch.inference_mode()
def generate_stream_xft(
    model,
    tokenizer,
    params,
    device,
    context_len=8192,
    stream_interval=2,
    judge_sent_end=False,
):
    prompt = params["prompt"]
    repetition_penalty = float(params.get("repetition_penalty", 1.0))

    # unused now, and placehold for future.
    # temperature = float(params.get("temperature", 1.0))
    # top_p = float(params.get("top_p", 1.0))

    max_new_tokens = int(params.get("max_new_tokens", 4096))
    echo = params.get("echo", True)

    inputs = tokenizer(
        prompt, return_tensors="pt", padding=model.config.padding
    ).input_ids
    input_echo_len = len(inputs[0])
    max_len = max_new_tokens + input_echo_len

    decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config)
    generation_kwargs = {
        "input_ids": inputs,
        "streamer": streamer,
        "max_length": max_len,
        "num_beams": model.config.beam_width,
        "length_penalty": repetition_penalty,
        "num_return_sequences": model.config.num_return_sequences,
        "early_stopping": model.config.early_stopping,
        "eos_token_id": model.config.eos_token_id,
        "pad_token_id": model.config.pad_token_id,
    }

    thread = Thread(target=model.model.generate, kwargs=generation_kwargs)
    thread.start()
    if echo:
        # means keep the prompt
        output = prompt
    else:
        output = ""
    i = 0
    for i, new_text in enumerate(streamer):
        output += new_text
        yield {
            "text": output,
            "usage": {
                "prompt_tokens": input_echo_len,
                "completion_tokens": i,
                "total_tokens": input_echo_len + i,
            },
            "finish_reason": None,
        }
    output = output.strip()
    if i == max_new_tokens - 1:
        finish_reason = "length"
    else:
        finish_reason = "stop"
    yield {
        "text": output,
        "usage": {
            "prompt_tokens": input_echo_len,
            "completion_tokens": i,
            "total_tokens": input_echo_len + i,
        },
        "finish_reason": finish_reason,
    }
    gc.collect()