Spaces:
Runtime error
Runtime error
File size: 2,109 Bytes
8c2d966 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
app_code = """\
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Load the fine-tuned SBERT model from Hugging Face
model_name = "Steph974/SBERT-FineTuned-Classifier" # Your uploaded model
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure the model is on the correct device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
def predict_similarity(sentence1, sentence2):
\"\"\"
Predicts the probability of two sentences belonging to the same class (1) or different (0).
Returns probability instead of class label.
\"\"\"
# Tokenize input
inputs = tokenizer(sentence1, sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
inputs = {key: value.to(device) for key, value in inputs.items()} # Move tensors to model device
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
# Get probabilities
probabilities = F.softmax(outputs.logits, dim=1).cpu().numpy()[0]
proba_same = probabilities[1] # Probability that sentences are in the same class
proba_diff = probabilities[0] # Probability that sentences are different
return {
"Same Class Probability": round(proba_same * 100, 2),
"Different Class Probability": round(proba_diff * 100, 2)
}
# Gradio UI
interface = gr.Interface(
fn=predict_similarity,
inputs=[
gr.Textbox(label="Sentence 1", placeholder="Enter the first sentence..."),
gr.Textbox(label="Sentence 2", placeholder="Enter the second sentence...")
],
outputs=gr.Label(label="Prediction Probabilities"),
title="SBERT Sentence-Pair Similarity",
description="Enter two sentences and see how similar they are according to the fine-tuned SBERT model.",
theme="huggingface",
)
# Launch the Gradio app
interface.launch()
"""
# Save to app.py
with open("app.py", "w") as f:
f.write(app_code)
|