File size: 2,638 Bytes
b7a1a13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
##
from transformers import AutoTokenizer, pipeline
from transformers import T5ForConditionalGeneration
from transformers import PegasusForConditionalGeneration
from transformers import BartForConditionalGeneration

import streamlit as st

# T5
def get_tidy_tab_t5():
    if 'tidy_tab_t5' not in st.session_state:
        st.session_state.tidy_tab_t5 = load_model_t5()
    return st.session_state.tidy_tab_t5

def load_model_t5():
    model_name="wgcv/tidy-tab-model-t5-small"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    return pipeline('summarization', model=model, tokenizer=tokenizer)


def predict_model_t5(text):
    tidy_tab_t5 = get_tidy_tab_t5()
    if(tidy_tab_t5):
        text = "summarize: " + text
        result = tidy_tab_t5(text, max_length=8, min_length=1)
        if(len(result)>0):
            return result[0]['summary_text']
        else:
            return None
    else:
        return None


# pegasus-xsum
def get_tidy_tab_pegasus():
    if 'tidy_tab_pegasus' not in st.session_state:
        st.session_state.tidy_tab_pegasus = load_model_pegasus()
    return st.session_state.tidy_tab_pegasus

def load_model_pegasus():
    model_name="wgcv/tidy-tab-model-pegasus-xsum"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = PegasusForConditionalGeneration.from_pretrained(model_name)
    return pipeline('summarization', model=model, tokenizer=tokenizer)


def predict_model_pegasus(text):
    tidy_tab_pegasus = get_tidy_tab_pegasus()
    if(tidy_tab_pegasus):
        text =  text
        result = tidy_tab_pegasus(text, max_length=8, min_length=1)
        if(len(result)>0):
            return result[0]['summary_text']
        else:
            return None
    else:
        return None
    

# Bart-Large
def get_tidy_tab_bart():
    if 'tidy_tab_bart' not in st.session_state:
        st.session_state.tidy_tab_bart = load_model_bart()
    return st.session_state.tidy_tab_bart

def load_model_bart():
    model_name="wgcv/tidy-tab-model-bart-large-cnn"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = BartForConditionalGeneration.from_pretrained(model_name)
    return pipeline('summarization', model=model, tokenizer=tokenizer)


def predict_model_bart(text):
    tidy_tab_bart = get_tidy_tab_bart()
    if(tidy_tab_bart):
        text =  text
        result = tidy_tab_bart(text, num_beams=4, max_length=12, min_length=1, do_sample=True  )
        if(len(result)>0):
            return result[0]['summary_text']
        else:
            return None
    else:
        return None