|
import argparse |
|
import json |
|
import numpy as np |
|
import tqdm |
|
from pathlib import Path |
|
from pprint import pprint |
|
from collections import defaultdict, Counter |
|
|
|
from transformers import AutoTokenizer |
|
import sys |
|
sys.path.append("/home/hdd/lijinyi/CompressionInAvalon/promptcompressor/SCRL_new") |
|
print(sys.path) |
|
import scrl.utils as utils |
|
from scrl.model import load_checkpoint, load_model |
|
from scrl.eval_metrics import compute_token_f1, rouge_scorer, ROUGE_TYPES |
|
from nltk import word_tokenize |
|
import nltk |
|
|
|
nltk.download('punkt') |
|
print("punkt done!") |
|
|
|
|
|
def main(args): |
|
|
|
if args.model_dir is not None and args.checkpoint is None: |
|
model = load_model( |
|
Path(args.model_dir), device=args.device, prefix="best" |
|
) |
|
elif args.model_dir is None and args.checkpoint is not None: |
|
model = load_checkpoint(Path(args.checkpoint), device=args.device) |
|
else: |
|
raise Exception("Provide either a model directory or checkpoint.") |
|
|
|
model = load_model(Path(args.model_dir), device=args.device) |
|
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base") |
|
|
|
dataset = list(utils.read_jsonl(args.dataset)) |
|
|
|
all_scores = defaultdict(list) |
|
|
|
for item in tqdm.tqdm(dataset): |
|
src = item["text"] |
|
if args.lower_src: |
|
src = src.lower() |
|
tgts = item["summaries"] |
|
pred = model.predict([src], tokenizer, args.device)[0] |
|
|
|
if args.max_chars > 0: |
|
pred = pred[:args.max_chars] |
|
|
|
src_tokens = word_tokenize(src) |
|
pred_tokens = word_tokenize(pred) |
|
|
|
if args.lower_summary: |
|
pred_tokens = [t.lower() for t in pred_tokens] |
|
|
|
if args.pretokenized: |
|
src_tokens = src.split() |
|
else: |
|
src_tokens = word_tokenize(src) |
|
|
|
item_scores = defaultdict(list) |
|
for tgt in tgts: |
|
if args.pretokenized: |
|
tgt_tokens = tgt.split() |
|
else: |
|
tgt_tokens = word_tokenize(tgt) |
|
if args.lower_summary: |
|
tgt_tokens = [t.lower() for t in tgt_tokens] |
|
|
|
token_fscore = compute_token_f1(tgt_tokens, pred_tokens, use_counts=True) |
|
|
|
rouge_scores = rouge_scorer.score(tgt, pred) |
|
for rouge_type, rouge_type_scores in rouge_scores.items(): |
|
item_scores[f"{rouge_type}-p"].append(rouge_type_scores.precision) |
|
item_scores[f"{rouge_type}-r"].append(rouge_type_scores.recall) |
|
item_scores[f"{rouge_type}-f"].append(rouge_type_scores.fmeasure) |
|
|
|
item_scores["token-f1"].append(token_fscore) |
|
item_scores["tgt-len"].append(len(tgt_tokens)) |
|
item_scores["tgt-cr"].append(len(tgt_tokens) / len(src_tokens)) |
|
|
|
for k, values in item_scores.items(): |
|
item_mean = np.mean(values) |
|
all_scores[k].append(item_mean) |
|
|
|
all_scores["pred-len"].append(len(pred_tokens)) |
|
all_scores["src-len"].append(len(src_tokens)) |
|
all_scores["pred-cr"].append(len(pred_tokens) / len(src_tokens)) |
|
|
|
if args.verbose: |
|
print("SRC:", src) |
|
print("TGT:", tgts[0]) |
|
print("PRED:", pred) |
|
print("=" * 100) |
|
|
|
print("="*100) |
|
print("RESULTS:") |
|
|
|
print("="*20, "Length (#tokens):", "="*20) |
|
for metric in ("src-len", "tgt-len", "pred-len"): |
|
mean = np.mean(all_scores[metric]) |
|
print(f"{metric}: {mean:.2f}") |
|
print() |
|
|
|
print("="*20, "Compression ratio:", "="*20) |
|
for metric in ("tgt-cr", "pred-cr"): |
|
mean = np.mean(all_scores[metric]) |
|
print(f"{metric}: {mean:.2f}") |
|
print() |
|
|
|
print("="*20, "Token F1-Score:", "="*20) |
|
mean = np.mean(all_scores["token-f1"]) |
|
print(f"f1-score: {mean:.3f}") |
|
print() |
|
|
|
print("="*20, "ROUGE F1-Scores:", "="*20) |
|
for rouge_type in ROUGE_TYPES: |
|
mean = np.mean(all_scores[f"{rouge_type}-f"]) |
|
print(f"{rouge_type}: {mean:.4f}") |
|
print() |
|
|
|
print("="*20, "ROUGE Recall:", "="*20) |
|
for rouge_type in ROUGE_TYPES: |
|
mean = np.mean(all_scores[f"{rouge_type}-r"]) |
|
print(f"{rouge_type}: {mean:.4f}") |
|
print() |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--dataset', required=True) |
|
parser.add_argument('--model-dir', required=False) |
|
parser.add_argument('--checkpoint', required=False) |
|
parser.add_argument('--device', default="cpu") |
|
parser.add_argument('--pretokenized', action="store_true") |
|
parser.add_argument('--max-chars', type=int, default=-1) |
|
parser.add_argument('--verbose', action="store_true") |
|
parser.add_argument('--lower-src', action="store_true") |
|
parser.add_argument('--lower-summary', action="store_true") |
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == '__main__': |
|
main(parse_args()) |
|
|