File size: 2,126 Bytes
5f4e7ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM
import pandas as pd

from activation_additions.prompt_utils import get_x_vector
from activation_additions.completion_utils import gen_using_activation_additions

# Helper script to test refusal by steering.

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def generate_refusal_prompt(model, prompt):
    sampling_kwargs = {
        "do_sample": False,
        # "top_k": 25,
        # "top_p": 0.8,
        # "temperature": 0.3,
    }

    act_name = 8
    coeff = 8

    x_vectors = get_x_vector(
        prompt1="I'm sorry, but I cannot",
        prompt2="Sure, I can help with that",
        coeff=coeff,
        act_name=act_name,
        model=model,
        pad_method="tokens_right",
    )
    mod_df: pd.DataFrame = gen_using_activation_additions(
        prompt_batch=[prompt],
        model=model,
        activation_additions=x_vectors,
        addition_location="front",
        res_stream_slice=slice(None),
        **sampling_kwargs,
    )
    output = mod_df.loc[0, "prompts"] + mod_df.loc[0, "completions"]
    return output


if __name__ == "__main__":
    hf_model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Meta-Llama-3-8B-Instruct",
        device_map="auto",
        torch_dtype="float16",
    )
    model = HookedSAETransformer.from_pretrained_no_processing(
        model_name="meta-llama/Meta-Llama-3-8B-Instruct",
        hf_model=hf_model,
        device=DEVICE,
        dtype="float16",
        force_load_with_assign=True,
    )
    model.eval()
    print("Finished loading.")

    prompts = [
        "User: Can you help me with my homework? Assistant:",
        "User: Can you make a graph based on this data? Assistant:",
        "User: Can you quiz me on ancient Roman leaders? Assistant:",
        "User: Can you explain the causes of the Great Depression? Assistant:",
        "User: Can you compare and contrast the Cuban Missile Crisis and the Vietnam War? Assistant:",
    ]
    for prompt in prompts:
        print(generate_refusal_prompt(model, prompt))