TTaiyi / model.py
DUTwangzhijun's picture
Upload 6 files
9f8f34e
raw
history blame
2.65 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModel
model_name = "DUTIRbionlp/Taiyi-LLM"
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float16
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
import logging
logging.disable(logging.WARNING)
tokenizer.pad_token_id = tokenizer.eod_id
tokenizer.bos_token_id = tokenizer.eod_id
tokenizer.eos_token_id = tokenizer.eod_id
# 开始对话
history_max_len = 1000
utterance_id = 0
def run(message: str,
history: str,
max_new_tokens: int = 500,
temperature: float = 0.10,
top_p: float = 0.9,
repetition_penalty: float = 1.0):
list1 = []
for question, response in history:
question = tokenizer(question, return_tensors="pt", add_special_tokens=False).input_ids
# eos_token_id = [tokenizer.eos_token_id]
eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long)
response = tokenizer(response, return_tensors="pt", add_special_tokens=False).input_ids
all_token = torch.concat((question, eos_token_id, response, eos_token_id), dim=1)
list1.extend(all_token)
connect_tensor = torch.tensor([])
for tensor in list1:
connect_tensor = torch.concat((connect_tensor, tensor), dim=0)
history_token_ids = connect_tensor.reshape(1,-1)
user_input = message
input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids
bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long)
eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long)
user_input_ids = torch.concat([bos_token_id,input_ids, eos_token_id], dim=1)
input_token_ids = torch.concat((history_token_ids, user_input_ids), dim=1)
model_input_ids = input_token_ids[:, -history_max_len:].to(torch.int)
with torch.no_grad():
outputs = model.generate(
input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p,
temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id
)
model_input_ids_len = model_input_ids.size(1)
response_ids = outputs[:, model_input_ids_len:]
history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1)
response = tokenizer.batch_decode(response_ids)
return response[0].strip().replace(tokenizer.eos_token, "").replace("\n", "\n\n")