File size: 2,753 Bytes
d72c532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"Qwen/Qwen2-0.5B-Instruct"

from threading import Thread
from simulator import Simulator

from transformers import TextIteratorStreamer


class Qwen2Simulator(Simulator):

    def generate_query(self, history):

        inputs = ""
        if history:
            messages = []
            for query, response in history:
                messages += [
                    {"role": "user", "content": query},
                    {"role": "assistant", "content": response},
                ]

            inputs += self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False,
            )
        inputs = inputs + "<|im_start|>user\n"
        input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.model.device)
        return self._generate(input_ids)
        # for new_text in self._stream_generate(input_ids):
        #     yield new_text

    def generate_response(self, query, history):
        messages = []
        for _query, _response in history:
            if _response is None:
                pass
            messages += [
                {"role": "user", "content": _query},
                {"role": "assistant", "content": _response},
            ]
        messages.append({"role": "user", "content": query})

        input_ids = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            return_tensors="pt",
            add_generation_prompt=True
        ).to(self.model.device)
        return self._generate(input_ids)
        # for new_text in self._stream_generate(input_ids):
        #     yield new_text

    def _generate(self, input_ids):

        input_ids_length = input_ids.shape[-1]
        response = self.model.generate(input_ids=input_ids, **self.generation_kwargs)
        return self.tokenizer.decode(response[0][input_ids_length:], skip_special_tokens=True)

    def _stream_generate(self, input_ids):
        streamer = TextIteratorStreamer(tokenizer=self.tokenizer, skip_prompt=True, timeout=60.0,
                                        skip_special_tokens=True)

        stream_generation_kwargs = dict(
            input_ids=input_ids,
            streamer=streamer
        ).update(self.generation_kwargs)
        thread = Thread(target=self.model.generate, kwargs=stream_generation_kwargs)
        thread.start()

        for new_text in streamer:
            yield new_text


bot = Qwen2Simulator(r"E:\data_model\Qwen2-0.5B-Instruct")
# bot = Qwen2Simulator("Qwen/Qwen2-0.5B-Instruct")


#
# history = [["hi, what your name", "rhino"]]
# generated_query = bot.generate_query(history)
# for char in generated_query:
#     print(char)
#
# bot.generate_response("1+2*3=", history)