|
import torch, sys |
|
import transformers |
|
|
|
try: model_path = sys.argv[1] |
|
except: model_path = "e2.0" |
|
|
|
print(f"Loading {model_path} ...") |
|
|
|
model = transformers.AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
device_map = "auto", |
|
torch_dtype = torch.bfloat16, |
|
) |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(".") |
|
|
|
from qwen_vocab import old2new, new2old |
|
STOP_WORDS = "<|im_end|> <|endoftext|>".split() |
|
|
|
|
|
def map_tids(map_dict, tids): |
|
return [ map_dict[x] for x in tids if x in map_dict ] |
|
|
|
|
|
class KeywordsStoppingCriteria(transformers.StoppingCriteria): |
|
def __init__(self, str): |
|
self.keyword_ids = tokenizer.encode(str) |
|
self.keyword_ids = map_tids(old2new, self.keyword_ids) |
|
self.keyword_len = len(self.keyword_ids) |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
last_token_ids = input_ids[0][-self.keyword_len:] |
|
return last_token_ids.tolist() == self.keyword_ids |
|
|
|
stop_criteria_list = transformers.StoppingCriteriaList( |
|
[ KeywordsStoppingCriteria(x) for x in STOP_WORDS ] |
|
) |
|
|
|
|
|
def get_answer(q): |
|
if len(q) < 3: return "..." |
|
|
|
prompt = f"<|im_start|>user\n{q}<|im_end|>\n<|im_start|>assistant" |
|
old_tids = tokenizer.encode(prompt) |
|
|
|
new_tids = map_tids(old2new, old_tids) |
|
new_old_tids = map_tids(new2old, new_tids) |
|
|
|
new_prompt = tokenizer.decode(new_old_tids) |
|
|
|
if new_old_tids != old_tids: |
|
print(f"!!! Cảnh báo sự trimm vocab làm mất thông tin !!!") |
|
print(f"!!! old prompt: {prompt}") |
|
print(f"!!! new prompt: {new_prompt}") |
|
|
|
inputs = tokenizer(new_prompt, return_tensors="pt").to(model.device) |
|
|
|
assert inputs["input_ids"][0].tolist() == new_old_tids |
|
|
|
for i, x in enumerate(new_tids): |
|
inputs["input_ids"][0][i] = x |
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate( |
|
**inputs, |
|
max_new_tokens=512, |
|
temperature=0.3, |
|
top_p=1.0, top_k=30, do_sample=True, |
|
repetition_penalty=1.1, |
|
stopping_criteria=stop_criteria_list, |
|
pad_token_id=tokenizer.pad_token_id, |
|
) |
|
|
|
answer_tids = output_ids[0][len(inputs["input_ids"][0]) : ] |
|
old_tids = map_tids(new2old, answer_tids.tolist()) |
|
|
|
|
|
return tokenizer.decode(old_tids)\ |
|
.split("<|im_end|>")[0].split("<end_of_turn>")[0].strip() |
|
|
|
|
|
from utils import * |
|
while True: |
|
|
|
try: q = input(f"Bạn: {GREEN}").encode('utf-8', 'ignore').decode('utf-8', 'ignore') |
|
except Exception as e: print(f"{RESET}{e}"); q = "" |
|
|
|
reset_timer(timer=model_path) |
|
a = get_answer(q).strip() |
|
print(f"{RESET}Bot: {RED}{a}{RESET}") |
|
measure_time("timespent", timer=model_path) |
|
|
|
''' |
|
python3 model_chat.py ../Qwen2.5-1.5B-Instruct__trimm_vocab |
|
|
|
python3 model_chat.py ../Qwen2.5-1.5B-Instruct |
|
|
|
số tuổi của An trừ đi số tuổi của Lan là 3, An 10 tuổi hỏi Lan mấy tuổi? |
|
|
|
ai tạo ra bạn |
|
|
|
Bạn: tạo ra một câu hoàn chỉnh với từ "thực hiện" |
|
Bot: Thì ra, việc thực hiện kế hoạch của chúng ta cần được lên lịch cụ thể. |
|
''' |
|
|