gaurav0026 commited on
Commit
70a0fff
·
verified ·
1 Parent(s): ea87c49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -176
app.py CHANGED
@@ -1,14 +1,6 @@
1
- from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoModel, AutoTokenizer
2
- import torch
3
- from sklearn.metrics.pairwise import cosine_similarity
4
- import numpy as np
5
- import gradio as gr
6
- from collections import Counter
7
- import pandas as pd
8
-
9
  # Load paraphrase model and tokenizer
10
  model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_paraphraser')
11
- tokenizer = T5Tokenizer.from_pretrained('t5-base', legacy=False) # Explicitly set legacy=False
12
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  model = model.to(device)
@@ -27,9 +19,6 @@ def get_sentence_embedding(sentence):
27
 
28
  # Paraphrasing function
29
  def paraphrase_sentence(sentence):
30
- if not sentence.strip():
31
- return "Please enter a valid sentence."
32
-
33
  # Updated prompt for statement-like output
34
  text = "rephrase as a statement: " + sentence
35
  encoding = tokenizer.encode_plus(text, padding=False, return_tensors="pt")
@@ -40,10 +29,10 @@ def paraphrase_sentence(sentence):
40
  attention_mask=attention_masks,
41
  do_sample=True,
42
  max_length=128,
43
- top_k=40,
44
- top_p=0.85,
45
  early_stopping=True,
46
- num_return_sequences=5
47
  )
48
 
49
  # Decode and format paraphrases with numbering
@@ -69,16 +58,14 @@ def calculate_precision_recall_accuracy(sentences):
69
  original_tokens = Counter(sentence.lower().split())
70
 
71
  for paraphrase in paraphrases:
72
- if not paraphrase.strip():
73
- continue
74
  # Remove numbering before evaluation
75
- paraphrase_text = paraphrase.split(". ", 1)[1] if ". " in paraphrase else paraphrase
76
- paraphrase_embedding = get_sentence_embedding(paraphrase_text)
77
  similarity = cosine_similarity(original_embedding.cpu(), paraphrase_embedding.cpu())[0][0]
78
  total_similarity += similarity
79
 
80
  # Calculate precision and recall based on token overlap
81
- paraphrase_tokens = Counter(paraphrase_text.lower().split())
82
  overlap = sum((paraphrase_tokens & original_tokens).values())
83
  precision = overlap / sum(paraphrase_tokens.values()) if paraphrase_tokens else 0
84
  recall = overlap / sum(original_tokens.values()) if original_tokens else 0
@@ -88,159 +75,34 @@ def calculate_precision_recall_accuracy(sentences):
88
  paraphrase_count += 1
89
 
90
  # Calculate averages for accuracy, precision, and recall
91
- overall_accuracy = (total_similarity / paraphrase_count) * 100 if paraphrase_count else 0
92
- avg_precision = (total_precision / paraphrase_count) * 100 if paraphrase_count else 0
93
- avg_recall = (total_recall / paraphrase_count) * 100 if paraphrase_count else 0
94
-
95
- return (f"**Overall Model Accuracy (Semantic Similarity):** {overall_accuracy:.2f}%\n"
96
- f"**Average Precision (Token Overlap):** {avg_precision:.2f}%\n"
97
- f"**Average Recall (Token Overlap):** {avg_recall:.2f}%")
98
-
99
- # Custom CSS for aesthetic design
100
- custom_css = """
101
- body {
102
- background: linear-gradient(135deg, #e0e7ff, #c3dafe, #e0e7ff);
103
- font-family: 'Inter', sans-serif;
104
- }
105
- .gradio-container {
106
- max-width: 800px !important;
107
- margin: auto;
108
- padding: 20px;
109
- background: white;
110
- border-radius: 20px;
111
- box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1);
112
- }
113
- h1 {
114
- font-size: 2.5rem;
115
- font-weight: 700;
116
- background: linear-gradient(to right, #4f46e5, #7c3aed);
117
- -webkit-background-clip: text;
118
- -webkit-text-fill-color: transparent;
119
- text-align: center;
120
- margin-bottom: 1rem;
121
- }
122
- textarea, input {
123
- border: 2px solid #e0e7ff !important;
124
- border-radius: 10px !important;
125
- padding: 15px !important;
126
- transition: all 0.3s ease !important;
127
- }
128
- textarea:hover, input:hover {
129
- border-color: #a5b4fc !important;
130
- box-shadow: 0 0 10px rgba(79, 70, 229, 0.2) !important;
131
- }
132
- textarea:focus, input:focus {
133
- border-color: #4f46e5 !important;
134
- box-shadow: 0 0 15px rgba(79, 70, 229, 0.3) !important;
135
- }
136
- button {
137
- background: linear-gradient(to right, #4f46e5, #7c3aed) !important;
138
- color: white !important;
139
- font-weight: 600 !important;
140
- padding: 12px 24px !important;
141
- border-radius: 10px !important;
142
- border: none !important;
143
- transition: all 0.3s ease !important;
144
- }
145
- button:hover {
146
- background: linear-gradient(to right, #4338ca, #6d28d9) !important;
147
- transform: scale(1.05) !important;
148
- box-shadow: 0 5px 15px rgba(79, 70, 229, 0.4) !important;
149
- }
150
- button:disabled {
151
- background: linear-gradient(to right, #a3a3a3, #d1d5db) !important;
152
- transform: none !important;
153
- box-shadow: none !important;
154
- }
155
- .output-text {
156
- background: #f9fafb !important;
157
- border-radius: 10px !important;
158
- padding: 15px !important;
159
- border: 1px solid #e5e7eb !important;
160
- transition: all 0.3s ease !important;
161
- }
162
- .output-text:hover {
163
- background: #eff6ff !important;
164
- border-color: #a5b4fc !important;
165
- }
166
- footer {
167
- display: none !important;
168
- }
169
- """
170
-
171
- # Custom JavaScript for additional interactivity
172
- custom_js = """
173
- <script>
174
- document.addEventListener('DOMContentLoaded', () => {
175
- const textarea = document.querySelector('textarea');
176
- const button = document.querySelector('button');
177
-
178
- // Add typing animation effect
179
- textarea.addEventListener('input', () => {
180
- textarea.style.transform = 'scale(1.02)';
181
- setTimeout(() => {
182
- textarea.style.transform = 'scale(1)';
183
- }, 200);
184
- });
185
-
186
- // Button click animation
187
- button.addEventListener('click', () => {
188
- if (!button.disabled) {
189
- button.style.transform = 'scale(0.95)';
190
- setTimeout(() => {
191
- button.style.transform = 'scale(1)';
192
- }, 200);
193
- }
194
- });
195
- });
196
- </script>
197
- """
198
-
199
- # Define Gradio UI with enhanced aesthetics
200
- with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, js=custom_js) as demo:
201
- gr.Markdown(
202
- """
203
- # PARA-GEN: Aesthetic Paraphraser
204
- Enter a sentence below to generate five beautifully rephrased statements.
205
- """
206
- )
207
-
208
- with gr.Row():
209
- with gr.Column(scale=3):
210
- input_text = gr.Textbox(
211
- label="Input Sentence",
212
- placeholder="Type your sentence here...",
213
- lines=4,
214
- max_lines=4
215
- )
216
- paraphrase_button = gr.Button("Generate Paraphrases")
217
-
218
- with gr.Column(scale=2):
219
- output_text = gr.Textbox(
220
- label="Paraphrased Results",
221
- lines=10,
222
- interactive=False
223
- )
224
-
225
- with gr.Accordion("Model Performance Metrics", open=False):
226
- metrics_output = gr.Markdown()
227
-
228
- # Define button click behavior
229
- paraphrase_button.click(
230
- fn=paraphrase_sentence,
231
- inputs=input_text,
232
- outputs=output_text
233
- )
234
-
235
- # Calculate and display metrics on load without _js
236
- test_sentences = [
237
- "The quick brown fox jumps over the lazy dog.",
238
- "Artificial intelligence is transforming industries.",
239
- "The weather is sunny and warm today.",
240
- "He enjoys reading books on machine learning.",
241
- "The stock market fluctuates daily due to various factors."
242
- ]
243
- metrics_output.value = calculate_precision_recall_accuracy(test_sentences)
244
-
245
- # Launch Gradio app
246
- demo.launch(share=False)
 
 
 
 
 
 
 
 
 
1
  # Load paraphrase model and tokenizer
2
  model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_paraphraser')
3
+ tokenizer = T5Tokenizer.from_pretrained('t5-base')
4
 
5
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
  model = model.to(device)
 
19
 
20
  # Paraphrasing function
21
  def paraphrase_sentence(sentence):
 
 
 
22
  # Updated prompt for statement-like output
23
  text = "rephrase as a statement: " + sentence
24
  encoding = tokenizer.encode_plus(text, padding=False, return_tensors="pt")
 
29
  attention_mask=attention_masks,
30
  do_sample=True,
31
  max_length=128,
32
+ top_k=40, # Reduced top_k for less randomness
33
+ top_p=0.85, # Reduced top_p for focused sampling
34
  early_stopping=True,
35
+ num_return_sequences=5 # Generate 5 paraphrases
36
  )
37
 
38
  # Decode and format paraphrases with numbering
 
58
  original_tokens = Counter(sentence.lower().split())
59
 
60
  for paraphrase in paraphrases:
 
 
61
  # Remove numbering before evaluation
62
+ paraphrase = paraphrase.split(". ", 1)[1]
63
+ paraphrase_embedding = get_sentence_embedding(paraphrase)
64
  similarity = cosine_similarity(original_embedding.cpu(), paraphrase_embedding.cpu())[0][0]
65
  total_similarity += similarity
66
 
67
  # Calculate precision and recall based on token overlap
68
+ paraphrase_tokens = Counter(paraphrase.lower().split())
69
  overlap = sum((paraphrase_tokens & original_tokens).values())
70
  precision = overlap / sum(paraphrase_tokens.values()) if paraphrase_tokens else 0
71
  recall = overlap / sum(original_tokens.values()) if original_tokens else 0
 
75
  paraphrase_count += 1
76
 
77
  # Calculate averages for accuracy, precision, and recall
78
+ overall_accuracy = (total_similarity / paraphrase_count) * 100
79
+ avg_precision = (total_precision / paraphrase_count) * 100
80
+ avg_recall = (total_recall / paraphrase_count) * 100
81
+
82
+ print(f"Overall Model Accuracy (Semantic Similarity): {overall_accuracy:.2f}%")
83
+ print(f"Average Precision (Token Overlap): {avg_precision:.2f}%")
84
+ print(f"Average Recall (Token Overlap): {avg_recall:.2f}%")
85
+
86
+ # Define Gradio UI
87
+ iface = gr.Interface(
88
+ fn=paraphrase_sentence,
89
+ inputs="text",
90
+ outputs="text",
91
+ title="PARA-GEN (T5 Paraphraser)",
92
+ description="Enter a sentence, and the model will generate five numbered paraphrases in statement form."
93
+ )
94
+
95
+ # List of test sentences to evaluate metrics
96
+ test_sentences = [
97
+ "The quick brown fox jumps over the lazy dog.",
98
+ "Artificial intelligence is transforming industries.",
99
+ "The weather is sunny and warm today.",
100
+ "He enjoys reading books on machine learning.",
101
+ "The stock market fluctuates daily due to various factors."
102
+ ]
103
+
104
+ # Calculate overall accuracy, precision, and recall for the list of test sentences
105
+ calculate_precision_recall_accuracy(test_sentences)
106
+
107
+ # Launch Gradio app (Gradio UI will not show metrics)
108
+ iface.launch(share=False)