Spaces:
Sleeping
Sleeping
File size: 2,638 Bytes
b7a1a13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
##
from transformers import AutoTokenizer, pipeline
from transformers import T5ForConditionalGeneration
from transformers import PegasusForConditionalGeneration
from transformers import BartForConditionalGeneration
import streamlit as st
# T5
def get_tidy_tab_t5():
if 'tidy_tab_t5' not in st.session_state:
st.session_state.tidy_tab_t5 = load_model_t5()
return st.session_state.tidy_tab_t5
def load_model_t5():
model_name="wgcv/tidy-tab-model-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
return pipeline('summarization', model=model, tokenizer=tokenizer)
def predict_model_t5(text):
tidy_tab_t5 = get_tidy_tab_t5()
if(tidy_tab_t5):
text = "summarize: " + text
result = tidy_tab_t5(text, max_length=8, min_length=1)
if(len(result)>0):
return result[0]['summary_text']
else:
return None
else:
return None
# pegasus-xsum
def get_tidy_tab_pegasus():
if 'tidy_tab_pegasus' not in st.session_state:
st.session_state.tidy_tab_pegasus = load_model_pegasus()
return st.session_state.tidy_tab_pegasus
def load_model_pegasus():
model_name="wgcv/tidy-tab-model-pegasus-xsum"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name)
return pipeline('summarization', model=model, tokenizer=tokenizer)
def predict_model_pegasus(text):
tidy_tab_pegasus = get_tidy_tab_pegasus()
if(tidy_tab_pegasus):
text = text
result = tidy_tab_pegasus(text, max_length=8, min_length=1)
if(len(result)>0):
return result[0]['summary_text']
else:
return None
else:
return None
# Bart-Large
def get_tidy_tab_bart():
if 'tidy_tab_bart' not in st.session_state:
st.session_state.tidy_tab_bart = load_model_bart()
return st.session_state.tidy_tab_bart
def load_model_bart():
model_name="wgcv/tidy-tab-model-bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
return pipeline('summarization', model=model, tokenizer=tokenizer)
def predict_model_bart(text):
tidy_tab_bart = get_tidy_tab_bart()
if(tidy_tab_bart):
text = text
result = tidy_tab_bart(text, num_beams=4, max_length=12, min_length=1, do_sample=True )
if(len(result)>0):
return result[0]['summary_text']
else:
return None
else:
return None |