Spaces:
Running
on
Zero
Running
on
Zero
johnsonhung906
commited on
Commit
·
ef017a0
1
Parent(s):
d62afec
support llama, granite
Browse files
app.py
CHANGED
@@ -3,27 +3,45 @@ import gradio as gr
|
|
3 |
from utils import open_config, create_model
|
4 |
from detector.attn import AttentionDetector
|
5 |
|
6 |
-
#
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
@spaces.GPU(duration=30)
|
14 |
-
def respond(message, threshold):
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
#
|
19 |
-
|
20 |
|
|
|
|
|
|
|
|
|
21 |
if detect_result[0]:
|
22 |
response = "Prompt injection detected!"
|
23 |
else:
|
24 |
-
|
|
|
25 |
|
26 |
-
#
|
27 |
response += f"\n\nFocus score: {detect_result[1]['focus_score']}"
|
28 |
return response
|
29 |
|
@@ -44,24 +62,25 @@ benign_prompts = [
|
|
44 |
# Define Gradio interface components
|
45 |
message_input = gr.Textbox(placeholder="Enter your message here...", label="Your Message")
|
46 |
threshold_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.05, label="Threshold")
|
|
|
47 |
response_output = gr.Textbox(label="Response")
|
48 |
|
49 |
-
# Gradio interface
|
50 |
with gr.Interface(
|
51 |
fn=respond,
|
52 |
-
inputs=[message_input, threshold_slider],
|
53 |
outputs=response_output,
|
54 |
-
title="Attention Tracker
|
55 |
) as demo:
|
56 |
with gr.Tab("Benign Prompts"):
|
57 |
gr.Examples(
|
58 |
-
benign_prompts,
|
59 |
-
inputs=[message_input], #
|
60 |
)
|
61 |
with gr.Tab("Malicious Prompts (Prompt Injection Attack)"):
|
62 |
gr.Examples(
|
63 |
-
injection_prompts,
|
64 |
-
inputs=[message_input],
|
65 |
)
|
66 |
gr.Markdown(
|
67 |
"### This website is developed and maintained by [Kuo-Han Hung](https://khhung-906.github.io/)"
|
|
|
3 |
from utils import open_config, create_model
|
4 |
from detector.attn import AttentionDetector
|
5 |
|
6 |
+
# Define model configuration paths
|
7 |
+
model_configs = {
|
8 |
+
"granite3_8b": "./configs/model_configs/granite3_8b-attn_config.json",
|
9 |
+
"llama3_8b": "./configs/model_configs/llama3_8b-attn_config.json",
|
10 |
+
"qwen2_1.5b": "./configs/model_configs/qwen2-attn_config.json",
|
11 |
+
}
|
12 |
|
13 |
+
# Load all models and create their corresponding detectors
|
14 |
+
models = {}
|
15 |
+
detectors = {}
|
16 |
+
for name, config_path in model_configs.items():
|
17 |
+
config = open_config(config_path=config_path)
|
18 |
+
model_instance = create_model(config=config)
|
19 |
+
models[name] = model_instance
|
20 |
+
detectors[name] = AttentionDetector(model_instance)
|
21 |
|
22 |
@spaces.GPU(duration=30)
|
23 |
+
def respond(message, threshold, model_name):
|
24 |
+
"""
|
25 |
+
Run the prompt injection detection and inference using the selected model.
|
26 |
+
"""
|
27 |
+
# Select the model and its detector based on the user's choice
|
28 |
+
selected_detector = detectors[model_name]
|
29 |
+
selected_model = models[model_name]
|
30 |
|
31 |
+
# Set the detection threshold
|
32 |
+
selected_detector.threshold = threshold
|
33 |
|
34 |
+
# Perform prompt injection detection
|
35 |
+
detect_result = selected_detector.detect(message)
|
36 |
+
|
37 |
+
# If injection is detected, return a warning; otherwise, perform inference.
|
38 |
if detect_result[0]:
|
39 |
response = "Prompt injection detected!"
|
40 |
else:
|
41 |
+
# Unpack the response from inference (assuming the first element is the text)
|
42 |
+
response, _, _, _, _, _ = selected_model.inference("You are a friendly Chatbot.", message)
|
43 |
|
44 |
+
# Append the focus score to the response.
|
45 |
response += f"\n\nFocus score: {detect_result[1]['focus_score']}"
|
46 |
return response
|
47 |
|
|
|
62 |
# Define Gradio interface components
|
63 |
message_input = gr.Textbox(placeholder="Enter your message here...", label="Your Message")
|
64 |
threshold_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.05, label="Threshold")
|
65 |
+
model_selector = gr.Radio(choices=list(model_configs.keys()), value="llama3_8b", label="Select Model")
|
66 |
response_output = gr.Textbox(label="Response")
|
67 |
|
68 |
+
# Build the Gradio interface (using the Blocks API via Interface)
|
69 |
with gr.Interface(
|
70 |
fn=respond,
|
71 |
+
inputs=[message_input, threshold_slider, model_selector],
|
72 |
outputs=response_output,
|
73 |
+
title="Attention Tracker"
|
74 |
) as demo:
|
75 |
with gr.Tab("Benign Prompts"):
|
76 |
gr.Examples(
|
77 |
+
examples=benign_prompts,
|
78 |
+
inputs=[message_input], # Only the message input is prefilled by these examples
|
79 |
)
|
80 |
with gr.Tab("Malicious Prompts (Prompt Injection Attack)"):
|
81 |
gr.Examples(
|
82 |
+
examples=injection_prompts,
|
83 |
+
inputs=[message_input],
|
84 |
)
|
85 |
gr.Markdown(
|
86 |
"### This website is developed and maintained by [Kuo-Han Hung](https://khhung-906.github.io/)"
|
configs/model_configs/granite3_8b-attn_config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_info": {
|
3 |
+
"provider": "attn-hf",
|
4 |
+
"name": "granite3-8b-attn",
|
5 |
+
"model_id": "ibm-granite/granite-3.1-8b-instruct"
|
6 |
+
},
|
7 |
+
"params": {
|
8 |
+
"temperature": 0.1,
|
9 |
+
"max_output_tokens": 32,
|
10 |
+
"important_heads": [[6, 9], [7, 20], [8, 1], [8, 13], [8, 14], [8, 15], [10, 2], [10, 3], [10, 6], [10, 21], [11, 4], [11, 30], [11, 31], [12, 2], [12, 28], [13, 8], [13, 9], [13, 12], [14, 15], [14, 16], [14, 19], [14, 27], [15, 6], [15, 7], [15, 20], [15, 23], [16, 12], [16, 14], [16, 16], [17, 7], [17, 11], [17, 15], [17, 19], [17, 21], [17, 25], [17, 26], [18, 9], [18, 17], [18, 20], [18, 28], [19, 1]]
|
11 |
+
}
|
12 |
+
}
|
configs/model_configs/llama3_8b-attn_config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_info": {
|
3 |
+
"provider": "attn-hf",
|
4 |
+
"name": "llama3-8b-attn-tensor",
|
5 |
+
"model_id": "meta-llama/Meta-Llama-3-8B-Instruct"
|
6 |
+
},
|
7 |
+
"params": {
|
8 |
+
"temperature": 0.1,
|
9 |
+
"max_output_tokens": 32,
|
10 |
+
"important_heads": [[5, 18], [7, 12], [9, 29], [17, 2]]
|
11 |
+
}
|
12 |
+
}
|
configs/model_configs/qwen2-attn_config.json
CHANGED
@@ -7,6 +7,6 @@
|
|
7 |
"params": {
|
8 |
"temperature": 0.1,
|
9 |
"max_output_tokens": 32,
|
10 |
-
"important_heads": [[11, 8], [12, 8], [
|
11 |
}
|
12 |
}
|
|
|
7 |
"params": {
|
8 |
"temperature": 0.1,
|
9 |
"max_output_tokens": 32,
|
10 |
+
"important_heads": [[10, 6], [11, 0], [11, 2], [11, 8], [11, 9], [11, 11], [12, 8], [13, 10], [14, 8], [15, 7], [15, 11], [17, 0], [18, 9], [19, 7]]
|
11 |
}
|
12 |
}
|
models/attn_model.py
CHANGED
@@ -67,12 +67,16 @@ class AttentionModel(Model):
|
|
67 |
input_tokens = self.tokenizer.convert_ids_to_tokens(
|
68 |
model_inputs['input_ids'][0])
|
69 |
|
70 |
-
if "qwen
|
71 |
data_range = ((3, 3+instruction_len), (-5-data_len, -5))
|
72 |
-
elif "phi3
|
73 |
data_range = ((1, 1+instruction_len), (-2-data_len, -2))
|
74 |
-
elif "
|
75 |
data_range = ((5, 5+instruction_len), (-5-data_len, -5))
|
|
|
|
|
|
|
|
|
76 |
else:
|
77 |
raise NotImplementedError
|
78 |
|
|
|
67 |
input_tokens = self.tokenizer.convert_ids_to_tokens(
|
68 |
model_inputs['input_ids'][0])
|
69 |
|
70 |
+
if "qwen" in self.name:
|
71 |
data_range = ((3, 3+instruction_len), (-5-data_len, -5))
|
72 |
+
elif "phi3" in self.name:
|
73 |
data_range = ((1, 1+instruction_len), (-2-data_len, -2))
|
74 |
+
elif "llama3-8b" in self.name:
|
75 |
data_range = ((5, 5+instruction_len), (-5-data_len, -5))
|
76 |
+
elif "mistral-7b" in self.name:
|
77 |
+
data_range = ((3, 3+instruction_len), (-1-data_len, -1))
|
78 |
+
elif "granite3-8b" in self.name:
|
79 |
+
data_range = ((3, 3+instruction_len), (-5-data_len, -5))
|
80 |
else:
|
81 |
raise NotImplementedError
|
82 |
|