File size: 5,898 Bytes
14c9e49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
torch.cuda.empty_cache()
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
base_model = "mistralai/Mistral-7B-Instruct-v0.1"
# new_model = "kmichiru/Nikaido-7B-mistral-instruct-v0.1"
new_model = "kmichiru/Nikaido-7B-mistral-instruct-v0.3-vn_v2"
# Reload tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
print(tokenizer.pad_token, tokenizer.pad_token_id)
tokenizer.padding_side = "right"
# Reload the base model
base_model_reload = AutoModelForCausalLM.from_pretrained(
base_model, low_cpu_mem_usage=True,
return_dict=True,torch_dtype=torch.bfloat16,
device_map= {"": 0})
model = PeftModel.from_pretrained(base_model_reload, new_model)
# model = model.merge_and_unload()
model.config.use_cache = True
model.eval()
def dialogue(role, content):
return {
"role": role,
"content": content
}
import json, random
TRAIN_DSET = "iroseka_dataset.jsonl"
try:
with open(TRAIN_DSET, "r", encoding="utf-8") as f:
examples = [json.loads(line) for line in f]
except FileNotFoundError:
print("Few-shot data not found, skipping...")
examples = []
def format_chat_history(example, few_shot=0):
user_msgs = []
# for inference each round, we only need the user messages
for msg in example["messages"]:
# if msg["role"] == "user":
user_msgs.append(msg["content"])
messages = [
dialogue("user", "\n".join(user_msgs)), # join user messages together
# example["messages"][-1], # the last message is the bot's response
]
if few_shot > 0:
# randomly sample a few messages from the dialogue history
few_shot_data = random.sample(examples, few_shot)
for few_shot_example in few_shot_data:
few_shot_msgs = []
for msg in few_shot_example["messages"]:
if msg["role"] == "user":
few_shot_msgs.append(msg["content"])
messages = [
dialogue("user", "\n".join(few_shot_msgs)),
few_shot_example["messages"][-1]
] + messages
encodeds = tokenizer.apply_chat_template(messages, tokenize=False)
return encodeds
def format_chat_history_v2(example, few_shot):
# TODO: implement few-shot learning
user_msg = []
user_msg.append("<s>")
for msg in example["messages"]:
# [INST] What is your favourite condiment? [/INST]
user_msg.append(f"[INST] {msg['content']} [/INST]")
# user_msg.append("</s>")
if "next_speaker" in example:
user_msg.append(f"[INST] {example['next_speaker']}: ")
return " ".join(user_msg)
from transformers import StoppingCriteria, StoppingCriteriaList
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops = [], encounters=1):
super().__init__()
self.stops = [stop.to("cuda") for stop in stops]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for seq in input_ids:
for stop in self.stops:
if len(seq) >= len(stop) and torch.all((stop == seq[-len(stop):])).item():
return True
return False
stop_words = ["[/INST]"]
stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def inference(chat_history):
# chat_history: dict, with "messages" key storing dialogue history, in OpenAI format
formatted = format_chat_history_v2(chat_history, few_shot=1)
print(formatted)
model_inputs = tokenizer(
[formatted],
return_tensors="pt",
)
print(model_inputs)
model_inputs = model_inputs.to(model.device)
with torch.no_grad():
outputs = model.generate(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
# max_length=1024,
do_sample=True,
top_p=1,
# contrastive search
# top_k=50,
# penalty_alpha=0.6,
# num_return_sequences=1,
temperature=0.3,
# num_return_sequences=3,
use_cache=True,
# pad_token_id=tokenizer.eos_token_id, # eos_token_id is not available for some models
pad_token_id=tokenizer.pad_token_id, # eos_token_id is not available for some models
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
output_scores=True,
output_attentions=False,
output_hidden_states=False,
max_new_tokens=256,
# num_beams=9,
# num_beam_groups=3,
# repetition_penalty=1.0,
# diversity_penalty=0.5,
# num_beams=5,
# stopping_criteria=stopping_criteria,
)
# print(outputs)
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
def postprocess(t):
t = t.split("[/INST]")
t = [x.replace("[INST]", "").strip() for x in t]
t = [x for x in t if x != ""]
return t[-1]
# text = [postprocess(t) for t in text]
return text
if __name__ == "__main__":
chat_history = {
"messages": [
# dialogue("system", ""),
dialogue("user", "ๅ็ฝ: ็็ด
ใฎ่จ่ใ่ธใฎไธญใซๆปใ่พผใใงใใใ"),
dialogue("user", "ๆ ้ฆฌ: ใฃ"),
dialogue("user", "ๅ็ฝ: ้็ใ ใฃใใ"),
dialogue("user", "ๆ ้ฆฌ: ็็ด
๏ผๅคงๅฅฝใใงใใใใใใใใใฃใจไธ็ทใซใใฆใใ ใใใ"),
],
"next_speaker": "็็ด
"
}
print(inference(chat_history))
|