|
import argparse |
|
import json |
|
import numpy as np |
|
import tqdm |
|
from pathlib import Path |
|
from pprint import pprint |
|
from collections import defaultdict, Counter |
|
|
|
from transformers import AutoTokenizer |
|
import scrl.utils as utils |
|
from scrl.model import load_checkpoint |
|
from scrl.metrics import compute_token_f1, rouge_scorer, ROUGE_TYPES |
|
from nltk import word_tokenize |
|
|
|
from scrl.rewards import load_rewards |
|
from scrl.config import load_config |
|
import time |
|
|
|
|
|
def main(args): |
|
model = load_checkpoint(Path(args.checkpoint), device=args.device) |
|
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base") |
|
dataset = list(utils.read_jsonl(args.dataset)) |
|
batches = utils.batchify(dataset, args.batch_size) |
|
outputs = [] |
|
t1 = time.time() |
|
for items in tqdm.tqdm(batches): |
|
sources = [x["text"] for x in items] |
|
summaries = model.predict(sources, tokenizer, args.device) |
|
for item, summary in zip(items, summaries): |
|
output = { |
|
"id": item["id"], |
|
"pred-summary": summary, |
|
} |
|
outputs.append(output) |
|
t2 = time.time() |
|
print("Seconds:", t2-t1) |
|
if args.output: |
|
utils.write_jsonl(outputs, args.output, "w") |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--dataset', required=True) |
|
parser.add_argument('--output', required=False) |
|
parser.add_argument('--checkpoint', required=True) |
|
parser.add_argument('--device', default="cpu") |
|
parser.add_argument('--batch-size', type=int, default=4) |
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == '__main__': |
|
main(parse_args()) |
|
|