File size: 2,073 Bytes
3c2639a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147774f
 
 
 
 
 
 
3c2639a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147774f
3c2639a
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr

from utils import (
    device,
    jina_tokenizer,
    jina_model,
    embeddings_predict_relevance,
    stsb_model,
    stsb_tokenizer,
    ms_model,
    ms_tokenizer,
    cross_encoder_predict_relevance
)

def predict(system_prompt, user_prompt, selected_model):
    if selected_model == "jinaai/jina-embeddings-v2-small-en":
        predicted_label, probabilities = embeddings_predict_relevance(system_prompt, user_prompt, jina_model, jina_tokenizer, device)
    elif selected_model == "cross-encoder/stsb-roberta-base":
        predicted_label, probabilities = cross_encoder_predict_relevance(system_prompt, user_prompt, stsb_model, stsb_tokenizer, device)
    elif selected_model == "cross-encoder/ms-marco-MiniLM-L-6-v2":
        predicted_label, probabilities = cross_encoder_predict_relevance(system_prompt, user_prompt, ms_model, ms_tokenizer, device)

    probability_off_topic = probabilities[0][1] * 100
    label = "Off-topic" if predicted_label==1 else "On-topic"
    result = f"""
    **Prediction Summary**:

    - **Predicted Label**: {label}
    - **Probability of Off-topic**: {probability_off_topic:.3f}%
    """

    return result

with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as app:

    gr.Markdown("# Off-Topic Classification using Fine-tuned Embeddings and Cross-Encoder Models")

    with gr.Row():
        system_prompt = gr.Textbox(label="System Prompt")
        user_prompt = gr.Textbox(label="User Prompt")

    with gr.Row():
        selected_model = gr.Dropdown(
            ["jinaai/jina-embeddings-v2-small-en",
             "cross-encoder/stsb-roberta-base",
             "cross-encoder/ms-marco-MiniLM-L-6-v2"],
            label="Select a model")

    # Button to run the prediction
    get_classfication = gr.Button("Check Content")

    output_result = gr.Markdown(label="Classification and Probabilities")

    get_classfication.click(
        fn=predict,
        inputs=[system_prompt, user_prompt, selected_model],
        outputs=output_result
    )

if __name__ == "__main__":
    app.launch()