File size: 3,058 Bytes
ab2ded1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from jina import Deployment
from docarray import BaseDoc
from jina import Executor, requests
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import argparse
import torch


class Prompt(BaseDoc):
    message: list[dict]
    gen_conf: dict


class Generation(BaseDoc):
    text: str


tokenizer = None
model_name = ""


class TokenStreamingExecutor(Executor):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, device_map="auto", torch_dtype="auto"
        )

    @requests(on="/chat")
    async def generate(self, doc: Prompt, **kwargs) -> Generation:
        text = tokenizer.apply_chat_template(
            doc.message,
            tokenize=False,
        )
        inputs = tokenizer([text], return_tensors="pt")
        generation_config = GenerationConfig(
            **doc.gen_conf,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id
        )
        generated_ids = self.model.generate(
            inputs.input_ids, generation_config=generation_config
        )
        generated_ids = [
            output_ids[len(input_ids) :]
            for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
        ]

        response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        yield Generation(text=response)

    @requests(on="/stream")
    async def task(self, doc: Prompt, **kwargs) -> Generation:
        text = tokenizer.apply_chat_template(
            doc.message,
            tokenize=False,
        )
        input = tokenizer([text], return_tensors="pt")
        input_len = input["input_ids"].shape[1]
        max_new_tokens = 512
        if "max_new_tokens" in doc.gen_conf:
            max_new_tokens = doc.gen_conf.pop("max_new_tokens")
        generation_config = GenerationConfig(
            **doc.gen_conf,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id
        )
        for _ in range(max_new_tokens):
            output = self.model.generate(
                **input, max_new_tokens=1, generation_config=generation_config
            )
            if output[0][-1] == tokenizer.eos_token_id:
                break
            yield Generation(
                text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
            )
            input = {
                "input_ids": output,
                "attention_mask": torch.ones(1, len(output[0])),
            }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, help="Model name or path")
    parser.add_argument("--port", default=12345, type=int, help="Jina serving port")
    args = parser.parse_args()
    model_name = args.model_name
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    with Deployment(
        uses=TokenStreamingExecutor, port=args.port, protocol="grpc"
    ) as dep:
        dep.block()