Spaces:
Running
Running
from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoModel, AutoTokenizer | |
import torch | |
from sklearn.metrics.pairwise import cosine_similarity | |
import numpy as np | |
import gradio as gr | |
from collections import Counter | |
import pandas as pd | |
# Load paraphrase model and tokenizer | |
model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_paraphraser') | |
tokenizer = T5Tokenizer.from_pretrained('t5-base', legacy=False) # Explicitly set legacy=False | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
# Load Sentence-BERT model for semantic similarity calculation | |
embed_model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | |
embed_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | |
embed_model = embed_model.to(device) | |
# Function to get sentence embeddings | |
def get_sentence_embedding(sentence): | |
inputs = embed_tokenizer(sentence, return_tensors="pt", padding=True).to(device) | |
with torch.no_grad(): | |
embeddings = embed_model(**inputs).last_hidden_state.mean(dim=1) | |
return embeddings | |
# Paraphrasing function | |
def paraphrase_sentence(sentence): | |
if not sentence.strip(): | |
return "Please enter a valid sentence." | |
# Updated prompt for statement-like output | |
text = "rephrase as a statement: " + sentence | |
encoding = tokenizer.encode_plus(text, padding=False, return_tensors="pt") | |
input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device) | |
beam_outputs = model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_masks, | |
do_sample=True, | |
max_length=128, | |
top_k=40, | |
top_p=0.85, | |
early_stopping=True, | |
num_return_sequences=5 | |
) | |
# Decode and format paraphrases with numbering | |
paraphrases = [] | |
for i, line in enumerate(beam_outputs, 1): | |
paraphrase = tokenizer.decode(line, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
paraphrases.append(f"{i}. {paraphrase}") | |
return "\n".join(paraphrases) | |
# Precision, Recall, and Overall Accuracy Calculation | |
def calculate_precision_recall_accuracy(sentences): | |
total_similarity = 0 | |
paraphrase_count = 0 | |
total_precision = 0 | |
total_recall = 0 | |
for sentence in sentences: | |
paraphrases = paraphrase_sentence(sentence).split("\n") | |
# Get the original embedding and token counts | |
original_embedding = get_sentence_embedding(sentence) | |
original_tokens = Counter(sentence.lower().split()) | |
for paraphrase in paraphrases: | |
if not paraphrase.strip(): | |
continue | |
# Remove numbering before evaluation | |
paraphrase_text = paraphrase.split(". ", 1)[1] if ". " in paraphrase else paraphrase | |
paraphrase_embedding = get_sentence_embedding(paraphrase_text) | |
similarity = cosine_similarity(original_embedding.cpu(), paraphrase_embedding.cpu())[0][0] | |
total_similarity += similarity | |
# Calculate precision and recall based on token overlap | |
paraphrase_tokens = Counter(paraphrase_text.lower().split()) | |
overlap = sum((paraphrase_tokens & original_tokens).values()) | |
precision = overlap / sum(paraphrase_tokens.values()) if paraphrase_tokens else 0 | |
recall = overlap / sum(original_tokens.values()) if original_tokens else 0 | |
total_precision += precision | |
total_recall += recall | |
paraphrase_count += 1 | |
# Calculate averages for accuracy, precision, and recall | |
overall_accuracy = (total_similarity / paraphrase_count) * 100 if paraphrase_count else 0 | |
avg_precision = (total_precision / paraphrase_count) * 100 if paraphrase_count else 0 | |
avg_recall = (total_recall / paraphrase_count) * 100 if paraphrase_count else 0 | |
return (f"**Overall Model Accuracy (Semantic Similarity):** {overall_accuracy:.2f}%\n" | |
f"**Average Precision (Token Overlap):** {avg_precision:.2f}%\n" | |
f"**Average Recall (Token Overlap):** {avg_recall:.2f}%") | |
# Custom CSS for aesthetic design | |
custom_css = """ | |
body { | |
background: linear-gradient(135deg, #e0e7ff, #c3dafe, #e0e7ff); | |
font-family: 'Inter', sans-serif; | |
} | |
.gradio-container { | |
max-width: 800px !important; | |
margin: auto; | |
padding: 20px; | |
background: white; | |
border-radius: 20px; | |
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1); | |
} | |
h1 { | |
font-size: 2.5rem; | |
font-weight: 700; | |
background: linear-gradient(to right, #4f46e5, #7c3aed); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
text-align: center; | |
margin-bottom: 1rem; | |
} | |
textarea, input { | |
border: 2px solid #e0e7ff !important; | |
border-radius: 10px !important; | |
padding: 15px !important; | |
transition: all 0.3s ease !important; | |
} | |
textarea:hover, input:hover { | |
border-color: #a5b4fc !important; | |
box-shadow: 0 0 10px rgba(79, 70, 229, 0.2) !important; | |
} | |
textarea:focus, input:focus { | |
border-color: #4f46e5 !important; | |
box-shadow: 0 0 15px rgba(79, 70, 229, 0.3) !important; | |
} | |
button { | |
background: linear-gradient(to right, #4f46e5, #7c3aed) !important; | |
color: white !important; | |
font-weight: 600 !important; | |
padding: 12px 24px !important; | |
border-radius: 10px !important; | |
border: none !important; | |
transition: all 0.3s ease !important; | |
} | |
button:hover { | |
background: linear-gradient(to right, #4338ca, #6d28d9) !important; | |
transform: scale(1.05) !important; | |
box-shadow: 0 5px 15px rgba(79, 70, 229, 0.4) !important; | |
} | |
button:disabled { | |
background: linear-gradient(to right, #a3a3a3, #d1d5db) !important; | |
transform: none !important; | |
box-shadow: none !important; | |
} | |
.output-text { | |
background: #f9fafb !important; | |
border-radius: 10px !important; | |
padding: 15px !important; | |
border: 1px solid #e5e7eb !important; | |
transition: all 0.3s ease !important; | |
} | |
.output-text:hover { | |
background: #eff6ff !important; | |
border-color: #a5b4fc !important; | |
} | |
footer { | |
display: none !important; | |
} | |
""" | |
# Custom JavaScript for additional interactivity | |
custom_js = """ | |
<script> | |
document.addEventListener('DOMContentLoaded', () => { | |
const textarea = document.querySelector('textarea'); | |
const button = document.querySelector('button'); | |
// Add typing animation effect | |
textarea.addEventListener('input', () => { | |
textarea.style.transform = 'scale(1.02)'; | |
setTimeout(() => { | |
textarea.style.transform = 'scale(1)'; | |
}, 200); | |
}); | |
// Button click animation | |
button.addEventListener('click', () => { | |
if (!button.disabled) { | |
button.style.transform = 'scale(0.95)'; | |
setTimeout(() => { | |
button.style.transform = 'scale(1)'; | |
}, 200); | |
} | |
}); | |
}); | |
</script> | |
""" | |
# Define Gradio UI with enhanced aesthetics | |
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, js=custom_js) as demo: | |
gr.Markdown( | |
""" | |
# PARA-GEN: Aesthetic Paraphraser | |
Enter a sentence below to generate five beautifully rephrased statements. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
input_text = gr.Textbox( | |
label="Input Sentence", | |
placeholder="Type your sentence here...", | |
lines=4, | |
max_lines=4 | |
) | |
paraphrase_button = gr.Button("Generate Paraphrases") | |
with gr.Column(scale=2): | |
output_text = gr.Textbox( | |
label="Paraphrased Results", | |
lines=10, | |
interactive=False | |
) | |
with gr.Accordion("Model Performance Metrics", open=False): | |
metrics_output = gr.Markdown() | |
# Define button click behavior | |
paraphrase_button.click( | |
fn=paraphrase_sentence, | |
inputs=input_text, | |
outputs=output_text | |
) | |
# Calculate and display metrics on load without _js | |
test_sentences = [ | |
"The quick brown fox jumps over the lazy dog.", | |
"Artificial intelligence is transforming industries.", | |
"The weather is sunny and warm today.", | |
"He enjoys reading books on machine learning.", | |
"The stock market fluctuates daily due to various factors." | |
] | |
metrics_output.value = calculate_precision_recall_accuracy(test_sentences) | |
# Launch Gradio app | |
demo.launch(share=False) |