Wootang01 commited on
Commit
9119fdc
·
1 Parent(s): 9882197

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
3
+
4
+ st.title(Paraphrase with Pegasus")
5
+
6
+ model_name = "tuner007/pegasus_paraphrase"
7
+ torch_device = "cpu"
8
+ tokenizer = PegasusTokenizer.from_pretrained(model_name)
9
+
10
+
11
+ @st.cache(allow_output_mutation=True)
12
+ def load_model():
13
+ model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
14
+ return model
15
+
16
+ def get_response(
17
+ input_text, num_return_sequences, num_beams, max_length=60, temperature=1.5
18
+ ):
19
+
20
+ model = load_model()
21
+ batch = tokenizer([input_text], truncation=True, padding="longest", max_length=max_length, return_tensors="pt").to(torch_device)
22
+ translated = model.generate(**batch, max_length=max_length, num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=temperature)
23
+ tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
24
+ return tgt_text
25
+
26
+ num_beams = 10
27
+ num_return_sequences = st.slider("Number of paraphrases", 1, 10, 5, 1)
28
+ context = st.text_area(label="Enter a sentence to paraphrase", max_chars=384)
29
+
30
+ with st.expander("Advanced"):
31
+ temperature = st.slider("Temperature", 0.1, 5.0, 1.5, 0.1)
32
+ max_length = st.slider("Max length", 10, 100, 60, 10)
33
+
34
+ if context:
35
+ response = get_response(context, num_return_sequences, num_beams, max_length, temperature)
36
+ for paraphrase in response:
37
+ st.write(paraphrase)
38
+