File size: 3,785 Bytes
99d8161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffe67a5
 
 
ea748ba
 
ffe67a5
 
 
ea748ba
 
 
ffe67a5
99d8161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb10343
 
 
 
 
 
 
 
 
 
 
99d8161
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import nltk
import pickle
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM



class TrigramBlock:
    def __init__(self):
        self.trigrams = set()

    def check_overlap(self, text):
        tokens = self._preprocess(text)
        trigrams = set(self._get_trigrams(tokens))
        overlap = bool(self.trigrams & trigrams)
        self.trigrams |= trigrams
        return overlap

    def _preprocess(self, text):
        text = text.lower()
        text = ''.join([c for c in text if c.isalpha() or c.isspace()])
        tokens = nltk.word_tokenize(text)
        return tokens

    def _get_trigrams(self, tokens):
        trigrams = [' '.join(tokens[i:i+3]) for i in range(len(tokens)-2)]
        return trigrams



def convert_sentence_df(sentJson, pred, true_proba, set_trigram_blocking):
    
    body = pd.DataFrame([(section, sent['text'].strip()) for section in 'IMRD' for sent in sentJson['body'][section]],
                       columns=['section', 'text']).astype({'section': 'category', 'text': 'string'})
    # 加上預測結果和機率
    body['predict'] = pred.astype('bool')
    body['proba'] = true_proba.astype('float16')
    # 對每章節的提取句子進行 trigram blocking
    if set_trigram_blocking:
        for section in 'IMRD':
            block = TrigramBlock()
            temp = body.loc[(body['section'] == section) & (body['predict'] == True)].sort_values(by='proba', ascending=False)
            for i, row in temp.iterrows():
                if block.check_overlap(row['text']):
                    body.at[i, 'predict'] = False                   
    return body

# 提取式方法
def extractive_method(sentJson, sentFeat, model, threshold=0.5, TGB=False):
    #預測
    def predict(x):
        true_proba = model.predict_proba(x)[:, 1]
        # 如果沒有任何句子的預測機率大於閾值,則選取最大機率的句子為摘要句
        if not np.any(true_proba > threshold):
            true_proba[true_proba == np.max(true_proba)] = 1
        pred = (true_proba > threshold).astype('int')
        return pred, true_proba


    grouped = sentFeat.groupby('section')
    pred = np.array([])
    true_proba = np.array([])

    for group_name, group_data in grouped:
        pred_sec, true_proba_sec = predict(group_data)
        # Append to the NumPy arrays
        pred = np.append(pred, pred_sec)
        true_proba = np.append(true_proba, true_proba_sec)

    body = convert_sentence_df(sentJson, pred, true_proba, TGB)
    res = body[body['predict'] == True]
    ext = {i: ' '.join(res.groupby('section').get_group(i)['text']) for i in 'IMRD'}
    return ext

def abstractive_method(ext, tokenizer, model, device='cpu'):
    abstr = {key: '' for key in 'IMRD'}
    for section in 'IMRD':
        text = ext[section]
        model_inputs = tokenizer(text,  truncation=True, return_tensors='pt').input_ids
        outputs = model.generate(model_inputs.to(device))
        abstr_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        abstr[section] = abstr_text
    return abstr

# extractive summarizer
def load_ExtModel(path):
    return pickle.load(open(path, 'rb'))

# abstractive summarizer
def load_AbstrModel(path, device='cpu'):
    model_checkpoint = path
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=1024)
    abstrModel = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
    abstrModel = abstrModel.to(device)

    generation_config = {
        'num_beams': 5,
        'max_length': 512,
        'min_length': 64,
        'length_penalty': 2.0,
        'early_stopping': True,
        'no_repeat_ngram_size': 3
    }

    abstrModel.config.update(generation_config)
    return tokenizer, abstrModel