File size: 2,753 Bytes
d72c532 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
"Qwen/Qwen2-0.5B-Instruct"
from threading import Thread
from simulator import Simulator
from transformers import TextIteratorStreamer
class Qwen2Simulator(Simulator):
def generate_query(self, history):
inputs = ""
if history:
messages = []
for query, response in history:
messages += [
{"role": "user", "content": query},
{"role": "assistant", "content": response},
]
inputs += self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
inputs = inputs + "<|im_start|>user\n"
input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.model.device)
return self._generate(input_ids)
# for new_text in self._stream_generate(input_ids):
# yield new_text
def generate_response(self, query, history):
messages = []
for _query, _response in history:
if _response is None:
pass
messages += [
{"role": "user", "content": _query},
{"role": "assistant", "content": _response},
]
messages.append({"role": "user", "content": query})
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
return_tensors="pt",
add_generation_prompt=True
).to(self.model.device)
return self._generate(input_ids)
# for new_text in self._stream_generate(input_ids):
# yield new_text
def _generate(self, input_ids):
input_ids_length = input_ids.shape[-1]
response = self.model.generate(input_ids=input_ids, **self.generation_kwargs)
return self.tokenizer.decode(response[0][input_ids_length:], skip_special_tokens=True)
def _stream_generate(self, input_ids):
streamer = TextIteratorStreamer(tokenizer=self.tokenizer, skip_prompt=True, timeout=60.0,
skip_special_tokens=True)
stream_generation_kwargs = dict(
input_ids=input_ids,
streamer=streamer
).update(self.generation_kwargs)
thread = Thread(target=self.model.generate, kwargs=stream_generation_kwargs)
thread.start()
for new_text in streamer:
yield new_text
bot = Qwen2Simulator(r"E:\data_model\Qwen2-0.5B-Instruct")
# bot = Qwen2Simulator("Qwen/Qwen2-0.5B-Instruct")
#
# history = [["hi, what your name", "rhino"]]
# generated_query = bot.generate_query(history)
# for char in generated_query:
# print(char)
#
# bot.generate_response("1+2*3=", history)
|