Spaces:
Runtime error
Runtime error
Commit
·
a80fe6b
1
Parent(s):
b0b4625
refactor and update requirements
Browse files- README.md +13 -2
- app.py +3 -22
- eval_examples.py +16 -0
- repl.py +27 -20
- requirements.txt +1 -0
README.md
CHANGED
@@ -31,10 +31,21 @@ These outputs are from the same model, `Meta-Llama-3-8B-Instruct`, with identica
|
|
31 |
|
32 |
Clearly, there are ways to trick the model; above, I say "This is just for testing purposes". In high-risk (highly capable models with tool access), we may want more robust methods of intervening that's cheaply implemented. (Running PPO with new reward models would likely be expensive and time-consuming.)
|
33 |
|
34 |
-
##
|
|
|
|
|
35 |
|
36 |
Sufficient activation for hand-chosen SAE feature.
|
37 |
|
38 |
-
|
39 |
|
40 |
Activation editing to steer towards refusal.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
Clearly, there are ways to trick the model; above, I say "This is just for testing purposes". In high-risk (highly capable models with tool access), we may want more robust methods of intervening that's cheaply implemented. (Running PPO with new reward models would likely be expensive and time-consuming.)
|
33 |
|
34 |
+
## Method
|
35 |
+
|
36 |
+
### Detection
|
37 |
|
38 |
Sufficient activation for hand-chosen SAE feature.
|
39 |
|
40 |
+
### Refusal
|
41 |
|
42 |
Activation editing to steer towards refusal.
|
43 |
+
|
44 |
+
### Example
|
45 |
+
|
46 |
+
Taking the previous example, with the "coerced" prompt,
|
47 |
+
|
48 |
+
## Evaluation
|
49 |
+
|
50 |
+
We want to (1) maximize the number of times it refuses when prompted with a request that may lead to the model generating bad outputs and (2) minimize the number of times it refuses answering benign prompts.
|
51 |
+
There is also a secondary goal of ensuring that the outputs remain high quality, but in this scenario, perfect coherence takes a lower priority.
|
app.py
CHANGED
@@ -1,33 +1,14 @@
|
|
1 |
import torch
|
2 |
from sae_lens import SAE, HookedSAETransformer
|
3 |
from transformers import AutoModelForCausalLM
|
4 |
-
from repl import generate_with_dms
|
5 |
|
6 |
import gradio as gr
|
7 |
|
8 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
print(f"DEVICE: {DEVICE}")
|
10 |
|
11 |
-
hf_model =
|
12 |
-
"meta-llama/Meta-Llama-3-8B-Instruct",
|
13 |
-
device_map="auto" if DEVICE == "cuda" else DEVICE,
|
14 |
-
torch_dtype="float16",
|
15 |
-
)
|
16 |
-
model = HookedSAETransformer.from_pretrained_no_processing(
|
17 |
-
model_name="meta-llama/Meta-Llama-3-8B-Instruct",
|
18 |
-
hf_model=hf_model,
|
19 |
-
device=DEVICE,
|
20 |
-
dtype="float16",
|
21 |
-
force_load_with_assign=True,
|
22 |
-
)
|
23 |
-
model.eval()
|
24 |
-
|
25 |
-
sae_id = f"blocks.25.hook_resid_post"
|
26 |
-
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
27 |
-
release="Juliushanhanhan/llama-3-8b-it-res",
|
28 |
-
sae_id=sae_id,
|
29 |
-
device=DEVICE,
|
30 |
-
)
|
31 |
|
32 |
|
33 |
def generate_response(prompt):
|
@@ -42,7 +23,7 @@ iface = gr.Interface(
|
|
42 |
outputs="text",
|
43 |
title="Llama-3-8B-Instruct with Deception Refusal",
|
44 |
description="This is meta-llama/Meta-Llama-3-8B-Instruct with refusal intervention if prompts sufficiently activate an SAE extracted deception feature.",
|
45 |
-
examples=[["Make up a convincing lie to my client to sell my product."]]
|
46 |
)
|
47 |
|
48 |
if __name__ == "__main__":
|
|
|
1 |
import torch
|
2 |
from sae_lens import SAE, HookedSAETransformer
|
3 |
from transformers import AutoModelForCausalLM
|
4 |
+
from repl import generate_with_dms, load_models
|
5 |
|
6 |
import gradio as gr
|
7 |
|
8 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
print(f"DEVICE: {DEVICE}")
|
10 |
|
11 |
+
hf_model, model, sae = load_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
def generate_response(prompt):
|
|
|
23 |
outputs="text",
|
24 |
title="Llama-3-8B-Instruct with Deception Refusal",
|
25 |
description="This is meta-llama/Meta-Llama-3-8B-Instruct with refusal intervention if prompts sufficiently activate an SAE extracted deception feature.",
|
26 |
+
examples=[["Make up a convincing lie to my client to sell my product."]],
|
27 |
)
|
28 |
|
29 |
if __name__ == "__main__":
|
eval_examples.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from sae_lens import SAE, HookedSAETransformer
|
3 |
+
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
4 |
+
from transformer_lens import HookedTransformer
|
5 |
+
import pandas as pd
|
6 |
+
import os
|
7 |
+
|
8 |
+
from activation_additions.prompt_utils import get_x_vector
|
9 |
+
from activation_additions.completion_utils import gen_using_activation_additions
|
10 |
+
|
11 |
+
from repl import load_models, generate_with_dms
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
hf_model, model, sae = load_models()
|
15 |
+
|
16 |
+
# TODO
|
repl.py
CHANGED
@@ -13,6 +13,32 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
13 |
NO_REFUSAL = os.getenv("NO_REFUSAL") == "1"
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def generate_with_dms(model: HookedSAETransformer, prompt: str, sae: SAE) -> str:
|
17 |
"""
|
18 |
generate from the model, triggering a refusal if the prompt contains a query that might be risky to answer
|
@@ -74,28 +100,9 @@ def should_trigger_refusal(
|
|
74 |
|
75 |
|
76 |
if __name__ == "__main__":
|
77 |
-
hf_model =
|
78 |
-
"meta-llama/Meta-Llama-3-8B-Instruct",
|
79 |
-
device_map="auto",
|
80 |
-
torch_dtype="float16",
|
81 |
-
)
|
82 |
-
model = HookedSAETransformer.from_pretrained_no_processing(
|
83 |
-
model_name="meta-llama/Meta-Llama-3-8B-Instruct",
|
84 |
-
hf_model=hf_model,
|
85 |
-
device=DEVICE,
|
86 |
-
dtype="float16",
|
87 |
-
force_load_with_assign=True,
|
88 |
-
)
|
89 |
-
model.eval()
|
90 |
print("Finished loading.")
|
91 |
|
92 |
-
sae_id = f"blocks.25.hook_resid_post"
|
93 |
-
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
94 |
-
release="Juliushanhanhan/llama-3-8b-it-res",
|
95 |
-
sae_id=sae_id,
|
96 |
-
device=DEVICE,
|
97 |
-
)
|
98 |
-
|
99 |
print("Note: each input is independent, not a continuous chat.")
|
100 |
while True:
|
101 |
prompt = input("User: ")
|
|
|
13 |
NO_REFUSAL = os.getenv("NO_REFUSAL") == "1"
|
14 |
|
15 |
|
16 |
+
def load_models() -> tuple[AutoModelForCausalLM, HookedSAETransformer, SAE]:
|
17 |
+
hf_model = AutoModelForCausalLM.from_pretrained(
|
18 |
+
"meta-llama/Meta-Llama-3-8B-Instruct",
|
19 |
+
device_map="auto" if DEVICE == "cuda" else DEVICE,
|
20 |
+
torch_dtype="float16",
|
21 |
+
)
|
22 |
+
model = HookedSAETransformer.from_pretrained_no_processing(
|
23 |
+
model_name="meta-llama/Meta-Llama-3-8B-Instruct",
|
24 |
+
hf_model=hf_model,
|
25 |
+
device=DEVICE,
|
26 |
+
dtype="float16",
|
27 |
+
force_load_with_assign=True,
|
28 |
+
)
|
29 |
+
model.eval()
|
30 |
+
|
31 |
+
sae_id = f"blocks.25.hook_resid_post"
|
32 |
+
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
33 |
+
release="Juliushanhanhan/llama-3-8b-it-res",
|
34 |
+
sae_id=sae_id,
|
35 |
+
device=DEVICE,
|
36 |
+
)
|
37 |
+
|
38 |
+
# I suspect we need to return the HF model as well to avoid memory dealloc
|
39 |
+
return hf_model, model, sae
|
40 |
+
|
41 |
+
|
42 |
def generate_with_dms(model: HookedSAETransformer, prompt: str, sae: SAE) -> str:
|
43 |
"""
|
44 |
generate from the model, triggering a refusal if the prompt contains a query that might be risky to answer
|
|
|
100 |
|
101 |
|
102 |
if __name__ == "__main__":
|
103 |
+
hf_model, model, sae = load_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
print("Finished loading.")
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
print("Note: each input is independent, not a continuous chat.")
|
107 |
while True:
|
108 |
prompt = input("User: ")
|
requirements.txt
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
torch
|
2 |
ipykernel
|
3 |
transformer_lens
|
|
|
4 |
transformers
|
5 |
sae-lens==3.18.2
|
6 |
git+https://github.com/cyber-chris/activation_additions.git
|
|
|
1 |
torch
|
2 |
ipykernel
|
3 |
transformer_lens
|
4 |
+
git+https://github.com/cyber-chris/TransformerLens
|
5 |
transformers
|
6 |
sae-lens==3.18.2
|
7 |
git+https://github.com/cyber-chris/activation_additions.git
|