Spaces:
Build error
Build error
File size: 4,935 Bytes
7f7285f 460c37e 7f7285f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# -*- coding: utf-8 -*-
'''
@Author : Jiangjie Chen
@Time : 2020/8/12 14:44
@Contact : [email protected]
@Description:
'''
import re
import time
from pathlib import Path
from typing import Dict, List
import torch
from logging import getLogger
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import ujson as json
import random
try:
from .seq2seq.seq2seq_utils import (
use_task_specific_params,
calculate_rouge,
chunks,
Seq2SeqDataset,
lmap,
load_json,
save_json,
)
except ImportError:
import cjjpy as cjj
import sys
sys.path.append(cjj.AbsParentDir(__file__, '.'))
from seq2seq.seq2seq_utils import (
use_task_specific_params,
calculate_rouge,
chunks,
Seq2SeqDataset,
lmap,
load_json,
save_json,
)
logger = getLogger(__name__)
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
random.seed(1111)
def assemble_answers_to_one(js, k=5, mask_token='<mask>', mask_rate=0.):
if isinstance(js, str):
js = json.loads(js)
should_keep = random.random() > mask_rate
if 'evidential_assembled' in js:
js.pop('evidential_assembled')
for q, answers in zip(js['cloze_qs'], js['evidential']):
if mask_token in q:
s = q.find(mask_token)
e = s + len(mask_token)
nq_list = []
if should_keep:
for i in range(k):
answer_span = answers[i]
nq = q[:s] + answer_span + q[e:]
nq_list.append(nq)
else:
for i in range(k):
answer_span = mask_token
nq = q[:s] + answer_span + q[e:]
nq_list.append(nq)
ev_nqs = ' '.join(nq_list)
if js.get('evidential_assembled') is None:
js['evidential_assembled'] = [ev_nqs]
else:
js['evidential_assembled'].append(ev_nqs)
assert len(js['evidential_assembled']) == len(js['answers'])
return js
class AnswerGenerator():
def __init__(self, model_name, device=DEFAULT_DEVICE):
self.model_name = str(model_name)
self.device = device
self.model = None
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
def init_model(self):
if self.model is None:
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(self.device)
def assemble(self, question, context):
sep = '\n' if 'unifiedqa' in self.tokenizer.name_or_path else self.tokenizer.sep_token
return f'{question} {sep} {context}'
def generate(self, examples, out_file=None, batch_size=16, verbose=True,
max_length=20, min_length=1, num_beams=4, num_return_sequences=4,
prefix=None, fp16=False, task='summarization', **generate_kwargs):
'''
:param examples: [N]
:return: [N x num_return_seq]
'''
self.init_model()
if fp16:
self.model = self.model.half()
# update config with summarization specific params
use_task_specific_params(self.model, task)
fout = None if out_file is None else Path(out_file).open("w", encoding="utf-8")
generated = []
if verbose:
iter = tqdm(list(chunks(examples, batch_size)), desc="MRC")
else:
iter = list(chunks(examples, batch_size))
if prefix is None:
prefix = prefix or getattr(self.model.config, "prefix", "") or ""
for examples_chunk in iter:
examples_chunk = [prefix + text for text in examples_chunk]
batch = self.tokenizer(examples_chunk, return_tensors="pt", truncation=True,
padding="longest").to(self.device)
summaries = self.model.generate(
input_ids=batch.input_ids,
attention_mask=batch.attention_mask,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
length_penalty=1.2,
repetition_penalty=1.2,
**generate_kwargs,
)
dec = self.tokenizer.batch_decode(summaries, skip_special_tokens=True,
clean_up_tokenization_spaces=False)
if fout is not None:
for hypothesis in dec:
fout.write(hypothesis.strip() + "\n")
fout.flush()
else:
generated += dec
if fout is not None:
fout.close()
generated = list(map(lambda x: x.strip(), generated))
generated = list(chunks(generated, num_return_sequences))
return generated
|