kmichiru's picture
Upload 5 files
14c9e49
import torch
torch.cuda.empty_cache()
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
base_model = "mistralai/Mistral-7B-Instruct-v0.1"
# new_model = "kmichiru/Nikaido-7B-mistral-instruct-v0.1"
new_model = "kmichiru/Nikaido-7B-mistral-instruct-v0.3-vn_v2"
# Reload tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
print(tokenizer.pad_token, tokenizer.pad_token_id)
tokenizer.padding_side = "right"
# Reload the base model
base_model_reload = AutoModelForCausalLM.from_pretrained(
base_model, low_cpu_mem_usage=True,
return_dict=True,torch_dtype=torch.bfloat16,
device_map= {"": 0})
model = PeftModel.from_pretrained(base_model_reload, new_model)
# model = model.merge_and_unload()
model.config.use_cache = True
model.eval()
def dialogue(role, content):
return {
"role": role,
"content": content
}
import json, random
TRAIN_DSET = "iroseka_dataset.jsonl"
try:
with open(TRAIN_DSET, "r", encoding="utf-8") as f:
examples = [json.loads(line) for line in f]
except FileNotFoundError:
print("Few-shot data not found, skipping...")
examples = []
def format_chat_history(example, few_shot=0):
user_msgs = []
# for inference each round, we only need the user messages
for msg in example["messages"]:
# if msg["role"] == "user":
user_msgs.append(msg["content"])
messages = [
dialogue("user", "\n".join(user_msgs)), # join user messages together
# example["messages"][-1], # the last message is the bot's response
]
if few_shot > 0:
# randomly sample a few messages from the dialogue history
few_shot_data = random.sample(examples, few_shot)
for few_shot_example in few_shot_data:
few_shot_msgs = []
for msg in few_shot_example["messages"]:
if msg["role"] == "user":
few_shot_msgs.append(msg["content"])
messages = [
dialogue("user", "\n".join(few_shot_msgs)),
few_shot_example["messages"][-1]
] + messages
encodeds = tokenizer.apply_chat_template(messages, tokenize=False)
return encodeds
def format_chat_history_v2(example, few_shot):
# TODO: implement few-shot learning
user_msg = []
user_msg.append("<s>")
for msg in example["messages"]:
# [INST] What is your favourite condiment? [/INST]
user_msg.append(f"[INST] {msg['content']} [/INST]")
# user_msg.append("</s>")
if "next_speaker" in example:
user_msg.append(f"[INST] {example['next_speaker']}: ")
return " ".join(user_msg)
from transformers import StoppingCriteria, StoppingCriteriaList
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops = [], encounters=1):
super().__init__()
self.stops = [stop.to("cuda") for stop in stops]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for seq in input_ids:
for stop in self.stops:
if len(seq) >= len(stop) and torch.all((stop == seq[-len(stop):])).item():
return True
return False
stop_words = ["[/INST]"]
stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def inference(chat_history):
# chat_history: dict, with "messages" key storing dialogue history, in OpenAI format
formatted = format_chat_history_v2(chat_history, few_shot=1)
print(formatted)
model_inputs = tokenizer(
[formatted],
return_tensors="pt",
)
print(model_inputs)
model_inputs = model_inputs.to(model.device)
with torch.no_grad():
outputs = model.generate(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
# max_length=1024,
do_sample=True,
top_p=1,
# contrastive search
# top_k=50,
# penalty_alpha=0.6,
# num_return_sequences=1,
temperature=0.3,
# num_return_sequences=3,
use_cache=True,
# pad_token_id=tokenizer.eos_token_id, # eos_token_id is not available for some models
pad_token_id=tokenizer.pad_token_id, # eos_token_id is not available for some models
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
output_scores=True,
output_attentions=False,
output_hidden_states=False,
max_new_tokens=256,
# num_beams=9,
# num_beam_groups=3,
# repetition_penalty=1.0,
# diversity_penalty=0.5,
# num_beams=5,
# stopping_criteria=stopping_criteria,
)
# print(outputs)
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
def postprocess(t):
t = t.split("[/INST]")
t = [x.replace("[INST]", "").strip() for x in t]
t = [x for x in t if x != ""]
return t[-1]
# text = [postprocess(t) for t in text]
return text
if __name__ == "__main__":
chat_history = {
"messages": [
# dialogue("system", ""),
dialogue("user", "ๅ‚็™ฝ: ็œŸ็ด…ใฎ่จ€่‘‰ใŒ่ƒธใฎไธญใซๆป‘ใ‚Š่พผใ‚“ใงใใ‚‹ใ€‚"),
dialogue("user", "ๆ‚ ้ฆฌ: ใฃ"),
dialogue("user", "ๅ‚็™ฝ: ้™็•Œใ ใฃใŸใ€‚"),
dialogue("user", "ๆ‚ ้ฆฌ: ็œŸ็ด…๏ผŒๅคงๅฅฝใใงใ™ใ€‚ใ“ใ‚Œใ‹ใ‚‰ใ‚‚ใšใฃใจไธ€็ท’ใซใ„ใฆใใ ใ•ใ„ใ€‚"),
],
"next_speaker": "็œŸ็ด…"
}
print(inference(chat_history))