datasciencedojo commited on
Commit
8c9a81a
·
1 Parent(s): eca26d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -1,15 +1,20 @@
1
  import gradio as gr
2
- from parrot import Parrot
3
  import torch
4
- import warnings
5
- warnings.filterwarnings("ignore")
6
 
7
- parrot = Parrot(model_tag="prithivida/parrot_paraphraser_on_T5", use_gpu=False)
 
 
 
8
 
9
- def paraphrase_text(input_text):
10
- para_phrases = parrot.augment(input_phrase=input_text,
11
- max_return_phrases = 3)
12
- return para_phrases[0][0], para_phrases[1][0] if len(para_phrases) > 1 else '' , para_phrases[2][0] if len(para_phrases) > 2 else ''
 
 
 
13
 
14
  examples = [["Uploading a video to YouTube can help exposure for your business.", "45"], ["Niagara Falls is viewed by thousands of tourists every year.", "30"]]
15
 
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
4
+ from sentence_splitter import SentenceSplitter, split_text_into_sentences
5
 
6
+ model_name = 'tuner007/pegasus_paraphrase'
7
+ torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
+ tokenizer = PegasusTokenizer.from_pretrained(model_name)
9
+ model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
10
 
11
+ def paraphrase_text(input_text, num_return_sequences = 3):
12
+ batch = tokenizer.prepare_seq2seq_batch([input_text], truncation=True, padding='longest', max_length=60,
13
+ return_tensors="pt").to(torch_device)
14
+ translated = model.generate(**batch, max_length=60, num_beams=10, num_return_sequences=num_return_sequences,
15
+ temperature=1.5)
16
+ paraphrased_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
17
+ return paraphrased_text[0], paraphrased_text[1], paraphrased_text[2]
18
 
19
  examples = [["Uploading a video to YouTube can help exposure for your business.", "45"], ["Niagara Falls is viewed by thousands of tourists every year.", "30"]]
20