Jaane commited on
Commit
8ec3ac3
·
verified ·
1 Parent(s): 01313f4

adding changes

Browse files
Files changed (1) hide show
  1. app.py +127 -57
app.py CHANGED
@@ -1,63 +1,133 @@
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  ],
 
 
 
 
 
 
 
59
  )
60
 
61
-
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
+ import torch
2
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
3
  import gradio as gr
4
+
5
+ # Load the tokenizer and model once when the app starts
6
+ model_name = 'tuner007/pegasus_paraphrase'
7
+ torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
+
9
+ # Initialize tokenizer and model
10
+ tokenizer = PegasusTokenizer.from_pretrained(model_name)
11
+ model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
12
+
13
+ def get_response(input_text, num_return_sequences=1, num_beams=3):
14
+ """
15
+ Generate paraphrased text for a given input_text using the Pegasus model.
16
+
17
+ Args:
18
+ input_text (str): The text to paraphrase.
19
+ num_return_sequences (int): Number of paraphrased sequences to return.
20
+ num_beams (int): Number of beams for beam search.
21
+
22
+ Returns:
23
+ list: A list containing paraphrased text strings.
24
+ """
25
+ # Tokenize the input text
26
+ batch = tokenizer(
27
+ [input_text],
28
+ truncation=True,
29
+ padding='longest',
30
+ max_length=60,
31
+ return_tensors="pt"
32
+ ).to(torch_device)
33
+
34
+ # Generate paraphrased outputs
35
+ translated = model.generate(
36
+ **batch,
37
+ max_length=60,
38
+ num_beams=num_beams,
39
+ num_return_sequences=num_return_sequences,
40
+ temperature=0.7
41
+ )
42
+
43
+ # Decode the generated tokens
44
+ tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
45
+ return tgt_text
46
+
47
+ def split_text_by_fullstop(text):
48
+ """
49
+ Split the input text into sentences based on full stops.
50
+
51
+ Args:
52
+ text (str): The text to split.
53
+
54
+ Returns:
55
+ list: A list of sentences.
56
+ """
57
+ sentences = [sentence.strip() for sentence in text.split('.') if sentence]
58
+ return sentences
59
+
60
+ def process_text_by_fullstop(text, num_return_sequences=1, num_beams=3):
61
+ """
62
+ Process the input text by splitting it into sentences and paraphrasing each sentence.
63
+
64
+ Args:
65
+ text (str): The text to paraphrase.
66
+ num_return_sequences (int): Number of paraphrased sequences per sentence.
67
+ num_beams (int): Number of beams for beam search.
68
+
69
+ Returns:
70
+ str: The paraphrased text.
71
+ """
72
+ sentences = split_text_by_fullstop(text)
73
+ paraphrased_sentences = []
74
+
75
+ for sentence in sentences:
76
+ # Ensure each sentence ends with a period
77
+ sentence = sentence + '.' if not sentence.endswith('.') else sentence
78
+ paraphrases = get_response(sentence, num_return_sequences, num_beams)
79
+ paraphrased_sentences.extend(paraphrases)
80
+
81
+ # Join all paraphrased sentences into a single string
82
+ return ' '.join(paraphrased_sentences)
83
+
84
+ def paraphrase(text, num_beams, num_return_sequences):
85
+ """
86
+ Interface function to paraphrase input text based on user parameters.
87
+
88
+ Args:
89
+ text (str): The input text to paraphrase.
90
+ num_beams (int): Number of beams for beam search.
91
+ num_return_sequences (int): Number of paraphrased sequences to return.
92
+
93
+ Returns:
94
+ str: The paraphrased text.
95
+ """
96
+ return process_text_by_fullstop(text, num_return_sequences, num_beams)
97
+
98
+ # Define the Gradio interface
99
+ iface = gr.Interface(
100
+ fn=paraphrase,
101
+ inputs=[
102
+ gr.components.Textbox(
103
+ lines=10,
104
+ placeholder="Enter text here...",
105
+ label="Input Text"
106
  ),
107
+ gr.components.Slider(
108
+ minimum=1,
109
+ maximum=10,
110
+ step=1,
111
+ value=3,
112
+ label="Number of Beams"
113
+ ),
114
+ gr.components.Slider(
115
+ minimum=1,
116
+ maximum=5,
117
+ step=1,
118
+ value=1,
119
+ label="Number of Return Sequences"
120
+ )
121
  ],
122
+ outputs=gr.components.Textbox(
123
+ lines=10,
124
+ label="Paraphrased Text"
125
+ ),
126
+ title="Text Paraphrasing App",
127
+ description="Enter your text and adjust the parameters to receive paraphrased versions using the Pegasus model.",
128
+ allow_flagging="never"
129
  )
130
 
131
+ # Launch the app
132
  if __name__ == "__main__":
133
+ iface.launch()