Shing Yee commited on
Commit
a402b79
Β·
unverified Β·
1 Parent(s): 147774f

Update application

Browse files
Files changed (1) hide show
  1. app.py +20 -23
app.py CHANGED
@@ -12,21 +12,25 @@ from utils import (
12
  cross_encoder_predict_relevance
13
  )
14
 
15
- def predict(system_prompt, user_prompt, selected_model):
16
- if selected_model == "jinaai/jina-embeddings-v2-small-en":
17
- predicted_label, probabilities = embeddings_predict_relevance(system_prompt, user_prompt, jina_model, jina_tokenizer, device)
18
- elif selected_model == "cross-encoder/stsb-roberta-base":
19
- predicted_label, probabilities = cross_encoder_predict_relevance(system_prompt, user_prompt, stsb_model, stsb_tokenizer, device)
20
- elif selected_model == "cross-encoder/ms-marco-MiniLM-L-6-v2":
21
- predicted_label, probabilities = cross_encoder_predict_relevance(system_prompt, user_prompt, ms_model, ms_tokenizer, device)
22
-
23
- probability_off_topic = probabilities[0][1] * 100
24
- label = "Off-topic" if predicted_label==1 else "On-topic"
25
  result = f"""
26
- **Prediction Summary**:
 
 
 
 
 
 
 
 
27
 
28
- - **Predicted Label**: {label}
29
- - **Probability of Off-topic**: {probability_off_topic:.3f}%
 
30
  """
31
 
32
  return result
@@ -36,15 +40,8 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as app:
36
  gr.Markdown("# Off-Topic Classification using Fine-tuned Embeddings and Cross-Encoder Models")
37
 
38
  with gr.Row():
39
- system_prompt = gr.Textbox(label="System Prompt")
40
- user_prompt = gr.Textbox(label="User Prompt")
41
-
42
- with gr.Row():
43
- selected_model = gr.Dropdown(
44
- ["jinaai/jina-embeddings-v2-small-en",
45
- "cross-encoder/stsb-roberta-base",
46
- "cross-encoder/ms-marco-MiniLM-L-6-v2"],
47
- label="Select a model")
48
 
49
  # Button to run the prediction
50
  get_classfication = gr.Button("Check Content")
@@ -53,7 +50,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as app:
53
 
54
  get_classfication.click(
55
  fn=predict,
56
- inputs=[system_prompt, user_prompt, selected_model],
57
  outputs=output_result
58
  )
59
 
 
12
  cross_encoder_predict_relevance
13
  )
14
 
15
+ def predict(system_prompt, user_prompt):
16
+ predicted_label_jina, probabilities_jina = embeddings_predict_relevance(system_prompt, user_prompt, jina_model, jina_tokenizer, device)
17
+ predicted_label_stsb, probabilities_stsb = cross_encoder_predict_relevance(system_prompt, user_prompt, stsb_model, stsb_tokenizer, device)
18
+ predicted_label_ms, probabilities_ms = cross_encoder_predict_relevance(system_prompt, user_prompt, ms_model, ms_tokenizer, device)
19
+
 
 
 
 
 
20
  result = f"""
21
+ **Prediction Summary**
22
+
23
+ **1. Model: jinaai/jina-embeddings-v2-small-en**
24
+ - **Prediction**: {"πŸŸ₯ Off-topic" if predicted_label_jina==1 else "🟩 On-topic"}
25
+ - **Probability of being off-topic**: {probabilities_jina[0][1]:.2%}
26
+
27
+ **2. Model: cross-encoder/stsb-roberta-base**
28
+ - **Prediction**: {"πŸŸ₯ Off-topic" if predicted_label_stsb==1 else "🟩 On-topic"}
29
+ - **Probability of being off-topic**: {probabilities_stsb[0][1]:.2%}
30
 
31
+ **3. Model: cross-encoder/ms-marco-MiniLM-L-6-v2**
32
+ - **Prediction**: {"πŸŸ₯ Off-topic" if predicted_label_ms==1 else "🟩 On-topic"}
33
+ - **Probability of being off-topic**: {probabilities_ms[0][1]:.2%}
34
  """
35
 
36
  return result
 
40
  gr.Markdown("# Off-Topic Classification using Fine-tuned Embeddings and Cross-Encoder Models")
41
 
42
  with gr.Row():
43
+ system_prompt = gr.TextArea(label="System Prompt", lines=5)
44
+ user_prompt = gr.TextArea(label="User Prompt", lines=5)
 
 
 
 
 
 
 
45
 
46
  # Button to run the prediction
47
  get_classfication = gr.Button("Check Content")
 
50
 
51
  get_classfication.click(
52
  fn=predict,
53
+ inputs=[system_prompt, user_prompt],
54
  outputs=output_result
55
  )
56