File size: 2,938 Bytes
9b6c439
4d21bee
 
 
dfcc660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d21bee
411567e
 
de45564
411567e
4d21bee
411567e
4d21bee
 
411567e
4d21bee
 
99d8161
 
f321fd2
 
2280fc9
 
dfcc660
2280fc9
dfcc660
2280fc9
 
f321fd2
 
 
 
f6b2292
99d8161
f321fd2
4d21bee
 
76f1e3b
 
dfcc660
 
 
76f1e3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from utils import *
import gradio as gr


from transformers import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel

def download_model():
    # 下載並快取SentenceTransformer所需的模型和tokenizer
    sentence_transformer_model = "sentence-transformers/all-MiniLM-L6-v2"
    PreTrainedModel.from_pretrained(sentence_transformer_model)
    PreTrainedTokenizer.from_pretrained(sentence_transformer_model)
    
    # 下載並快取AutoTokenizer所需的模型
    biobart_model = "fuhsiao/BioBART-PMC-EXT-Section"
    PreTrainedModel.from_pretrained(biobart_model)
    AutoTokenizer.from_pretrained(biobart_model)
    AutoModel.from_pretrained(biobart_model)
    
    # 下載並快取AutoModelForSeq2SeqLM所需的模型
    bart_model = "fuhsiao/BART-PMC-EXT-Section"
    PreTrainedModel.from_pretrained(bart_model)
    AutoTokenizer.from_pretrained(bart_model)
    AutoModelForSeq2SeqLM.from_pretrained(bart_model)
    
    return True





def main(file, ext_threshold, article_type):
    
    if file is None or ext_threshold is None or article_type is None:
        return 'Please confirm that the file and settings are correct.'
    
    paper = read_text_to_json(file.name)
    
    if not is_valid_format(paper):
        return "invalid_format"
    
    sentJson = convert_to_sentence_json(paper)
    sentFeat = extract_sentence_features(sentJson)

    ExtModel = load_ExtModel('model/LGB_model_F10_S.pkl')
    ext = extractive_method(sentJson, sentFeat, ExtModel, threshold=ext_threshold, TGB=False)

    abstr_model_path = ''
    if article_type == 'non-specialized field':
        abstr_model_path = 'fuhsiao/BART-PMC-EXT-Section'
    elif article_type == 'biomedical field':
        abstr_model_path = 'fuhsiao/BioBART-PMC-EXT-Section'
        
    TOKENIZER, ABSTRMODEL = load_AbstrModel(abstr_model_path)
    abstr = abstractive_method(ext, tokenizer=TOKENIZER, model=ABSTRMODEL)

    result = ''
    for key, sec in zip(['I','M','R','D'], ['Introduction', 'Methods', 'Results', 'Discussion/Conclusion']):
        result += f"{sec}\n{abstr[key]}\n\n"
    
    return result
    

if __name__ == '__main__':

    download_model()
    

    # 定義Gradio介面
    iface = gr.Interface(
        fn=main,
        inputs=[
            gr.inputs.File(),
            gr.inputs.Slider(minimum=0.5, maximum=1, default=0.5, step=0.01, label="Extractive - Threshold"),
            gr.inputs.Dropdown(["non-specialized field", "biomedical field"],default="non-specialized field", label="Abstractive - Field")
        ],
        outputs=gr.outputs.Textbox(label="Output - Structured Abstract"),
        title="Ext-Abs-StructuredSum",
        description="please upload a .txt file formatted in the form of the example.",
        # examples=[['text.txt']],
        allow_flagging='never'
    )

    # 啟動Gradio介面
    iface.launch(share=False)  # share=False 用於停用分享模式