Para-gen / app.py
gaurav0026's picture
upload model
4a9b381 verified
raw
history blame
4.89 kB
from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoModel, AutoTokenizer
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import gradio as gr
from collections import Counter
import pandas as pd
# Load paraphrase model and tokenizer
model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_paraphraser')
tokenizer = T5Tokenizer.from_pretrained('t5-base')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Load Sentence-BERT model for semantic similarity calculation
embed_model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
embed_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
embed_model = embed_model.to(device)
# Function to get sentence embeddings
def get_sentence_embedding(sentence):
inputs = embed_tokenizer(sentence, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
embeddings = embed_model(**inputs).last_hidden_state.mean(dim=1)
return embeddings
# Paraphrasing function
def paraphrase_sentence(sentence):
# Updated prompt for statement-like output
text = "rephrase as a statement: " + sentence
encoding = tokenizer.encode_plus(text, padding=False, return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
beam_outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_masks,
do_sample=True,
max_length=128,
top_k=40, # Reduced top_k for less randomness
top_p=0.85, # Reduced top_p for focused sampling
early_stopping=True,
num_return_sequences=5 # Generate 5 paraphrases
)
# Decode and format paraphrases with numbering
paraphrases = []
for i, line in enumerate(beam_outputs, 1):
paraphrase = tokenizer.decode(line, skip_special_tokens=True, clean_up_tokenization_spaces=True)
paraphrases.append(f"{i}. {paraphrase}")
return "\n".join(paraphrases)
# Precision, Recall, and Overall Accuracy Calculation
def calculate_precision_recall_accuracy(sentences):
total_similarity = 0
paraphrase_count = 0
total_precision = 0
total_recall = 0
for sentence in sentences:
paraphrases = paraphrase_sentence(sentence).split("\n")
# Get the original embedding and token counts
original_embedding = get_sentence_embedding(sentence)
original_tokens = Counter(sentence.lower().split())
for paraphrase in paraphrases:
# Remove numbering before evaluation
paraphrase = paraphrase.split(". ", 1)[1]
paraphrase_embedding = get_sentence_embedding(paraphrase)
similarity = cosine_similarity(original_embedding.cpu(), paraphrase_embedding.cpu())[0][0]
total_similarity += similarity
# Calculate precision and recall based on token overlap
paraphrase_tokens = Counter(paraphrase.lower().split())
overlap = sum((paraphrase_tokens & original_tokens).values())
precision = overlap / sum(paraphrase_tokens.values()) if paraphrase_tokens else 0
recall = overlap / sum(original_tokens.values()) if original_tokens else 0
total_precision += precision
total_recall += recall
paraphrase_count += 1
# Calculate averages for accuracy, precision, and recall
overall_accuracy = (total_similarity / paraphrase_count) * 100
avg_precision = (total_precision / paraphrase_count) * 100
avg_recall = (total_recall / paraphrase_count) * 100
print(f"Overall Model Accuracy (Semantic Similarity): {overall_accuracy:.2f}%")
print(f"Average Precision (Token Overlap): {avg_precision:.2f}%")
print(f"Average Recall (Token Overlap): {avg_recall:.2f}%")
# Define Gradio UI
iface = gr.Interface(
fn=paraphrase_sentence,
inputs="text",
outputs="text",
title="PARA-GEN (T5 Paraphraser)",
description="Enter a sentence, and the model will generate five numbered paraphrases in statement form."
)
# List of test sentences to evaluate metrics
test_sentences = [
"The quick brown fox jumps over the lazy dog.",
"Artificial intelligence is transforming industries.",
"The weather is sunny and warm today.",
"He enjoys reading books on machine learning.",
"The stock market fluctuates daily due to various factors."
]
# Calculate overall accuracy, precision, and recall for the list of test sentences
calculate_precision_recall_accuracy(test_sentences)
# Launch Gradio app (Gradio UI will not show metrics)
iface.launch(share=False)