bakhil-aissa commited on
Commit
e54e4ba
·
verified ·
1 Parent(s): f98f467

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +57 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from transformers import AutoTokenizer
7
+ from huggingface_hub import hf_hub_download
8
+
9
+
10
+ import os
11
+
12
+
13
+
14
+ # download the model from Hugging Face
15
+ tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large')
16
+ if os.path.exists("model_f16.onnx"):
17
+ st.write("Model already downloaded.")
18
+ else:
19
+ st.write("Downloading model...")
20
+ model_path = hf_hub_download(
21
+ repo_id="bakhil-aissa/anti_prompt_injection",
22
+ filename="model_f16.onnx",
23
+ local_dir_use_symlinks=False,
24
+ )
25
+
26
+ st.title("Anti Prompt Injection Detection")
27
+
28
+
29
+ # Load the ONNX model
30
+ sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
31
+ # Define the input form
32
+ def predict ( text ):
33
+ enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
34
+ inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
35
+ logits = sess.run(["logits"], inputs)[0]
36
+ exp = np.exp(logits)
37
+ probs = exp / exp.sum(axis=1, keepdims=True) # shape (1, num_classes)
38
+ return probs
39
+
40
+ st.subheader("Enter your text to check for prompt injection:")
41
+ text_input = st.text_area("Text Input", height=200)
42
+ confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
43
+ if st.button("Check"):
44
+ if text_input:
45
+ try:
46
+ with st.spinner("Processing..."):
47
+ # Call the predict function
48
+ probs = predict(text_input)
49
+ jailbreak_prob = float(probs[0][1]) # index into batch
50
+ is_jailbreak = jailbreak_prob >= confidence_threshold
51
+
52
+ st.success(f"Is Jailbreak: {is_jailbreak}")
53
+ st.info(f"Jailbreak Probability: {jailbreak_prob:.4f}")
54
+ except Exception as e:
55
+ st.error(f"Error: {str(e)}")
56
+ else:
57
+ st.warning("Please enter some text to check.")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.116.1
2
+ huggingface_hub==0.33.5
3
+ numpy==1.21.5
4
+ onnxruntime==1.22.0
5
+ pandas==2.3.1
6
+ pydantic==2.11.7
7
+ streamlit==1.44.1
8
+ transformers==4.53.3