import torch torch.cuda.empty_cache() from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer base_model = "mistralai/Mistral-7B-Instruct-v0.1" # new_model = "kmichiru/Nikaido-7B-mistral-instruct-v0.1" new_model = "kmichiru/Nikaido-7B-mistral-instruct-v0.3-vn_v2" # Reload tokenizer tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token print(tokenizer.pad_token, tokenizer.pad_token_id) tokenizer.padding_side = "right" # Reload the base model base_model_reload = AutoModelForCausalLM.from_pretrained( base_model, low_cpu_mem_usage=True, return_dict=True,torch_dtype=torch.bfloat16, device_map= {"": 0}) model = PeftModel.from_pretrained(base_model_reload, new_model) # model = model.merge_and_unload() model.config.use_cache = True model.eval() def dialogue(role, content): return { "role": role, "content": content } import json, random TRAIN_DSET = "iroseka_dataset.jsonl" try: with open(TRAIN_DSET, "r", encoding="utf-8") as f: examples = [json.loads(line) for line in f] except FileNotFoundError: print("Few-shot data not found, skipping...") examples = [] def format_chat_history(example, few_shot=0): user_msgs = [] # for inference each round, we only need the user messages for msg in example["messages"]: # if msg["role"] == "user": user_msgs.append(msg["content"]) messages = [ dialogue("user", "\n".join(user_msgs)), # join user messages together # example["messages"][-1], # the last message is the bot's response ] if few_shot > 0: # randomly sample a few messages from the dialogue history few_shot_data = random.sample(examples, few_shot) for few_shot_example in few_shot_data: few_shot_msgs = [] for msg in few_shot_example["messages"]: if msg["role"] == "user": few_shot_msgs.append(msg["content"]) messages = [ dialogue("user", "\n".join(few_shot_msgs)), few_shot_example["messages"][-1] ] + messages encodeds = tokenizer.apply_chat_template(messages, tokenize=False) return encodeds def format_chat_history_v2(example, few_shot): # TODO: implement few-shot learning user_msg = [] user_msg.append("") for msg in example["messages"]: # [INST] What is your favourite condiment? [/INST] user_msg.append(f"[INST] {msg['content']} [/INST]") # user_msg.append("") if "next_speaker" in example: user_msg.append(f"[INST] {example['next_speaker']}: ") return " ".join(user_msg) from transformers import StoppingCriteria, StoppingCriteriaList class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops = [], encounters=1): super().__init__() self.stops = [stop.to("cuda") for stop in stops] def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for seq in input_ids: for stop in self.stops: if len(seq) >= len(stop) and torch.all((stop == seq[-len(stop):])).item(): return True return False stop_words = ["[/INST]"] stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words] stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) def inference(chat_history): # chat_history: dict, with "messages" key storing dialogue history, in OpenAI format formatted = format_chat_history_v2(chat_history, few_shot=1) print(formatted) model_inputs = tokenizer( [formatted], return_tensors="pt", ) print(model_inputs) model_inputs = model_inputs.to(model.device) with torch.no_grad(): outputs = model.generate( input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, # max_length=1024, do_sample=True, top_p=1, # contrastive search # top_k=50, # penalty_alpha=0.6, # num_return_sequences=1, temperature=0.3, # num_return_sequences=3, use_cache=True, # pad_token_id=tokenizer.eos_token_id, # eos_token_id is not available for some models pad_token_id=tokenizer.pad_token_id, # eos_token_id is not available for some models eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, output_scores=True, output_attentions=False, output_hidden_states=False, max_new_tokens=256, # num_beams=9, # num_beam_groups=3, # repetition_penalty=1.0, # diversity_penalty=0.5, # num_beams=5, # stopping_criteria=stopping_criteria, ) # print(outputs) text = tokenizer.batch_decode(outputs, skip_special_tokens=True) def postprocess(t): t = t.split("[/INST]") t = [x.replace("[INST]", "").strip() for x in t] t = [x for x in t if x != ""] return t[-1] # text = [postprocess(t) for t in text] return text if __name__ == "__main__": chat_history = { "messages": [ # dialogue("system", ""), dialogue("user", "傍白: 真紅の言葉が胸の中に滑り込んでくる。"), dialogue("user", "悠馬: っ"), dialogue("user", "傍白: 限界だった。"), dialogue("user", "悠馬: 真紅,大好きです。これからもずっと一緒にいてください。"), ], "next_speaker": "真紅" } print(inference(chat_history))