wiraindrak commited on
Commit
b020e81
·
1 Parent(s): ac78664

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -6
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import T5Tokenizer, T5Model, T5ForConditionalGeneration, BertTokenizer, EncoderDecoderModel
2
 
3
  import gradio as gr
4
  from gradio.mix import Parallel
@@ -11,6 +11,10 @@ tokenizer_bert.bos_token = tokenizer_bert.cls_token
11
  tokenizer_bert.eos_token = tokenizer_bert.sep_token
12
  model_bert = EncoderDecoderModel.from_pretrained("cahya/bert2bert-indonesian-summarization")
13
 
 
 
 
 
14
  def summ_t5(text):
15
  input_ids = tokenizer_t5.encode(text, return_tensors='pt')
16
  summary_ids = model_t5.generate(input_ids,
@@ -25,8 +29,21 @@ def summ_t5(text):
25
  return summary_text
26
 
27
  def summ_bert(text):
28
- input_ids = tokenizer_bert.encode(text, return_tensors='pt')
29
- summary_ids = model_bert.generate(input_ids,
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  min_length=20,
31
  max_length=100,
32
  num_beams=10,
@@ -39,8 +56,12 @@ def summ_bert(text):
39
  temperature = 0.8,
40
  top_k = 50,
41
  top_p = 0.95)
42
- summary_text = tokenizer_bert.decode(summary_ids[0], skip_special_tokens=True)
43
- return summary_text
 
 
 
 
44
 
45
  t5_demo = gr.Interface(
46
  fn=summ_t5,
@@ -52,8 +73,13 @@ bert_demo = gr.Interface(
52
  inputs="text",
53
  outputs=gr.Textbox(lines=10, label="Bert2Bert Base Output")
54
  )
 
 
 
 
 
55
 
56
  if __name__ == "__main__":
57
- Parallel(t5_demo, bert_demo,
58
  inputs=gr.Textbox(lines=10, label="Input Text", placeholder="Enter article here..."),
59
  title="Summary of Summarizer - Indonesia").launch()
 
1
+ from transformers import T5Tokenizer, T5Model, T5ForConditionalGeneration, BertTokenizer, EncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM
2
 
3
  import gradio as gr
4
  from gradio.mix import Parallel
 
11
  tokenizer_bert.eos_token = tokenizer_bert.sep_token
12
  model_bert = EncoderDecoderModel.from_pretrained("cahya/bert2bert-indonesian-summarization")
13
 
14
+ t5_para_tokenizer = AutoTokenizer.from_pretrained("Wikidepia/IndoT5-base-paraphrase")
15
+ t5_para_model = AutoModelForSeq2SeqLM.from_pretrained("Wikidepia/IndoT5-base-paraphrase")
16
+
17
+
18
  def summ_t5(text):
19
  input_ids = tokenizer_t5.encode(text, return_tensors='pt')
20
  summary_ids = model_t5.generate(input_ids,
 
29
  return summary_text
30
 
31
  def summ_bert(text):
32
+ encoding = tokenizer(text, padding='longest', return_tensors="pt")
33
+ outputs = model.generate(
34
+ input_ids=encoding["input_ids"], attention_mask=encoding["attention_mask"],
35
+ max_length=512,
36
+ do_sample=True,
37
+ top_k=200,
38
+ top_p=0.95,
39
+ early_stopping=True,
40
+ num_return_sequences=5)
41
+ summary_text = tokenizer_bert.decode(summary_ids[0], skip_special_tokens=True)
42
+ return summary_text
43
+
44
+ def para_t5(text):
45
+ input_ids = t5_para_tokenizer.encode(text, return_tensors='pt')
46
+ outputs = t5_para_model .generate(input_ids,
47
  min_length=20,
48
  max_length=100,
49
  num_beams=10,
 
56
  temperature = 0.8,
57
  top_k = 50,
58
  top_p = 0.95)
59
+ return [
60
+ t5_para_tokenizer.decode(
61
+ output, skip_special_tokens=True, clean_up_tokenization_spaces=True
62
+ )
63
+ for output in outputs
64
+ ]
65
 
66
  t5_demo = gr.Interface(
67
  fn=summ_t5,
 
73
  inputs="text",
74
  outputs=gr.Textbox(lines=10, label="Bert2Bert Base Output")
75
  )
76
+ para_demo = gr.Interface(
77
+ fn=para_t5,
78
+ inputs="text",
79
+ outputs=gr.Textbox(lines=10, label="T5 Paraphrase Output")
80
+ )
81
 
82
  if __name__ == "__main__":
83
+ Parallel(t5_demo, bert_demo, para_demo,
84
  inputs=gr.Textbox(lines=10, label="Input Text", placeholder="Enter article here..."),
85
  title="Summary of Summarizer - Indonesia").launch()