johnsonhung906 commited on
Commit
ef017a0
·
1 Parent(s): d62afec

support llama, granite

Browse files
app.py CHANGED
@@ -3,27 +3,45 @@ import gradio as gr
3
  from utils import open_config, create_model
4
  from detector.attn import AttentionDetector
5
 
6
- # Load model configuration and initialize the detector
7
- model_config_path = f"./configs/model_configs/qwen2-attn_config.json"
8
- model_config = open_config(config_path=model_config_path)
9
- model = create_model(config=model_config)
 
 
10
 
11
- detector = AttentionDetector(model)
 
 
 
 
 
 
 
12
 
13
  @spaces.GPU(duration=30)
14
- def respond(message, threshold):
15
- # Set threshold for detection
16
- detector.threshold = threshold
 
 
 
 
17
 
18
- # Detect prompt injection
19
- detect_result = detector.detect(message)
20
 
 
 
 
 
21
  if detect_result[0]:
22
  response = "Prompt injection detected!"
23
  else:
24
- response, _, _, _, _, _ = model.inference("You are a friendly Chatbot.", message)
 
25
 
26
- # Include focus score in the response
27
  response += f"\n\nFocus score: {detect_result[1]['focus_score']}"
28
  return response
29
 
@@ -44,24 +62,25 @@ benign_prompts = [
44
  # Define Gradio interface components
45
  message_input = gr.Textbox(placeholder="Enter your message here...", label="Your Message")
46
  threshold_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.05, label="Threshold")
 
47
  response_output = gr.Textbox(label="Response")
48
 
49
- # Gradio interface
50
  with gr.Interface(
51
  fn=respond,
52
- inputs=[message_input, threshold_slider],
53
  outputs=response_output,
54
- title="Attention Tracker - Qwen-1.5b-instruct"
55
  ) as demo:
56
  with gr.Tab("Benign Prompts"):
57
  gr.Examples(
58
- benign_prompts,
59
- inputs=[message_input], # Correctly reference the input component
60
  )
61
  with gr.Tab("Malicious Prompts (Prompt Injection Attack)"):
62
  gr.Examples(
63
- injection_prompts,
64
- inputs=[message_input], # Correctly reference the input component
65
  )
66
  gr.Markdown(
67
  "### This website is developed and maintained by [Kuo-Han Hung](https://khhung-906.github.io/)"
 
3
  from utils import open_config, create_model
4
  from detector.attn import AttentionDetector
5
 
6
+ # Define model configuration paths
7
+ model_configs = {
8
+ "granite3_8b": "./configs/model_configs/granite3_8b-attn_config.json",
9
+ "llama3_8b": "./configs/model_configs/llama3_8b-attn_config.json",
10
+ "qwen2_1.5b": "./configs/model_configs/qwen2-attn_config.json",
11
+ }
12
 
13
+ # Load all models and create their corresponding detectors
14
+ models = {}
15
+ detectors = {}
16
+ for name, config_path in model_configs.items():
17
+ config = open_config(config_path=config_path)
18
+ model_instance = create_model(config=config)
19
+ models[name] = model_instance
20
+ detectors[name] = AttentionDetector(model_instance)
21
 
22
  @spaces.GPU(duration=30)
23
+ def respond(message, threshold, model_name):
24
+ """
25
+ Run the prompt injection detection and inference using the selected model.
26
+ """
27
+ # Select the model and its detector based on the user's choice
28
+ selected_detector = detectors[model_name]
29
+ selected_model = models[model_name]
30
 
31
+ # Set the detection threshold
32
+ selected_detector.threshold = threshold
33
 
34
+ # Perform prompt injection detection
35
+ detect_result = selected_detector.detect(message)
36
+
37
+ # If injection is detected, return a warning; otherwise, perform inference.
38
  if detect_result[0]:
39
  response = "Prompt injection detected!"
40
  else:
41
+ # Unpack the response from inference (assuming the first element is the text)
42
+ response, _, _, _, _, _ = selected_model.inference("You are a friendly Chatbot.", message)
43
 
44
+ # Append the focus score to the response.
45
  response += f"\n\nFocus score: {detect_result[1]['focus_score']}"
46
  return response
47
 
 
62
  # Define Gradio interface components
63
  message_input = gr.Textbox(placeholder="Enter your message here...", label="Your Message")
64
  threshold_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.05, label="Threshold")
65
+ model_selector = gr.Radio(choices=list(model_configs.keys()), value="llama3_8b", label="Select Model")
66
  response_output = gr.Textbox(label="Response")
67
 
68
+ # Build the Gradio interface (using the Blocks API via Interface)
69
  with gr.Interface(
70
  fn=respond,
71
+ inputs=[message_input, threshold_slider, model_selector],
72
  outputs=response_output,
73
+ title="Attention Tracker"
74
  ) as demo:
75
  with gr.Tab("Benign Prompts"):
76
  gr.Examples(
77
+ examples=benign_prompts,
78
+ inputs=[message_input], # Only the message input is prefilled by these examples
79
  )
80
  with gr.Tab("Malicious Prompts (Prompt Injection Attack)"):
81
  gr.Examples(
82
+ examples=injection_prompts,
83
+ inputs=[message_input],
84
  )
85
  gr.Markdown(
86
  "### This website is developed and maintained by [Kuo-Han Hung](https://khhung-906.github.io/)"
configs/model_configs/granite3_8b-attn_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_info": {
3
+ "provider": "attn-hf",
4
+ "name": "granite3-8b-attn",
5
+ "model_id": "ibm-granite/granite-3.1-8b-instruct"
6
+ },
7
+ "params": {
8
+ "temperature": 0.1,
9
+ "max_output_tokens": 32,
10
+ "important_heads": [[6, 9], [7, 20], [8, 1], [8, 13], [8, 14], [8, 15], [10, 2], [10, 3], [10, 6], [10, 21], [11, 4], [11, 30], [11, 31], [12, 2], [12, 28], [13, 8], [13, 9], [13, 12], [14, 15], [14, 16], [14, 19], [14, 27], [15, 6], [15, 7], [15, 20], [15, 23], [16, 12], [16, 14], [16, 16], [17, 7], [17, 11], [17, 15], [17, 19], [17, 21], [17, 25], [17, 26], [18, 9], [18, 17], [18, 20], [18, 28], [19, 1]]
11
+ }
12
+ }
configs/model_configs/llama3_8b-attn_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_info": {
3
+ "provider": "attn-hf",
4
+ "name": "llama3-8b-attn-tensor",
5
+ "model_id": "meta-llama/Meta-Llama-3-8B-Instruct"
6
+ },
7
+ "params": {
8
+ "temperature": 0.1,
9
+ "max_output_tokens": 32,
10
+ "important_heads": [[5, 18], [7, 12], [9, 29], [17, 2]]
11
+ }
12
+ }
configs/model_configs/qwen2-attn_config.json CHANGED
@@ -7,6 +7,6 @@
7
  "params": {
8
  "temperature": 0.1,
9
  "max_output_tokens": 32,
10
- "important_heads": [[11, 8], [12, 8], [14, 10], [19, 7]]
11
  }
12
  }
 
7
  "params": {
8
  "temperature": 0.1,
9
  "max_output_tokens": 32,
10
+ "important_heads": [[10, 6], [11, 0], [11, 2], [11, 8], [11, 9], [11, 11], [12, 8], [13, 10], [14, 8], [15, 7], [15, 11], [17, 0], [18, 9], [19, 7]]
11
  }
12
  }
models/attn_model.py CHANGED
@@ -67,12 +67,16 @@ class AttentionModel(Model):
67
  input_tokens = self.tokenizer.convert_ids_to_tokens(
68
  model_inputs['input_ids'][0])
69
 
70
- if "qwen-attn" in self.name:
71
  data_range = ((3, 3+instruction_len), (-5-data_len, -5))
72
- elif "phi3-attn" in self.name:
73
  data_range = ((1, 1+instruction_len), (-2-data_len, -2))
74
- elif "llama2-13b" in self.name or "llama3-8b" in self.name:
75
  data_range = ((5, 5+instruction_len), (-5-data_len, -5))
 
 
 
 
76
  else:
77
  raise NotImplementedError
78
 
 
67
  input_tokens = self.tokenizer.convert_ids_to_tokens(
68
  model_inputs['input_ids'][0])
69
 
70
+ if "qwen" in self.name:
71
  data_range = ((3, 3+instruction_len), (-5-data_len, -5))
72
+ elif "phi3" in self.name:
73
  data_range = ((1, 1+instruction_len), (-2-data_len, -2))
74
+ elif "llama3-8b" in self.name:
75
  data_range = ((5, 5+instruction_len), (-5-data_len, -5))
76
+ elif "mistral-7b" in self.name:
77
+ data_range = ((3, 3+instruction_len), (-1-data_len, -1))
78
+ elif "granite3-8b" in self.name:
79
+ data_range = ((3, 3+instruction_len), (-5-data_len, -5))
80
  else:
81
  raise NotImplementedError
82