|
import torch, sys
|
|
import transformers
|
|
|
|
try: model_path = sys.argv[1]
|
|
except: model_path = "e1.0"
|
|
|
|
print(f"Loading {model_path} ...")
|
|
|
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
device_map = "auto",
|
|
torch_dtype = torch.bfloat16,
|
|
)
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(".")
|
|
|
|
from qwen_vocab import old2new, new2old
|
|
STOP_WORDS = "<|im_end|> <|endoftext|>".split()
|
|
|
|
|
|
def map_tids(map_dict, tids):
|
|
return [ map_dict[x] for x in tids if x in map_dict ]
|
|
|
|
|
|
class KeywordsStoppingCriteria(transformers.StoppingCriteria):
|
|
def __init__(self, str):
|
|
self.keyword_ids = tokenizer.encode(str)
|
|
self.keyword_ids = map_tids(old2new, self.keyword_ids)
|
|
self.keyword_len = len(self.keyword_ids)
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
last_token_ids = input_ids[0][-self.keyword_len:]
|
|
return last_token_ids.tolist() == self.keyword_ids
|
|
|
|
stop_criteria_list = transformers.StoppingCriteriaList(
|
|
[ KeywordsStoppingCriteria(x) for x in STOP_WORDS ]
|
|
)
|
|
|
|
|
|
def get_answer(q):
|
|
if len(q) < 3: return "..."
|
|
|
|
prompt = f"<|im_start|>user\n{q}<|im_end|>\n<|im_start|>assistant"
|
|
old_tids = tokenizer.encode(prompt)
|
|
|
|
new_tids = map_tids(old2new, old_tids)
|
|
new_old_tids = map_tids(new2old, new_tids)
|
|
|
|
new_prompt = tokenizer.decode(new_old_tids)
|
|
|
|
if new_old_tids != old_tids:
|
|
print(f"!!! Cảnh báo sự trimm vocab làm mất thông tin !!!")
|
|
print(f"!!! old prompt: {prompt}")
|
|
print(f"!!! new prompt: {new_prompt}")
|
|
|
|
inputs = tokenizer(new_prompt, return_tensors="pt").to(model.device)
|
|
|
|
assert inputs["input_ids"][0].tolist() == new_old_tids
|
|
|
|
for i, x in enumerate(new_tids):
|
|
inputs["input_ids"][0][i] = x
|
|
|
|
with torch.no_grad():
|
|
output_ids = model.generate(
|
|
**inputs,
|
|
max_new_tokens=512,
|
|
temperature=0.3,
|
|
top_p=1.0, top_k=30, do_sample=True,
|
|
repetition_penalty=1.1,
|
|
stopping_criteria=stop_criteria_list,
|
|
pad_token_id=tokenizer.pad_token_id,
|
|
)
|
|
|
|
answer_tids = output_ids[0][len(inputs["input_ids"][0]) : ]
|
|
old_tids = map_tids(new2old, answer_tids.tolist())
|
|
|
|
|
|
return tokenizer.decode(old_tids)\
|
|
.split("<|im_end|>")[0].split("<end_of_turn>")[0].strip()
|
|
|
|
|
|
from utils import *
|
|
while True:
|
|
|
|
try: q = input(f"Bạn: {GREEN}").encode('utf-8', 'ignore').decode('utf-8', 'ignore')
|
|
except Exception as e: print(f"{RESET}{e}"); q = ""
|
|
|
|
reset_timer(timer=model_path)
|
|
a = get_answer(q).strip()
|
|
print(f"{RESET}Bot: {RED}{a}{RESET}")
|
|
measure_time("timespent", timer=model_path)
|
|
|
|
'''
|
|
python3 model_chat.py ../Qwen2.5-1.5B-Instruct__trimm_vocab
|
|
|
|
python3 model_chat.py ../Qwen2.5-1.5B-Instruct
|
|
|
|
số tuổi của An trừ đi số tuổi của Lan là 3, An 10 tuổi hỏi Lan mấy tuổi?
|
|
|
|
ai tạo ra bạn
|
|
|
|
Bạn: tạo ra một câu hoàn chỉnh với từ "thực hiện"
|
|
Bot: Thì ra, việc thực hiện kế hoạch của chúng ta cần được lên lịch cụ thể.
|
|
'''
|
|
|