wiraindrak's picture
Update app.py
34b3fd1
raw
history blame
3.37 kB
from transformers import T5Tokenizer, T5Model, T5ForConditionalGeneration, BertTokenizer, EncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
from gradio.mix import Parallel
tokenizer_t5 = T5Tokenizer.from_pretrained("panggi/t5-base-indonesian-summarization-cased")
model_t5 = T5ForConditionalGeneration.from_pretrained("panggi/t5-base-indonesian-summarization-cased")
tokenizer_bert = BertTokenizer.from_pretrained("cahya/bert2bert-indonesian-summarization")
tokenizer_bert.bos_token = tokenizer_bert.cls_token
tokenizer_bert.eos_token = tokenizer_bert.sep_token
model_bert = EncoderDecoderModel.from_pretrained("cahya/bert2bert-indonesian-summarization")
t5_para_tokenizer = AutoTokenizer.from_pretrained("Wikidepia/IndoT5-base-paraphrase")
t5_para_model = AutoModelForSeq2SeqLM.from_pretrained("Wikidepia/IndoT5-base-paraphrase")
def summ_t5(text):
input_ids = tokenizer_t5.encode(text, return_tensors='pt')
summary_ids = model_t5.generate(input_ids,
max_length=100,
num_beams=2,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True,
no_repeat_ngram_size=2,
use_cache=True)
summary_text = tokenizer_t5.decode(summary_ids[0], skip_special_tokens=True)
return summary_text
def summ_bert(text):
input_ids = tokenizer_bert.encode(text, return_tensors="pt")
summary_ids= model_bert.generate(input_ids,
max_length=100,
num_beams=10,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True,
no_repeat_ngram_size=2,
use_cache=True)
summary_text = tokenizer_bert.decode(summary_ids[0], skip_special_tokens=True)
return summary_text
def para_t5(text):
encoding = t5_para_tokenizer(text, padding='longest', return_tensors='pt')
outputs = t5_para_model.generate(
input_ids=encoding["input_ids"],
attention_mask=encoding["attention_mask"],
max_length=100,
do_sample=True,
top_k=120,
top_p=0.95,
early_stopping=True,
num_return_sequences=5)
return [
t5_para_tokenizer.decode(
output, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
for output in outputs
]
def summarize(text):
t5_ = summ_t5(text)
bert_ = summ_bert(text)
para_ = para_t5(t5_)
return t5_, bert_, para_
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.Markdown("""<h1 style="text-align:center">Summary of Summarizer - Indonesia</h1>""")
gr.Markdown(
"""
Creator: wiraindrak
"""
)
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Text")
analyze_button = gr.Button(label="Analyze")
with gr.Column():
t5_output = gr.Textbox(label="T5 Base Output")
bert_output = gr.Textbox(label="Bert2Bert Base Output")
para_output = gr.Textbox(label="T5 Paraphrase Output")
analyze_button.click(summarize, inputs=input_text, outputs=[t5_output, bert_output, para_output])
demo.launch()