from parrot import Parrot | |
import torch | |
import warnings | |
warnings.filterwarnings("ignore") | |
#import nltk #next stage, when executing multiple sentences | |
import streamlit as st | |
# ''' | |
# uncomment to get reproducable paraphrase generations | |
# def random_state(seed): | |
# torch.manual_seed(seed) | |
# if torch.cuda.is_available(): | |
# torch.cuda.manual_seed_all(seed) | |
# random_state(1234) | |
# ''' | |
# #Init models (make sure you init ONLY once if you integrate this to your code) | |
def load_model(): | |
# Fetch & load model | |
parrot = Parrot(model_tag="prithivida/parrot_paraphraser_on_T5") | |
return parrot | |
parrot = load_model() | |
st.title("Let's Rewrite your sentence!") | |
input_phrase = st.text_input("Input your text here:") | |
option = st.selectbox('Do you want to preserve some of the original words?', | |
('Yes', 'No')) | |
if option == 'Yes': | |
oc = False | |
else: | |
oc= True | |
if st.button('Submit Text!'): | |
st.header(':blue[Input]') | |
st.write(f" {input_phrase}") | |
st.text("--"*30) | |
st.header(':blue[Output]') | |
output_phrases = parrot.augment(input_phrase=input_phrase,do_diverse=oc) | |
if output_phrases is not None: | |
for phrases in output_phrases: | |
score = phrases[1] | |
sentence = phrases[0] | |
if score > 0: | |
st.write(sentence) | |
else: | |
st.write("Sorry! No sentences were found with a good score!") | |
else: | |
st.write("Sorry! No sentences were found with a good score!") | |
st.header("Feedback") | |
st.write("[Kindly consider providing feedback!](https://forms.gle/97st7g2n9NNpqnXw5)") |