fuhsiao418 commited on
Commit
99d8161
·
1 Parent(s): 9b6c439
app.py CHANGED
@@ -8,7 +8,11 @@ def main(file, ext_threshold, article_type):
8
  return "invalid_format"
9
  sentJson = convert_to_sentence_json(paper)
10
  sentFeat = extract_sentence_features(sentJson)
11
- return 'done'
 
 
 
 
12
 
13
 
14
 
 
8
  return "invalid_format"
9
  sentJson = convert_to_sentence_json(paper)
10
  sentFeat = extract_sentence_features(sentJson)
11
+
12
+ ExtModel = load_ExtModel('model/LGB_model_F10_S.pkl')
13
+ ext = extractive_method(sentJson, sentFeat, ExtModel, TGB=False)
14
+
15
+ return ext
16
 
17
 
18
 
model/LGB_model_F10_S.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c0f2b490f03417f065af6f3419b32c30f73af78f2aa9a846b1c55723d75fae3
3
+ size 1837716
requirements.txt CHANGED
@@ -1,7 +1,9 @@
1
  numpy==1.23.3
2
  pandas==1.5.3
 
3
  torch==1.13.1
4
  scikit-learn==1.2.1
 
5
  sentence-transformers==2.2.2
6
  spacy==3.4.4
7
  https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_core_sci_sm-0.5.1.tar.gz
 
1
  numpy==1.23.3
2
  pandas==1.5.3
3
+ nltk==3.7
4
  torch==1.13.1
5
  scikit-learn==1.2.1
6
+ transformers==4.27.2
7
  sentence-transformers==2.2.2
8
  spacy==3.4.4
9
  https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_core_sci_sm-0.5.1.tar.gz
utils/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
  from utils.preprocess import read_text_to_json, convert_to_sentence_json, extract_sentence_features, is_valid_format
2
-
 
1
  from utils.preprocess import read_text_to_json, convert_to_sentence_json, extract_sentence_features, is_valid_format
2
+ from utils.methods import load_ExtModel, load_AbstrModel
utils/methods.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import pickle
3
+ import numpy as np
4
+ import pandas as pd
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+
7
+
8
+
9
+ class TrigramBlock:
10
+ def __init__(self):
11
+ self.trigrams = set()
12
+
13
+ def check_overlap(self, text):
14
+ tokens = self._preprocess(text)
15
+ trigrams = set(self._get_trigrams(tokens))
16
+ overlap = bool(self.trigrams & trigrams)
17
+ self.trigrams |= trigrams
18
+ return overlap
19
+
20
+ def _preprocess(self, text):
21
+ text = text.lower()
22
+ text = ''.join([c for c in text if c.isalpha() or c.isspace()])
23
+ tokens = nltk.word_tokenize(text)
24
+ return tokens
25
+
26
+ def _get_trigrams(self, tokens):
27
+ trigrams = [' '.join(tokens[i:i+3]) for i in range(len(tokens)-2)]
28
+ return trigrams
29
+
30
+
31
+
32
+ def convert_sentence_df(sentJson, pred, true_proba, set_trigram_blocking):
33
+
34
+ body = pd.DataFrame([(section, sent['text'].strip()) for section in 'IMRD' for sent in sentJson['body'][section]],
35
+ columns=['section', 'text']).astype({'section': 'category', 'text': 'string'})
36
+ # 加上預測結果和機率
37
+ body['predict'] = pred.astype('bool')
38
+ body['proba'] = true_proba.astype('float16')
39
+ # 對每章節的提取句子進行 trigram blocking
40
+ if set_trigram_blocking:
41
+ for section in 'IMRD':
42
+ block = TrigramBlock()
43
+ temp = body.loc[(body['section'] == section) & (body['predict'] == True)].sort_values(by='proba', ascending=False)
44
+ for i, row in temp.iterrows():
45
+ if block.check_overlap(row['text']):
46
+ body.at[i, 'predict'] = False
47
+ return body
48
+
49
+ # 提取式方法
50
+ def extractive_method(sentJson, sentFeat, model, threshold=0.5, TGB=False):
51
+ #預測
52
+ def predict(x):
53
+ true_proba = model.predict_proba(x)[:, 1]
54
+ # 如果沒有任何句子的預測機率大於閾值,則選取最大機率的句子為摘要句
55
+ if not np.any(true_proba > threshold):
56
+ true_proba[true_proba == np.max(true_proba)] = 1
57
+ pred = (true_proba > threshold).astype('int')
58
+ return pred, true_proba
59
+
60
+ pred, true_proba = predict(sentFeat)
61
+ body = convert_sentence_df(sentJson, pred, true_proba, TGB)
62
+ res = body[body['predict'] == True]
63
+ ext = {i: ' '.join(res.groupby('section').get_group(i)['text']) for i in 'IMRD'}
64
+ return ext
65
+
66
+ def abstractive_method(ext, tokenizer, model, device='cpu'):
67
+ abstr = {key: '' for key in 'IMRD'}
68
+ for section in 'IMRD':
69
+ text = ext[section]
70
+ model_inputs = tokenizer(text, truncation=True, return_tensors='pt').input_ids
71
+ outputs = model.generate(model_inputs.to(device))
72
+ abstr_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
73
+ abstr[section] = abstr_text
74
+ return abstr
75
+
76
+ # extractive summarizer
77
+ def load_ExtModel(path):
78
+ return pickle.load(open(path, 'rb'))
79
+
80
+ # abstractive summarizer
81
+ def load_AbstrModel(path, device='cpu'):
82
+ model_checkpoint = path
83
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=1024)
84
+ abstrModel = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
85
+ abstrModel = abstrModel.to(device)
86
+ return tokenizer, abstrModel
87
+