File size: 3,592 Bytes
7f7285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-

'''
@Author     : Jiangjie Chen
@Time       : 2020/7/29 21:50
@Contact    : [email protected]
@Description: 
'''

import random
import cjjpy as cjj
import sys, os
import torch
from tqdm import tqdm

try:
    from .t5_qg.generator import Generator
except:
    sys.path.append(cjj.AbsParentDir(__file__, '.'))
    from t5_qg.generator import Generator


def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i: i + n]


class QuestionGenerator:
    def __init__(self, model, prefix=None, verbose=True):
        assert model in ['t5']
        self.verbose = verbose
        prefix = f'{prefix}/models/question_generation/t5-base-qg-hl/' if prefix else None
        self.qg = Generator('valhalla/t5-base-qg-hl', prefix,
                            device='cuda' if torch.cuda.is_available() else 'cpu',
                            verbose=self.verbose)

    def _clean_input_lines(self, input_lines):
        # Only use the first option
        if isinstance(input_lines[0][1], tuple) and len(input_lines[0][1]) == 3:
            input_lines = list(map(lambda x: (x[0], [x[1]]), input_lines))
        return input_lines

    def generate(self, input_lines: list, sample_num=1, batch_size=128, mask_token='<mask>'):
        '''
        :param input_lines: List([text, options=[('answer', 0, 1), (x, y, z), ...]])
        :param sample_num: default as 1, as usually only provide one option.
        :return: List((regular_q, cloze_q, a))
        '''
        qa_pairs = []
        if len(input_lines) == 0:
            return qa_pairs
        input_lines = self._clean_input_lines(input_lines)
        ques_chunk = []

        for text, options in input_lines:
            masked_qa = self.mask_text(text, options, sample_num=sample_num, mask_token=mask_token)
            for q, a in masked_qa:
                ques_chunk.append({'context': text, 'answer': a, 'cloze_q': q})

        ques_pairs = self.qg(ques_chunk, batch_size=batch_size)
        iter = tqdm(zip(ques_pairs, ques_chunk), desc='Replacing') \
            if self.verbose else zip(ques_pairs, ques_chunk)
        for qa, mq in iter:
            q = qa['questions'][0]
            a = qa['answer']
            q = q.replace(a[0], mask_token)
            qa_pairs.append((q, mq['cloze_q'], a))

        return qa_pairs

    def _sample(self, options, sample_num=1):
        if len(options) <= sample_num:
            return options
        else:
            return random.sample(options, sample_num)

    def mask_text(self, text: str, options, sample_num=1, mask_token='<mask>'):
        '''
        :param text: text
        :param options: [('xx', 1, 2), (), ()]
        :return: [text, ('xx', 1, 2)] * sample_num
        '''
        masked_span = self._sample(options, sample_num)
        masked = []
        for span in masked_span:
            if isinstance(span, str):
                ntext = text.replace(span, mask_token)
            elif len(span) == 3:
                assert text[span[1]:span[2]] == span[0], (text[span[1]:span[2]], span[0])
                ntext = text[:span[1]] + mask_token + text[span[2]:]
            else:
                raise ValueError(span)
            masked.append((ntext, span))
        return masked

    def assemble_question(self, regular_q, cloze_q):
        return f'{regular_q} or {cloze_q}'


if __name__ == '__main__':
    qg = QuestionGenerator('t5')
    qa_pairs = qg.generate([['I was born yesterday.', [('born', 6, 10), ('yesterday', 11, 20)]]], sample_num=1)
    print(qa_pairs)