gaurav0026 commited on
Commit
30ca17d
·
verified ·
1 Parent(s): 36b234c

updated ui in gradio

Browse files
Files changed (1) hide show
  1. app.py +251 -119
app.py CHANGED
@@ -1,119 +1,251 @@
1
- from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoModel, AutoTokenizer
2
- import torch
3
- from sklearn.metrics.pairwise import cosine_similarity
4
- import numpy as np
5
- import gradio as gr
6
- from collections import Counter
7
- import pandas as pd
8
-
9
-
10
-
11
-
12
- # Load paraphrase model and tokenizer
13
- model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_paraphraser')
14
- tokenizer = T5Tokenizer.from_pretrained('t5-base')
15
-
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- model = model.to(device)
18
-
19
- # Load Sentence-BERT model for semantic similarity calculation
20
- embed_model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
21
- embed_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
22
- embed_model = embed_model.to(device)
23
-
24
- # Function to get sentence embeddings
25
- def get_sentence_embedding(sentence):
26
- inputs = embed_tokenizer(sentence, return_tensors="pt", padding=True).to(device)
27
- with torch.no_grad():
28
- embeddings = embed_model(**inputs).last_hidden_state.mean(dim=1)
29
- return embeddings
30
-
31
- # Paraphrasing function
32
- def paraphrase_sentence(sentence):
33
- # Updated prompt for statement-like output
34
- text = "rephrase as a statement: " + sentence
35
- encoding = tokenizer.encode_plus(text, padding=False, return_tensors="pt")
36
- input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
37
-
38
- beam_outputs = model.generate(
39
- input_ids=input_ids,
40
- attention_mask=attention_masks,
41
- do_sample=True,
42
- max_length=128,
43
- top_k=40, # Reduced top_k for less randomness
44
- top_p=0.85, # Reduced top_p for focused sampling
45
- early_stopping=True,
46
- num_return_sequences=5 # Generate 5 paraphrases
47
- )
48
-
49
- # Decode and format paraphrases with numbering
50
- paraphrases = []
51
- for i, line in enumerate(beam_outputs, 1):
52
- paraphrase = tokenizer.decode(line, skip_special_tokens=True, clean_up_tokenization_spaces=True)
53
- paraphrases.append(f"{i}. {paraphrase}")
54
-
55
- return "\n".join(paraphrases)
56
-
57
- # Precision, Recall, and Overall Accuracy Calculation
58
- def calculate_precision_recall_accuracy(sentences):
59
- total_similarity = 0
60
- paraphrase_count = 0
61
- total_precision = 0
62
- total_recall = 0
63
-
64
- for sentence in sentences:
65
- paraphrases = paraphrase_sentence(sentence).split("\n")
66
-
67
- # Get the original embedding and token counts
68
- original_embedding = get_sentence_embedding(sentence)
69
- original_tokens = Counter(sentence.lower().split())
70
-
71
- for paraphrase in paraphrases:
72
- # Remove numbering before evaluation
73
- paraphrase = paraphrase.split(". ", 1)[1]
74
- paraphrase_embedding = get_sentence_embedding(paraphrase)
75
- similarity = cosine_similarity(original_embedding.cpu(), paraphrase_embedding.cpu())[0][0]
76
- total_similarity += similarity
77
-
78
- # Calculate precision and recall based on token overlap
79
- paraphrase_tokens = Counter(paraphrase.lower().split())
80
- overlap = sum((paraphrase_tokens & original_tokens).values())
81
- precision = overlap / sum(paraphrase_tokens.values()) if paraphrase_tokens else 0
82
- recall = overlap / sum(original_tokens.values()) if original_tokens else 0
83
-
84
- total_precision += precision
85
- total_recall += recall
86
- paraphrase_count += 1
87
-
88
- # Calculate averages for accuracy, precision, and recall
89
- overall_accuracy = (total_similarity / paraphrase_count) * 100
90
- avg_precision = (total_precision / paraphrase_count) * 100
91
- avg_recall = (total_recall / paraphrase_count) * 100
92
-
93
- print(f"Overall Model Accuracy (Semantic Similarity): {overall_accuracy:.2f}%")
94
- print(f"Average Precision (Token Overlap): {avg_precision:.2f}%")
95
- print(f"Average Recall (Token Overlap): {avg_recall:.2f}%")
96
-
97
- # Define Gradio UI
98
- iface = gr.Interface(
99
- fn=paraphrase_sentence,
100
- inputs="text",
101
- outputs="text",
102
- title="PARA-GEN (T5 Paraphraser)",
103
- description="Enter a sentence, and the model will generate five numbered paraphrases in statement form."
104
- )
105
-
106
- # List of test sentences to evaluate metrics
107
- test_sentences = [
108
- "The quick brown fox jumps over the lazy dog.",
109
- "Artificial intelligence is transforming industries.",
110
- "The weather is sunny and warm today.",
111
- "He enjoys reading books on machine learning.",
112
- "The stock market fluctuates daily due to various factors."
113
- ]
114
-
115
- # Calculate overall accuracy, precision, and recall for the list of test sentences
116
- calculate_precision_recall_accuracy(test_sentences)
117
-
118
- # Launch Gradio app (Gradio UI will not show metrics)
119
- iface.launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoModel, AutoTokenizer
2
+ import torch
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+ import numpy as np
5
+ import gradio as gr
6
+ from collections import Counter
7
+ import pandas as pd
8
+
9
+ # Load paraphrase model and tokenizer
10
+ model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_paraphraser')
11
+ tokenizer = T5Tokenizer.from_pretrained('t5-base')
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model = model.to(device)
15
+
16
+ # Load Sentence-BERT model for semantic similarity calculation
17
+ embed_model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
18
+ embed_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
19
+ embed_model = embed_model.to(device)
20
+
21
+ # Function to get sentence embeddings
22
+ def get_sentence_embedding(sentence):
23
+ inputs = embed_tokenizer(sentence, return_tensors="pt", padding=True).to(device)
24
+ with torch.no_grad():
25
+ embeddings = embed_model(**inputs).last_hidden_state.mean(dim=1)
26
+ return embeddings
27
+
28
+ # Paraphrasing function
29
+ def paraphrase_sentence(sentence):
30
+ if not sentence.strip():
31
+ return "Please enter a valid sentence."
32
+
33
+ # Updated prompt for statement-like output
34
+ text = "rephrase as a statement: " + sentence
35
+ encoding = tokenizer.encode_plus(text, padding=False, return_tensors="pt")
36
+ input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
37
+
38
+ beam_outputs = model.generate(
39
+ input_ids=input_ids,
40
+ attention_mask=attention_masks,
41
+ do_sample=True,
42
+ max_length=128,
43
+ top_k=40,
44
+ top_p=0.85,
45
+ early_stopping=True,
46
+ num_return_sequences=5
47
+ )
48
+
49
+ # Decode and format paraphrases with numbering
50
+ paraphrases = []
51
+ for i, line in enumerate(beam_outputs, 1):
52
+ paraphrase = tokenizer.decode(line, skip_special_tokens=True, clean_up_tokenization_spaces=True)
53
+ paraphrases.append(f"{i}. {paraphrase}")
54
+
55
+ return "\n".join(paraphrases)
56
+
57
+ # Precision, Recall, and Overall Accuracy Calculation
58
+ def calculate_precision_recall_accuracy(sentences):
59
+ total_similarity = 0
60
+ paraphrase_count = 0
61
+ total_precision = 0
62
+ total_recall = 0
63
+
64
+ for sentence in sentences:
65
+ paraphrases = paraphrase_sentence(sentence).split("\n")
66
+
67
+ # Get the original embedding and token counts
68
+ original_embedding = get_sentence_embedding(sentence)
69
+ original_tokens = Counter(sentence.lower().split())
70
+
71
+ for paraphrase in paraphrases:
72
+ if not paraphrase.strip():
73
+ continue
74
+ # Remove numbering before evaluation
75
+ paraphrase_text = paraphrase.split(". ", 1)[1] if ". " in paraphrase else paraphrase
76
+ paraphrase_embedding = get_sentence_embedding(paraphrase_text)
77
+ similarity = cosine_similarity(original_embedding.cpu(), paraphrase_embedding.cpu())[0][0]
78
+ total_similarity += similarity
79
+
80
+ # Calculate precision and recall based on token overlap
81
+ paraphrase_tokens = Counter(paraphrase_text.lower().split())
82
+ overlap = sum((paraphrase_tokens & original_tokens).values())
83
+ precision = overlap / sum(paraphrase_tokens.values()) if paraphrase_tokens else 0
84
+ recall = overlap / sum(original_tokens.values()) if original_tokens else 0
85
+
86
+ total_precision += precision
87
+ total_recall += recall
88
+ paraphrase_count += 1
89
+
90
+ # Calculate averages for accuracy, precision, and recall
91
+ overall_accuracy = (total_similarity / paraphrase_count) * 100 if paraphrase_count else 0
92
+ avg_precision = (total_precision / paraphrase_count) * 100 if paraphrase_count else 0
93
+ avg_recall = (total_recall / paraphrase_count) * 100 if paraphrase_count else 0
94
+
95
+ return (f"**Overall Model Accuracy (Semantic Similarity):** {overall_accuracy:.2f}%\n"
96
+ f"**Average Precision (Token Overlap):** {avg_precision:.2f}%\n"
97
+ f"**Average Recall (Token Overlap):** {avg_recall:.2f}%")
98
+
99
+ # Custom CSS for aesthetic design
100
+ custom_css = """
101
+ body {
102
+ background: linear-gradient(135deg, #e0e7ff, #c3dafe, #e0e7ff);
103
+ font-family: 'Inter', sans-serif;
104
+ }
105
+ .gradio-container {
106
+ max-width: 800px !important;
107
+ margin: auto;
108
+ padding: 20px;
109
+ background: white;
110
+ border-radius: 20px;
111
+ box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1);
112
+ }
113
+ h1 {
114
+ font-size: 2.5rem;
115
+ font-weight: 700;
116
+ background: linear-gradient(to right, #4f46e5, #7c3aed);
117
+ -webkit-background-clip: text;
118
+ -webkit-text-fill-color: transparent;
119
+ text-align: center;
120
+ margin-bottom: 1rem;
121
+ }
122
+ textarea, input {
123
+ border: 2px solid #e0e7ff !important;
124
+ border-radius: 10px !important;
125
+ padding: 15px !important;
126
+ transition: all 0.3s ease !important;
127
+ }
128
+ textarea:hover, input:hover {
129
+ border-color: #a5b4fc !important;
130
+ box-shadow: 0 0 10px rgba(79, 70, 229, 0.2) !important;
131
+ }
132
+ textarea:focus, input:focus {
133
+ border-color: #4f46e5 !important;
134
+ box-shadow: 0 0 15px rgba(79, 70, 229, 0.3) !important;
135
+ }
136
+ button {
137
+ background: linear-gradient(to right, #4f46e5, #7c3aed) !important;
138
+ color: white !important;
139
+ font-weight: 600 !important;
140
+ padding: 12px 24px !important;
141
+ border-radius: 10px !important;
142
+ border: none !important;
143
+ transition: all 0.3s ease !important;
144
+ }
145
+ button:hover {
146
+ background: linear-gradient(to right, #4338ca, #6d28d9) !important;
147
+ transform: scale(1.05) !important;
148
+ box-shadow: 0 5px 15px rgba(79, 70, 229, 0.4) !important;
149
+ }
150
+ button:disabled {
151
+ background: linear-gradient(to right, #a3a3a3, #d1d5db) !important;
152
+ transform: none !important;
153
+ box-shadow: none !important;
154
+ }
155
+ .output-text {
156
+ background: #f9fafb !important;
157
+ border-radius: 10px !important;
158
+ padding: 15px !important;
159
+ border: 1px solid #e5e7eb !important;
160
+ transition: all 0.3s ease !important;
161
+ }
162
+ .output-text:hover {
163
+ background: #eff6ff !important;
164
+ border-color: #a5b4fc !important;
165
+ }
166
+ footer {
167
+ display: none !important;
168
+ }
169
+ """
170
+
171
+ # Custom JavaScript for additional interactivity
172
+ custom_js = """
173
+ <script>
174
+ document.addEventListener('DOMContentLoaded', () => {
175
+ const textarea = document.querySelector('textarea');
176
+ const button = document.querySelector('button');
177
+
178
+ // Add typing animation effect
179
+ textarea.addEventListener('input', () => {
180
+ textarea.style.transform = 'scale(1.02)';
181
+ setTimeout(() => {
182
+ textarea.style.transform = 'scale(1)';
183
+ }, 200);
184
+ });
185
+
186
+ // Button click animation
187
+ button.addEventListener('click', () => {
188
+ if (!button.disabled) {
189
+ button.style.transform = 'scale(0.95)';
190
+ setTimeout(() => {
191
+ button.style.transform = 'scale(1)';
192
+ }, 200);
193
+ }
194
+ });
195
+ });
196
+ </script>
197
+ """
198
+
199
+ # Define Gradio UI with enhanced aesthetics
200
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, js=custom_js) as demo:
201
+ gr.Markdown(
202
+ """
203
+ # PARA-GEN: Aesthetic Paraphraser
204
+ Enter a sentence below to generate five beautifully rephrased statements.
205
+ """
206
+ )
207
+
208
+ with gr.Row():
209
+ with gr.Column(scale=3):
210
+ input_text = gr.Textbox(
211
+ label="Input Sentence",
212
+ placeholder="Type your sentence here...",
213
+ lines=4,
214
+ max_lines=4
215
+ )
216
+ paraphrase_button = gr.Button("Generate Paraphrases")
217
+
218
+ with gr.Column(scale=2):
219
+ output_text = gr.Textbox(
220
+ label="Paraphrased Results",
221
+ lines=10,
222
+ interactive=False
223
+ )
224
+
225
+ with gr.Accordion("Model Performance Metrics", open=False):
226
+ metrics_output = gr.Markdown()
227
+
228
+ # Define button click behavior
229
+ paraphrase_button.click(
230
+ fn=paraphrase_sentence,
231
+ inputs=input_text,
232
+ outputs=output_text
233
+ )
234
+
235
+ # Calculate and display metrics on load
236
+ test_sentences = [
237
+ "The quick brown fox jumps over the lazy dog.",
238
+ "Artificial intelligence is transforming industries.",
239
+ "The weather is sunny and warm today.",
240
+ "He enjoys reading books on machine learning.",
241
+ "The stock market fluctuates daily due to various factors."
242
+ ]
243
+ demo.load(
244
+ fn=calculate_precision_recall_accuracy,
245
+ inputs=None,
246
+ outputs=metrics_output,
247
+ _js="() => { return ['" + "', '".join(test_sentences) + "']; }"
248
+ )
249
+
250
+ # Launch Gradio app
251
+ demo.launch(share=False)