|
import torch |
|
torch.cuda.empty_cache() |
|
from peft import PeftModel |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
base_model = "mistralai/Mistral-7B-Instruct-v0.1" |
|
|
|
new_model = "kmichiru/Nikaido-7B-mistral-instruct-v0.3-vn_v2" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
print(tokenizer.pad_token, tokenizer.pad_token_id) |
|
tokenizer.padding_side = "right" |
|
|
|
|
|
base_model_reload = AutoModelForCausalLM.from_pretrained( |
|
base_model, low_cpu_mem_usage=True, |
|
return_dict=True,torch_dtype=torch.bfloat16, |
|
device_map= {"": 0}) |
|
model = PeftModel.from_pretrained(base_model_reload, new_model) |
|
|
|
|
|
|
|
|
|
model.config.use_cache = True |
|
model.eval() |
|
|
|
def dialogue(role, content): |
|
return { |
|
"role": role, |
|
"content": content |
|
} |
|
|
|
|
|
import json, random |
|
TRAIN_DSET = "iroseka_dataset.jsonl" |
|
try: |
|
with open(TRAIN_DSET, "r", encoding="utf-8") as f: |
|
examples = [json.loads(line) for line in f] |
|
except FileNotFoundError: |
|
print("Few-shot data not found, skipping...") |
|
examples = [] |
|
|
|
def format_chat_history(example, few_shot=0): |
|
user_msgs = [] |
|
|
|
for msg in example["messages"]: |
|
|
|
user_msgs.append(msg["content"]) |
|
messages = [ |
|
dialogue("user", "\n".join(user_msgs)), |
|
|
|
] |
|
|
|
if few_shot > 0: |
|
|
|
few_shot_data = random.sample(examples, few_shot) |
|
for few_shot_example in few_shot_data: |
|
few_shot_msgs = [] |
|
for msg in few_shot_example["messages"]: |
|
if msg["role"] == "user": |
|
few_shot_msgs.append(msg["content"]) |
|
messages = [ |
|
dialogue("user", "\n".join(few_shot_msgs)), |
|
few_shot_example["messages"][-1] |
|
] + messages |
|
|
|
encodeds = tokenizer.apply_chat_template(messages, tokenize=False) |
|
return encodeds |
|
|
|
def format_chat_history_v2(example, few_shot): |
|
|
|
user_msg = [] |
|
user_msg.append("<s>") |
|
for msg in example["messages"]: |
|
|
|
user_msg.append(f"[INST] {msg['content']} [/INST]") |
|
|
|
if "next_speaker" in example: |
|
user_msg.append(f"[INST] {example['next_speaker']}: ") |
|
return " ".join(user_msg) |
|
|
|
from transformers import StoppingCriteria, StoppingCriteriaList |
|
class StoppingCriteriaSub(StoppingCriteria): |
|
def __init__(self, stops = [], encounters=1): |
|
super().__init__() |
|
self.stops = [stop.to("cuda") for stop in stops] |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
for seq in input_ids: |
|
for stop in self.stops: |
|
if len(seq) >= len(stop) and torch.all((stop == seq[-len(stop):])).item(): |
|
return True |
|
return False |
|
|
|
stop_words = ["[/INST]"] |
|
stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words] |
|
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) |
|
|
|
def inference(chat_history): |
|
|
|
formatted = format_chat_history_v2(chat_history, few_shot=1) |
|
print(formatted) |
|
model_inputs = tokenizer( |
|
[formatted], |
|
return_tensors="pt", |
|
) |
|
print(model_inputs) |
|
model_inputs = model_inputs.to(model.device) |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
input_ids=model_inputs.input_ids, |
|
attention_mask=model_inputs.attention_mask, |
|
|
|
do_sample=True, |
|
top_p=1, |
|
|
|
|
|
|
|
|
|
temperature=0.3, |
|
|
|
use_cache=True, |
|
|
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
bos_token_id=tokenizer.bos_token_id, |
|
output_scores=True, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
max_new_tokens=256, |
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
text = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
def postprocess(t): |
|
t = t.split("[/INST]") |
|
t = [x.replace("[INST]", "").strip() for x in t] |
|
t = [x for x in t if x != ""] |
|
return t[-1] |
|
|
|
|
|
return text |
|
|
|
|
|
if __name__ == "__main__": |
|
chat_history = { |
|
"messages": [ |
|
|
|
dialogue("user", "ๅ็ฝ: ็็ด
ใฎ่จ่ใ่ธใฎไธญใซๆปใ่พผใใงใใใ"), |
|
dialogue("user", "ๆ ้ฆฌ: ใฃ"), |
|
dialogue("user", "ๅ็ฝ: ้็ใ ใฃใใ"), |
|
dialogue("user", "ๆ ้ฆฌ: ็็ด
๏ผๅคงๅฅฝใใงใใใใใใใใใฃใจไธ็ทใซใใฆใใ ใใใ"), |
|
], |
|
"next_speaker": "็็ด
" |
|
} |
|
print(inference(chat_history)) |
|
|
|
|