ShreyaRao commited on
Commit
b9617b4
·
1 Parent(s): 9e866e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -38,15 +38,15 @@ def bart_summarize(text):
38
  return pp
39
 
40
  #Encoder-Decoder
41
- # def encoder_decoder(text):
42
- # model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")
43
- # tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")
44
- # # let's perform inference on a long piece of text
45
- # input_ids = tokenizer(text, return_tensors="pt").input_ids
46
- # # autoregressively generate summary (uses greedy decoding by default)
47
- # generated_ids = model.generate(input_ids)
48
- # generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
49
- # return generated_text
50
 
51
 
52
  #Input text
@@ -77,8 +77,9 @@ if button:
77
  elif model == "BART":
78
  st.write("You have selected BART model.")
79
  summary = bart_summarize(text)
80
- #elif model == "Encoder-Decoder":
81
- #st.write("You have selected Encoder-Decoder model.")
 
82
 
83
  #st.toast("Please wait while we summarize your text.")
84
  #with st.spinner("Summarizing..."):
 
38
  return pp
39
 
40
  #Encoder-Decoder
41
+ def encoder_decoder(text):
42
+ model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")
43
+ tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")
44
+ # let's perform inference on a long piece of text
45
+ input_ids = tokenizer(text, return_tensors="pt").input_ids
46
+ # autoregressively generate summary (uses greedy decoding by default)
47
+ generated_ids = model.generate(input_ids)
48
+ generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
49
+ return generated_text
50
 
51
 
52
  #Input text
 
77
  elif model == "BART":
78
  st.write("You have selected BART model.")
79
  summary = bart_summarize(text)
80
+ elif model == "Encoder-Decoder":
81
+ st.write("You have selected Encoder-Decoder model.")
82
+ summary = encoder_decoder(text)
83
 
84
  #st.toast("Please wait while we summarize your text.")
85
  #with st.spinner("Summarizing..."):