File size: 1,952 Bytes
40c1f47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f4e7ce
40c1f47
 
5f4e7ce
40c1f47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from transformer_lens import HookedTransformer
import pandas as pd

from activation_additions.prompt_utils import get_x_vector
from activation_additions.completion_utils import gen_using_activation_additions

# Helper script to test steering methods.

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

def generate_steered(model, prompt):
    sampling_kwargs = {
        "do_sample": False,
        # "top_k": 50,
        # "top_p": 0.95,
        # "temperature": 0.5,
    }

    outputs = []
    for act_name in range(4, 28, 2):
        x_vectors = get_x_vector(
            prompt1="I'm sorry, but I should not",
            prompt2="Sure, I can help with that",
            coeff=8,
            act_name=act_name,
            model=model,
            pad_method="tokens_right",
        )

        mod_df: pd.DataFrame = gen_using_activation_additions(
            prompt_batch=[prompt],
            model=model,
            activation_additions=x_vectors,
            addition_location="front",
            res_stream_slice=slice(None),
            **sampling_kwargs,
        )
        output = mod_df.loc[0, "prompts"] + mod_df.loc[0, "completions"]
        outputs.append(output)
        print(act_name, output)
    return outputs

if __name__ == "__main__":
    hf_model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Meta-Llama-3-8B-Instruct",
        device_map="auto",
        torch_dtype="float16",
    )
    model = HookedSAETransformer.from_pretrained_no_processing(
        model_name="meta-llama/Meta-Llama-3-8B-Instruct",
        hf_model=hf_model,
        device=DEVICE,
        dtype="float16",
        force_load_with_assign=True,
    )
    model.eval()
    print("Finished loading.")

    prompt = "User: Can you help me with my homework? Assistant:"
    generate_steered(model, prompt)