|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
|
|
|
model_path = "./llama3-5b/hf" |
|
|
|
|
|
quantization_config = None |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_path, |
|
device_map="auto", |
|
quantization_config=quantization_config, |
|
output_hidden_states=True) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful assistant."} |
|
] |
|
|
|
|
|
def generate_response(messages): |
|
input_ids = tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(model.device) |
|
|
|
terminators = [ |
|
tokenizer.eos_token_id, |
|
tokenizer.convert_tokens_to_ids("<|eot_id|>") |
|
] |
|
|
|
outputs = model.generate( |
|
input_ids, |
|
max_new_tokens=256, |
|
eos_token_id=terminators, |
|
do_sample=True, |
|
temperature=0.6, |
|
top_p=0.9, |
|
) |
|
response = outputs[0][input_ids.shape[-1]:] |
|
return tokenizer.decode(response, skip_special_tokens=True) |
|
|
|
|
|
while True: |
|
|
|
user_input = input("User: ") |
|
|
|
|
|
if user_input.lower() == 'q': |
|
break |
|
|
|
|
|
messages.append({"role": "user", "content": user_input}) |
|
|
|
|
|
response = generate_response(messages) |
|
print("Assistant:", response) |
|
|
|
|
|
messages.append({"role": "assistant", "content": response}) |
|
|