|
import re |
|
import os |
|
import datasets |
|
from sklearn.metrics import accuracy_score, mean_squared_error |
|
from collections import defaultdict |
|
from rouge_score import rouge_scorer |
|
|
|
|
|
lora_module_dict = { |
|
'chatglm2': ['query_key_value'], |
|
'llama2': [ |
|
'q_proj', 'k_proj', 'v_proj', |
|
'o_proj', 'gate_proj', 'up_proj', 'down_proj', |
|
|
|
], |
|
} |
|
|
|
|
|
def tokenize(args, tokenizer, feature): |
|
|
|
prompt_ids = tokenizer.encode( |
|
feature['prompt'].strip(), padding=False, |
|
max_length=args.max_length, truncation=True |
|
) |
|
|
|
target_ids = tokenizer.encode( |
|
feature['answer'].strip(), padding=False, |
|
max_length=args.max_length, truncation=True, add_special_tokens=False |
|
) |
|
|
|
input_ids = prompt_ids + target_ids |
|
exceed_max_length = len(input_ids) >= args.max_length |
|
|
|
|
|
if input_ids[-1] != tokenizer.eos_token_id and not exceed_max_length: |
|
input_ids.append(tokenizer.eos_token_id) |
|
|
|
label_ids = [tokenizer.pad_token_id] * len(prompt_ids) + input_ids[len(prompt_ids):] |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"labels": label_ids, |
|
"exceed_max_length": exceed_max_length |
|
} |
|
|
|
|
|
def parse_model_name(name, from_remote=False): |
|
|
|
if name == 'chatglm2': |
|
return 'THUDM/chatglm2-6b' if from_remote else 'base_models/chatglm2-6b' |
|
elif name == 'llama2': |
|
return 'meta-llama/Llama-2-7b-chat-hf' if from_remote else 'base_models/Llama-2-7b-chat-hf' |
|
else: |
|
raise ValueError(f"Undefined base model {name}") |
|
|
|
|
|
def load_dataset(names, from_remote=False): |
|
|
|
dataset_names = [d for d in names.split(',')] |
|
dataset_list = [] |
|
|
|
for name in dataset_names: |
|
rep = 1 |
|
if not os.path.exists(name): |
|
rep = int(name.split('*')[1]) if '*' in name else 1 |
|
name = ('FinGPT/fingpt-forecaster-' if from_remote else 'data/fingpt-forecaster-') + name.split('*')[0] |
|
tmp_dataset = datasets.load_dataset(name) if from_remote else datasets.load_from_disk(name) |
|
|
|
if 'test' not in tmp_dataset: |
|
tmp_dataset = tmp_dataset.train_test_split(0.2, shuffle=True, seed=42) |
|
dataset_list.extend([tmp_dataset] * rep) |
|
|
|
return dataset_list |
|
|
|
|
|
def parse_answer(answer): |
|
|
|
match_res = re.match(r"^\s*\[Positive Developments\]:\s*(.*)\s*\[Potential Concerns\]:\s*(.*)\s*\[Prediction & Analysis\]:\s*(.*)\s*$", answer, flags=re.DOTALL) |
|
if not match_res: |
|
return None |
|
|
|
pros, cons, pna = match_res.group(1), match_res.group(2), match_res.group(3) |
|
|
|
match_res = re.match(r'^Prediction:\s*(.*)\s*Analysis:\s*(.*)\s*$', pna, flags=re.DOTALL) |
|
if not match_res: |
|
return None |
|
|
|
pred, anal = match_res.group(1), match_res.group(2) |
|
|
|
if re.search(r'up|increase', pred.lower()): |
|
pred_bin = 1 |
|
elif re.search(r'down|decrease|decline', pred.lower()): |
|
pred_bin = -1 |
|
else: |
|
pred_bin = 0 |
|
|
|
match_res = re.search(r'(\d)-(\d)%', pred) |
|
if not match_res: |
|
match_res = re.search(r'(?:more than )?(\d)+?%', pred) |
|
|
|
pred_margin = pred_bin * (int(match_res.group(1)) + 0.5) if match_res else 0. |
|
|
|
return { |
|
"positive developments": pros, |
|
"potential concerns": cons, |
|
"prediction": pred_margin, |
|
"prediction_binary": pred_bin, |
|
"analysis": anal |
|
} |
|
|
|
|
|
def calc_rouge_score(references, answers): |
|
|
|
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) |
|
|
|
scores_per_pair = [scorer.score(ref, ans) for ref, ans in zip(references, answers)] |
|
|
|
rouge1 = sum(score['rouge1'].fmeasure for score in scores_per_pair) / len(scores_per_pair) |
|
rouge2 = sum(score['rouge2'].fmeasure for score in scores_per_pair) / len(scores_per_pair) |
|
rougeL = sum(score['rougeL'].fmeasure for score in scores_per_pair) / len(scores_per_pair) |
|
|
|
return {'rouge1': rouge1, 'rouge2': rouge2, 'rougeL': rougeL} |
|
|
|
|
|
def calc_metrics(answers, gts): |
|
|
|
answers_dict = defaultdict(list) |
|
gts_dict = defaultdict(list) |
|
|
|
for answer, gt in zip(answers, gts): |
|
answer_dict = parse_answer(answer) |
|
gt_dict = parse_answer(gt) |
|
|
|
if answer_dict and gt_dict: |
|
for k in answer_dict.keys(): |
|
answers_dict[k].append(answer_dict[k]) |
|
gts_dict[k].append(gt_dict[k]) |
|
|
|
if not answers_dict['prediction']: |
|
return {} |
|
|
|
bin_acc = accuracy_score(gts_dict['prediction_binary'], answers_dict['prediction_binary']) |
|
mse = mean_squared_error(gts_dict['prediction'], answers_dict['prediction']) |
|
|
|
pros_rouge_scores = calc_rouge_score(gts_dict['positive developments'], answers_dict['positive developments']) |
|
cons_rouge_scores = calc_rouge_score(gts_dict['potential concerns'], answers_dict['potential concerns']) |
|
anal_rouge_scores = calc_rouge_score(gts_dict['analysis'], answers_dict['analysis']) |
|
|
|
print(f"\nBinary Accuracy: {bin_acc:.2f} | Mean Square Error: {mse:.2f}") |
|
print(f"\nRouge Score of Positive Developments: {pros_rouge_scores}") |
|
print(f"\nRouge Score of Potential Concerns: {cons_rouge_scores}") |
|
print(f"\nRouge Score of Summary Analysis: {anal_rouge_scores}") |
|
|
|
return { |
|
"valid_count": len(answers_dict['prediction']), |
|
"bin_acc": bin_acc, |
|
"mse": mse, |
|
"pros_rouge_scores": pros_rouge_scores, |
|
"cons_rouge_scores": cons_rouge_scores, |
|
"anal_rouge_scores": anal_rouge_scores |
|
} |
|
|