Spaces:
Running
Running
File size: 1,759 Bytes
3c2639a a402b79 147774f a402b79 147774f 3c2639a ce41854 3c2639a a402b79 3c2639a 147774f 3c2639a a402b79 3c2639a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import gradio as gr
from utils import (
device,
jina_tokenizer,
jina_model,
embeddings_predict_relevance,
stsb_model,
stsb_tokenizer,
cross_encoder_predict_relevance
)
def predict(system_prompt, user_prompt):
predicted_label_jina, probabilities_jina = embeddings_predict_relevance(system_prompt, user_prompt, jina_model, jina_tokenizer, device)
predicted_label_stsb, probabilities_stsb = cross_encoder_predict_relevance(system_prompt, user_prompt, stsb_model, stsb_tokenizer, device)
result = f"""
**Prediction Summary**
**1. Model: jinaai/jina-embeddings-v2-small-en**
- **Prediction**: {"π₯ Off-topic" if predicted_label_jina==1 else "π© On-topic"}
- **Probability of being off-topic**: {probabilities_jina[0][1]:.2%}
**2. Model: cross-encoder/stsb-roberta-base**
- **Prediction**: {"π₯ Off-topic" if predicted_label_stsb==1 else "π© On-topic"}
- **Probability of being off-topic**: {probabilities_stsb[0][1]:.2%}
"""
return result
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as app:
gr.Markdown("# Off-Topic Detection")
gr.Markdown("This is a CPU-only demo for `govtech/jina-embeddings-v2-small-en-off-topic` and `govtech/stsb-roberta-base-off-topic`")
with gr.Row():
system_prompt = gr.TextArea(label="System Prompt", lines=5)
user_prompt = gr.TextArea(label="User Prompt", lines=5)
# Button to run the prediction
get_classfication = gr.Button("Check Content")
output_result = gr.Markdown(label="Classification and Probabilities")
get_classfication.click(
fn=predict,
inputs=[system_prompt, user_prompt],
outputs=output_result
)
if __name__ == "__main__":
app.launch()
|