amueller commited on
Commit
dadf4bf
·
verified ·
1 Parent(s): 11d6630

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:correct_symbol)_featurizer filter=lfs diff=lfs merge=lfs -text
37
+ 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:correct_symbol)_inverse_featurizer filter=lfs diff=lfs merge=lfs -text
38
+ 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:last_token)_featurizer filter=lfs diff=lfs merge=lfs -text
39
+ 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:last_token)_inverse_featurizer filter=lfs diff=lfs merge=lfs -text
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:correct_symbol)_featurizer ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a36425ac3f132d84665263a53b6933afd65a48479529d0eb4ba7f75a85932b2
3
+ size 21531300
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:correct_symbol)_indices ADDED
@@ -0,0 +1 @@
 
 
1
+ null
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:correct_symbol)_inverse_featurizer ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:847ce30d6a5bb0b6da2308c61e49ff979f843eb9c9d13250b2936336c63456d9
3
+ size 21531356
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:last_token)_featurizer ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca074860461cf9ede6cf8175332be8bfa7380722184ea0d7969f0796797cf2be
3
+ size 21531208
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:last_token)_indices ADDED
@@ -0,0 +1 @@
 
 
1
+ null
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:last_token)_inverse_featurizer ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb1378f957f695d398cdbf01c41dd47a386f950ab94ca2b9d57ddc1b0cd29211
3
+ size 21531328
featurizer.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copy of the existing SubspaceFeaturizer implementation for submission.
3
+ This file provides the same SubspaceFeaturizer functionality in a self-contained format.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import pyvene as pv
9
+ from CausalAbstraction.model_units.model_units import Featurizer
10
+
11
+
12
+ class SubspaceFeaturizerModuleCopy(torch.nn.Module):
13
+ def __init__(self, rotate_layer):
14
+ super().__init__()
15
+ self.rotate = rotate_layer
16
+
17
+ def forward(self, x):
18
+ r = self.rotate.weight.T
19
+ f = x.to(r.dtype) @ r.T
20
+ error = x - (f @ r).to(x.dtype)
21
+ return f, error
22
+
23
+
24
+ class SubspaceInverseFeaturizerModuleCopy(torch.nn.Module):
25
+ def __init__(self, rotate_layer):
26
+ super().__init__()
27
+ self.rotate = rotate_layer
28
+
29
+ def forward(self, f, error):
30
+ r = self.rotate.weight.T
31
+ return (f.to(r.dtype) @ r).to(f.dtype) + error.to(f.dtype)
32
+
33
+
34
+ class SubspaceFeaturizerCopy(Featurizer):
35
+ def __init__(self, shape=None, rotation_subspace=None, trainable=True, id="subspace"):
36
+ assert shape is not None or rotation_subspace is not None, "Either shape or rotation_subspace must be provided."
37
+ if shape is not None:
38
+ self.rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=True)
39
+ elif rotation_subspace is not None:
40
+ shape = rotation_subspace.shape
41
+ self.rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=False)
42
+ self.rotate.weight.data.copy_(rotation_subspace)
43
+ self.rotate = torch.nn.utils.parametrizations.orthogonal(self.rotate)
44
+
45
+ if not trainable:
46
+ self.rotate.requires_grad_(False)
47
+
48
+ # Create module-based featurizer and inverse_featurizer
49
+ featurizer = SubspaceFeaturizerModuleCopy(self.rotate)
50
+ inverse_featurizer = SubspaceInverseFeaturizerModuleCopy(self.rotate)
51
+
52
+ super().__init__(featurizer, inverse_featurizer, n_features=self.rotate.weight.shape[1], id=id)
token_position.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Token position definitions for MCQA task submission.
3
+ This file provides token position functions that identify key tokens in MCQA prompts.
4
+ """
5
+
6
+ import re
7
+ from CausalAbstraction.model_units.LM_units import TokenPosition
8
+
9
+
10
+ def get_last_token_index(prompt, pipeline):
11
+ """
12
+ Get the index of the last token in the prompt.
13
+
14
+ Args:
15
+ prompt (str): The input prompt
16
+ pipeline: The tokenizer pipeline
17
+
18
+ Returns:
19
+ list[int]: List containing the index of the last token
20
+ """
21
+ input_ids = list(pipeline.load(prompt)["input_ids"][0])
22
+ return [len(input_ids) - 1]
23
+
24
+
25
+ def get_correct_symbol_index(prompt, pipeline, task):
26
+ """
27
+ Find the index of the correct answer symbol in the prompt.
28
+
29
+ Args:
30
+ prompt (str): The prompt text
31
+ pipeline: The tokenizer pipeline
32
+ task: The task object containing causal model
33
+
34
+ Returns:
35
+ list[int]: List containing the index of the correct answer symbol token
36
+ """
37
+ # Run the model to get the answer position
38
+ output = task.causal_model.run_forward(task.input_loader(prompt))
39
+ pointer = output["answer_pointer"]
40
+ correct_symbol = output[f"symbol{pointer}"]
41
+
42
+ # Find all single uppercase letters in the prompt
43
+ matches = list(re.finditer(r"\b[A-Z]\b", prompt))
44
+
45
+ # Find the match corresponding to our correct symbol
46
+ symbol_match = None
47
+ for match in matches:
48
+ if prompt[match.start():match.end()] == correct_symbol:
49
+ symbol_match = match
50
+ break
51
+
52
+ if not symbol_match:
53
+ raise ValueError(f"Could not find correct symbol {correct_symbol} in prompt: {prompt}")
54
+
55
+ # Get the substring up to the symbol match end
56
+ substring = prompt[:symbol_match.end()]
57
+ tokenized_substring = list(pipeline.load(substring)["input_ids"][0])
58
+
59
+ # The symbol token will be at the end of the substring
60
+ return [len(tokenized_substring) - 1]
61
+
62
+
63
+ def get_token_positions(pipeline, task):
64
+ """
65
+ Get token positions for the MCQA task.
66
+
67
+ This function identifies key token positions in MCQA prompts:
68
+ - correct_symbol: The position of the correct answer symbol (A, B, C, or D)
69
+ - last_token: The position of the last token in the prompt
70
+
71
+ Args:
72
+ pipeline: The language model pipeline with tokenizer
73
+ task: The MCQA task object
74
+
75
+ Returns:
76
+ list[TokenPosition]: List of TokenPosition objects for intervention experiments
77
+ """
78
+ # Create TokenPosition objects
79
+ token_positions = [
80
+ TokenPosition(
81
+ lambda x: get_correct_symbol_index(x, pipeline, task),
82
+ pipeline,
83
+ id="correct_symbol"
84
+ ),
85
+ TokenPosition(
86
+ lambda x: get_last_token_index(x, pipeline),
87
+ pipeline,
88
+ id="last_token"
89
+ )
90
+ ]
91
+ return token_positions