|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from datasets import Dataset, load_dataset |
|
import json |
|
from tqdm import tqdm |
|
import copy |
|
|
|
|
|
def read_raw_data(file_path): |
|
raw_data = [] |
|
with open(file_path, 'r') as f: |
|
json_list = list(f) |
|
for json_str in json_list: |
|
raw_data.append(json.loads(json_str)) |
|
return raw_data |
|
|
|
def split_one_chat_as_two(raw_dataset): |
|
chat_mul2 = [] |
|
for item in raw_dataset: |
|
|
|
chat_list = item["messages"] |
|
assert len(chat_list) == 5, "length should be 5" |
|
chat_1 = chat_list[:3] |
|
chat_2 = chat_list |
|
chat_mul2.append({"messages": chat_1}) |
|
chat_mul2.append({"messages": chat_2}) |
|
assert len(chat_mul2) == 2 * len(raw_dataset) |
|
return chat_mul2 |
|
|
|
|
|
def format_dataset(raw_dataset, fmt_tokenizer): |
|
message_list_infer = [] |
|
message_list_label = [] |
|
message_list_trct = [] |
|
for entry in raw_dataset: |
|
message_list_infer.append(entry["messages"][:-1]) |
|
message_list_label.append(entry["messages"][-1]) |
|
truncate_chat = [entry["messages"][0], entry["messages"][3]] |
|
message_list_trct.append(truncate_chat) |
|
dataset = Dataset.from_dict({"complete_chat": message_list_infer, "truncate_chat": message_list_trct, "label": message_list_label}) |
|
dataset = dataset.map(lambda x: {"formatted_chat": fmt_tokenizer.apply_chat_template(x["truncate_chat"], tokenize=False, add_generation_prompt=True)}) |
|
return dataset |
|
|
|
|
|
if __name__ == "__main__": |
|
finetuned_path = "checkpoint" |
|
test_dataset_path = "YOUR_PATH_TO_EMOTIONBENCH" |
|
llama3_fmt_tokenizer_path = "YOUR_PATH_TO_META_LLAMA-3.1-8B" |
|
|
|
llama3_ft_model = AutoModelForCausalLM.from_pretrained(finetuned_path, device_map='auto') |
|
llama3_tokenizer = AutoTokenizer.from_pretrained(finetuned_path) |
|
llama3_format_tokenizer = AutoTokenizer.from_pretrained(llama3_fmt_tokenizer_path) |
|
|
|
raw_data = read_raw_data(file_path=test_dataset_path) |
|
|
|
eval_data_formatted = format_dataset(raw_data, llama3_format_tokenizer) |
|
print(eval_data_formatted) |
|
|
|
|
|
|
|
ret_list = [] |
|
for sample in tqdm(eval_data_formatted["formatted_chat"][:], desc='Infering answers: '): |
|
inputs = llama3_tokenizer(sample, return_tensors="pt") |
|
inputs = inputs.to("cuda") |
|
|
|
|
|
outputs = llama3_ft_model.generate(**inputs, max_new_tokens=256)[0] |
|
decoded_text = llama3_tokenizer.decode(outputs) |
|
gen_text = decoded_text.split("<|start_header_id|>assistant<|end_header_id|>")[1].strip() |
|
ret_list.append(gen_text) |
|
|
|
with open("test_answers_with_context_trct.json", "w") as f: |
|
json.dump(ret_list, f, indent=2) |
|
eval_data_formatted.to_json("test_data_trct.jsonl") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|