bakhil-aissa commited on
Commit
79efd3c
·
verified ·
1 Parent(s): f9360bd

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +72 -39
  2. requirements.txt +5 -9
app.py CHANGED
@@ -1,9 +1,8 @@
1
- import streamlit as st
2
- import pandas as pd
3
  import numpy as np
4
  import onnxruntime as ort
5
  from transformers import AutoTokenizer
6
- from huggingface_hub import hf_hub_download
7
  import os
8
 
9
  # Global variables to store loaded models
@@ -19,10 +18,10 @@ def load_models():
19
 
20
  if sess is None:
21
  if os.path.exists("model_f16.onnx"):
22
- st.write("Model already downloaded.")
23
  model_path = "model_f16.onnx"
24
  else:
25
- st.write("Downloading model...")
26
  model_path = hf_hub_download(
27
  repo_id="bakhil-aissa/anti_prompt_injection",
28
  filename="model_f16.onnx",
@@ -33,41 +32,75 @@ def load_models():
33
 
34
  return tokenizer, sess
35
 
36
- def predict(text):
37
  """Predict function that uses the loaded models"""
38
- enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
39
- inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
40
- logits = sess.run(["logits"], inputs)[0]
41
- exp = np.exp(logits)
42
- probs = exp / exp.sum(axis=1, keepdims=True) # shape (1, num_classes)
43
- return probs
44
-
45
- def main():
46
- st.title("Anti Prompt Injection Detection")
47
-
48
- # Load models when needed
49
- global tokenizer, sess
50
- tokenizer, sess = load_models()
51
 
52
- st.subheader("Enter your text to check for prompt injection:")
53
- text_input = st.text_area("Text Input", height=200)
54
- confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- if st.button("Check"):
57
- if text_input:
58
- try:
59
- with st.spinner("Processing..."):
60
- # Call the predict function
61
- probs = predict(text_input)
62
- jailbreak_prob = float(probs[0][1]) # index into batch
63
- is_jailbreak = jailbreak_prob >= confidence_threshold
64
-
65
- st.success(f"Is Jailbreak: {is_jailbreak}")
66
- st.info(f"Jailbreak Probability: {jailbreak_prob:.4f}")
67
- except Exception as e:
68
- st.error(f"Error: {str(e)}")
69
- else:
70
- st.warning("Please enter some text to check.")
71
 
72
- # Only define functions, don't execute anything
73
- # Streamlit will automatically run the script when it's ready
 
 
 
 
 
 
1
+ import gradio as gr
 
2
  import numpy as np
3
  import onnxruntime as ort
4
  from transformers import AutoTokenizer
5
+ from huggingface_hub import hf_hub_download
6
  import os
7
 
8
  # Global variables to store loaded models
 
18
 
19
  if sess is None:
20
  if os.path.exists("model_f16.onnx"):
21
+ print("Model already downloaded.")
22
  model_path = "model_f16.onnx"
23
  else:
24
+ print("Downloading model...")
25
  model_path = hf_hub_download(
26
  repo_id="bakhil-aissa/anti_prompt_injection",
27
  filename="model_f16.onnx",
 
32
 
33
  return tokenizer, sess
34
 
35
+ def predict(text, confidence_threshold):
36
  """Predict function that uses the loaded models"""
37
+ if not text.strip():
38
+ return "Please enter some text to check.", 0.0, False
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ try:
41
+ # Load models if not already loaded
42
+ load_models()
43
+
44
+ # Make prediction
45
+ enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
46
+ inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
47
+ logits = sess.run(["logits"], inputs)[0]
48
+ exp = np.exp(logits)
49
+ probs = exp / exp.sum(axis=1, keepdims=True)
50
+
51
+ jailbreak_prob = float(probs[0][1])
52
+ is_jailbreak = jailbreak_prob >= confidence_threshold
53
+
54
+ result_text = f"Is Jailbreak: {is_jailbreak}"
55
+ return result_text, jailbreak_prob, is_jailbreak
56
+
57
+ except Exception as e:
58
+ return f"Error: {str(e)}", 0.0, False
59
+
60
+ # Create Gradio interface
61
+ def create_interface():
62
+ with gr.Blocks(title="Anti Prompt Injection Detection") as demo:
63
+ gr.Markdown("# 🚫 Anti Prompt Injection Detection")
64
+ gr.Markdown("Enter your text to check for prompt injection attempts.")
65
+
66
+ with gr.Row():
67
+ with gr.Column():
68
+ text_input = gr.Textbox(
69
+ label="Text Input",
70
+ placeholder="Enter text to analyze...",
71
+ lines=5,
72
+ max_lines=10
73
+ )
74
+ confidence_threshold = gr.Slider(
75
+ minimum=0.0,
76
+ maximum=1.0,
77
+ value=0.5,
78
+ step=0.01,
79
+ label="Confidence Threshold"
80
+ )
81
+ check_button = gr.Button("Check Text", variant="primary")
82
+
83
+ with gr.Column():
84
+ result_text = gr.Textbox(label="Result", interactive=False)
85
+ probability = gr.Number(label="Jailbreak Probability", precision=4)
86
+ is_jailbreak = gr.Checkbox(label="Is Jailbreak", interactive=False)
87
+
88
+ # Set up the prediction
89
+ check_button.click(
90
+ fn=predict,
91
+ inputs=[text_input, confidence_threshold],
92
+ outputs=[result_text, probability, is_jailbreak]
93
+ )
94
+
95
+ gr.Markdown("---")
96
+ gr.Markdown("**How it works:** This tool analyzes text to detect potential prompt injection attempts that could bypass AI safety measures.")
97
 
98
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ # Create and launch the interface
101
+ if __name__ == "__main__":
102
+ demo = create_interface()
103
+ demo.launch()
104
+ else:
105
+ # For Hugging Face Spaces
106
+ demo = create_interface()
requirements.txt CHANGED
@@ -1,9 +1,5 @@
1
- fastapi
2
- huggingface_hub
3
- numpy
4
- onnxruntime
5
- pandas
6
- pydantic
7
- streamlit
8
- transformers
9
- torch
 
1
+ gradio>=4.0.0
2
+ transformers>=4.30.0
3
+ onnxruntime>=1.15.0
4
+ numpy>=1.21.0
5
+ huggingface_hub>=0.16.0