off-topic-demo / app.py
Shing Yee
Update application
147774f unverified
raw
history blame
2.07 kB
import gradio as gr
from utils import (
device,
jina_tokenizer,
jina_model,
embeddings_predict_relevance,
stsb_model,
stsb_tokenizer,
ms_model,
ms_tokenizer,
cross_encoder_predict_relevance
)
def predict(system_prompt, user_prompt, selected_model):
if selected_model == "jinaai/jina-embeddings-v2-small-en":
predicted_label, probabilities = embeddings_predict_relevance(system_prompt, user_prompt, jina_model, jina_tokenizer, device)
elif selected_model == "cross-encoder/stsb-roberta-base":
predicted_label, probabilities = cross_encoder_predict_relevance(system_prompt, user_prompt, stsb_model, stsb_tokenizer, device)
elif selected_model == "cross-encoder/ms-marco-MiniLM-L-6-v2":
predicted_label, probabilities = cross_encoder_predict_relevance(system_prompt, user_prompt, ms_model, ms_tokenizer, device)
probability_off_topic = probabilities[0][1] * 100
label = "Off-topic" if predicted_label==1 else "On-topic"
result = f"""
**Prediction Summary**:
- **Predicted Label**: {label}
- **Probability of Off-topic**: {probability_off_topic:.3f}%
"""
return result
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as app:
gr.Markdown("# Off-Topic Classification using Fine-tuned Embeddings and Cross-Encoder Models")
with gr.Row():
system_prompt = gr.Textbox(label="System Prompt")
user_prompt = gr.Textbox(label="User Prompt")
with gr.Row():
selected_model = gr.Dropdown(
["jinaai/jina-embeddings-v2-small-en",
"cross-encoder/stsb-roberta-base",
"cross-encoder/ms-marco-MiniLM-L-6-v2"],
label="Select a model")
# Button to run the prediction
get_classfication = gr.Button("Check Content")
output_result = gr.Markdown(label="Classification and Probabilities")
get_classfication.click(
fn=predict,
inputs=[system_prompt, user_prompt, selected_model],
outputs=output_result
)
if __name__ == "__main__":
app.launch()