LLaMA-3.1-8B-EmotionBench / evaluation.py
CUHK-ARISE's picture
Upload evaluation.py
b55896b verified
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import Dataset, load_dataset
import json
from tqdm import tqdm
import copy
def read_raw_data(file_path):
raw_data = []
with open(file_path, 'r') as f:
json_list = list(f)
for json_str in json_list:
raw_data.append(json.loads(json_str))
return raw_data
def split_one_chat_as_two(raw_dataset):
chat_mul2 = []
for item in raw_dataset:
# a list of 5 dicts
chat_list = item["messages"]
assert len(chat_list) == 5, "length should be 5"
chat_1 = chat_list[:3]
chat_2 = chat_list
chat_mul2.append({"messages": chat_1})
chat_mul2.append({"messages": chat_2})
assert len(chat_mul2) == 2 * len(raw_dataset)
return chat_mul2
def format_dataset(raw_dataset, fmt_tokenizer):
message_list_infer = []
message_list_label = []
message_list_trct = []
for entry in raw_dataset:
message_list_infer.append(entry["messages"][:-1]) # mask the last assistant response as label
message_list_label.append(entry["messages"][-1])
truncate_chat = [entry["messages"][0], entry["messages"][3]] # sys + last user
message_list_trct.append(truncate_chat)
dataset = Dataset.from_dict({"complete_chat": message_list_infer, "truncate_chat": message_list_trct, "label": message_list_label})
dataset = dataset.map(lambda x: {"formatted_chat": fmt_tokenizer.apply_chat_template(x["truncate_chat"], tokenize=False, add_generation_prompt=True)})
return dataset
if __name__ == "__main__":
finetuned_path = "checkpoint"
test_dataset_path = "YOUR_PATH_TO_EMOTIONBENCH"
llama3_fmt_tokenizer_path = "YOUR_PATH_TO_META_LLAMA-3.1-8B"
llama3_ft_model = AutoModelForCausalLM.from_pretrained(finetuned_path, device_map='auto')
llama3_tokenizer = AutoTokenizer.from_pretrained(finetuned_path)
llama3_format_tokenizer = AutoTokenizer.from_pretrained(llama3_fmt_tokenizer_path)
raw_data = read_raw_data(file_path=test_dataset_path)
# eval_data = split_one_chat_as_two(raw_data)
eval_data_formatted = format_dataset(raw_data, llama3_format_tokenizer)
print(eval_data_formatted)
# # ret_list = []
ret_list = []
for sample in tqdm(eval_data_formatted["formatted_chat"][:], desc='Infering answers: '):
inputs = llama3_tokenizer(sample, return_tensors="pt")
inputs = inputs.to("cuda")
# print(inputs)
# break
outputs = llama3_ft_model.generate(**inputs, max_new_tokens=256)[0]
decoded_text = llama3_tokenizer.decode(outputs)
gen_text = decoded_text.split("<|start_header_id|>assistant<|end_header_id|>")[1].strip()
ret_list.append(gen_text)
# print(decoded_text)
with open("test_answers_with_context_trct.json", "w") as f:
json.dump(ret_list, f, indent=2)
eval_data_formatted.to_json("test_data_trct.jsonl")
# default_sit = raw_data[0]["messages"][:2]
# default_prompt = llama3_format_tokenizer.apply_chat_template(default_sit, add_generation_prompt=True, tokenize=False)
# # print(default_prompt)
# ret_list = []
# for iter in tqdm(range(100)):
# inputs = llama3_tokenizer(default_prompt, return_tensors="pt")
# inputs = inputs.to("cuda")
# outputs = llama3_ft_model.generate(**inputs, max_new_tokens=256)[0]
# decoded_text = llama3_tokenizer.decode(outputs)
# gen_text = decoded_text.split("<|start_header_id|>assistant<|end_header_id|>")[1].strip()
# ret_list.append(gen_text)
# with open("test_answers_wo_context.json", "w") as f:
# json.dump(ret_list, f, indent=2)