|
""" |
|
https://github.com/abetlen/llama-cpp-python/blob/main/examples/gradio_chat/local.py |
|
https://github.com/awinml/llama-cpp-python-bindings |
|
""" |
|
|
|
from simulator import Simulator |
|
from llama_cpp import Llama |
|
import llama_cpp.llama_tokenizer |
|
|
|
|
|
class Qwen2Simulator(Simulator): |
|
|
|
def __init__(self, model_name_or_path=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.llm = Llama( |
|
model_path="Qwen/Qwen1.5-0.5B-Chat-GGUF/qwen1_5-0_5b-chat-q8_0.gguf", |
|
|
|
|
|
|
|
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( |
|
"/workspace/czy/model_weights/Qwen1.5-0.5B-Chat/" |
|
), |
|
verbose=False, |
|
) |
|
|
|
|
|
def generate_query(self, messages): |
|
""" |
|
:param messages: |
|
:return: |
|
""" |
|
assert messages[-1]["role"] != "user" |
|
inputs = self.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=False, |
|
) |
|
inputs = inputs + "<|im_start|>user\n" |
|
return self._generate(inputs) |
|
|
|
|
|
|
|
def generate_response(self, messages): |
|
assert messages[-1]["role"] == "user" |
|
inputs = self.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
return self._generate(inputs) |
|
|
|
|
|
|
|
|
|
def _generate(self, inputs): |
|
|
|
output = self.llm( |
|
inputs, |
|
max_tokens=20, |
|
temperature=0.7, |
|
stop=["<|im_end|>"] |
|
) |
|
output_text = output["choices"][0]["text"] |
|
return output_text |
|
|
|
|
|
|
|
bot = Qwen2Simulator(r"E:\data_model\Qwen2-0.5B-Instruct") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
messages = [ |
|
{"role": "system", "content": "you are a helpful assistant"}, |
|
{"role": "user", "content": "What is the capital of France?"} |
|
] |
|
output = bot.generate_response(messages) |
|
print(output) |
|
|
|
messages = [ |
|
{"role": "system", "content": "you are a helpful assistant"}, |
|
{"role": "user", "content": "hi, what your name"}, |
|
{"role": "assistant", "content": "My name is Jordan"} |
|
] |
|
output = bot.generate_query(messages) |
|
print(output) |
|
|