Upload folder using huggingface_hub
Browse files- .gitattributes +4 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:correct_symbol)_featurizer +3 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:correct_symbol)_indices +1 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:correct_symbol)_inverse_featurizer +3 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:last_token)_featurizer +3 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:last_token)_indices +1 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:last_token)_inverse_featurizer +3 -0
- featurizer.py +52 -0
- token_position.py +91 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:correct_symbol)_featurizer filter=lfs diff=lfs merge=lfs -text
|
37 |
+
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:correct_symbol)_inverse_featurizer filter=lfs diff=lfs merge=lfs -text
|
38 |
+
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:last_token)_featurizer filter=lfs diff=lfs merge=lfs -text
|
39 |
+
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:last_token)_inverse_featurizer filter=lfs diff=lfs merge=lfs -text
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:correct_symbol)_featurizer
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0a36425ac3f132d84665263a53b6933afd65a48479529d0eb4ba7f75a85932b2
|
3 |
+
size 21531300
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:correct_symbol)_indices
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
null
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:correct_symbol)_inverse_featurizer
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:847ce30d6a5bb0b6da2308c61e49ff979f843eb9c9d13250b2936336c63456d9
|
3 |
+
size 21531356
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:last_token)_featurizer
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ca074860461cf9ede6cf8175332be8bfa7380722184ea0d7969f0796797cf2be
|
3 |
+
size 21531208
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:last_token)_indices
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
null
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer:0,Token:last_token)_inverse_featurizer
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eb1378f957f695d398cdbf01c41dd47a386f950ab94ca2b9d57ddc1b0cd29211
|
3 |
+
size 21531328
|
featurizer.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copy of the existing SubspaceFeaturizer implementation for submission.
|
3 |
+
This file provides the same SubspaceFeaturizer functionality in a self-contained format.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import pyvene as pv
|
9 |
+
from CausalAbstraction.model_units.model_units import Featurizer
|
10 |
+
|
11 |
+
|
12 |
+
class SubspaceFeaturizerModuleCopy(torch.nn.Module):
|
13 |
+
def __init__(self, rotate_layer):
|
14 |
+
super().__init__()
|
15 |
+
self.rotate = rotate_layer
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
r = self.rotate.weight.T
|
19 |
+
f = x.to(r.dtype) @ r.T
|
20 |
+
error = x - (f @ r).to(x.dtype)
|
21 |
+
return f, error
|
22 |
+
|
23 |
+
|
24 |
+
class SubspaceInverseFeaturizerModuleCopy(torch.nn.Module):
|
25 |
+
def __init__(self, rotate_layer):
|
26 |
+
super().__init__()
|
27 |
+
self.rotate = rotate_layer
|
28 |
+
|
29 |
+
def forward(self, f, error):
|
30 |
+
r = self.rotate.weight.T
|
31 |
+
return (f.to(r.dtype) @ r).to(f.dtype) + error.to(f.dtype)
|
32 |
+
|
33 |
+
|
34 |
+
class SubspaceFeaturizerCopy(Featurizer):
|
35 |
+
def __init__(self, shape=None, rotation_subspace=None, trainable=True, id="subspace"):
|
36 |
+
assert shape is not None or rotation_subspace is not None, "Either shape or rotation_subspace must be provided."
|
37 |
+
if shape is not None:
|
38 |
+
self.rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=True)
|
39 |
+
elif rotation_subspace is not None:
|
40 |
+
shape = rotation_subspace.shape
|
41 |
+
self.rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=False)
|
42 |
+
self.rotate.weight.data.copy_(rotation_subspace)
|
43 |
+
self.rotate = torch.nn.utils.parametrizations.orthogonal(self.rotate)
|
44 |
+
|
45 |
+
if not trainable:
|
46 |
+
self.rotate.requires_grad_(False)
|
47 |
+
|
48 |
+
# Create module-based featurizer and inverse_featurizer
|
49 |
+
featurizer = SubspaceFeaturizerModuleCopy(self.rotate)
|
50 |
+
inverse_featurizer = SubspaceInverseFeaturizerModuleCopy(self.rotate)
|
51 |
+
|
52 |
+
super().__init__(featurizer, inverse_featurizer, n_features=self.rotate.weight.shape[1], id=id)
|
token_position.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Token position definitions for MCQA task submission.
|
3 |
+
This file provides token position functions that identify key tokens in MCQA prompts.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import re
|
7 |
+
from CausalAbstraction.model_units.LM_units import TokenPosition
|
8 |
+
|
9 |
+
|
10 |
+
def get_last_token_index(prompt, pipeline):
|
11 |
+
"""
|
12 |
+
Get the index of the last token in the prompt.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
prompt (str): The input prompt
|
16 |
+
pipeline: The tokenizer pipeline
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
list[int]: List containing the index of the last token
|
20 |
+
"""
|
21 |
+
input_ids = list(pipeline.load(prompt)["input_ids"][0])
|
22 |
+
return [len(input_ids) - 1]
|
23 |
+
|
24 |
+
|
25 |
+
def get_correct_symbol_index(prompt, pipeline, task):
|
26 |
+
"""
|
27 |
+
Find the index of the correct answer symbol in the prompt.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
prompt (str): The prompt text
|
31 |
+
pipeline: The tokenizer pipeline
|
32 |
+
task: The task object containing causal model
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
list[int]: List containing the index of the correct answer symbol token
|
36 |
+
"""
|
37 |
+
# Run the model to get the answer position
|
38 |
+
output = task.causal_model.run_forward(task.input_loader(prompt))
|
39 |
+
pointer = output["answer_pointer"]
|
40 |
+
correct_symbol = output[f"symbol{pointer}"]
|
41 |
+
|
42 |
+
# Find all single uppercase letters in the prompt
|
43 |
+
matches = list(re.finditer(r"\b[A-Z]\b", prompt))
|
44 |
+
|
45 |
+
# Find the match corresponding to our correct symbol
|
46 |
+
symbol_match = None
|
47 |
+
for match in matches:
|
48 |
+
if prompt[match.start():match.end()] == correct_symbol:
|
49 |
+
symbol_match = match
|
50 |
+
break
|
51 |
+
|
52 |
+
if not symbol_match:
|
53 |
+
raise ValueError(f"Could not find correct symbol {correct_symbol} in prompt: {prompt}")
|
54 |
+
|
55 |
+
# Get the substring up to the symbol match end
|
56 |
+
substring = prompt[:symbol_match.end()]
|
57 |
+
tokenized_substring = list(pipeline.load(substring)["input_ids"][0])
|
58 |
+
|
59 |
+
# The symbol token will be at the end of the substring
|
60 |
+
return [len(tokenized_substring) - 1]
|
61 |
+
|
62 |
+
|
63 |
+
def get_token_positions(pipeline, task):
|
64 |
+
"""
|
65 |
+
Get token positions for the MCQA task.
|
66 |
+
|
67 |
+
This function identifies key token positions in MCQA prompts:
|
68 |
+
- correct_symbol: The position of the correct answer symbol (A, B, C, or D)
|
69 |
+
- last_token: The position of the last token in the prompt
|
70 |
+
|
71 |
+
Args:
|
72 |
+
pipeline: The language model pipeline with tokenizer
|
73 |
+
task: The MCQA task object
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
list[TokenPosition]: List of TokenPosition objects for intervention experiments
|
77 |
+
"""
|
78 |
+
# Create TokenPosition objects
|
79 |
+
token_positions = [
|
80 |
+
TokenPosition(
|
81 |
+
lambda x: get_correct_symbol_index(x, pipeline, task),
|
82 |
+
pipeline,
|
83 |
+
id="correct_symbol"
|
84 |
+
),
|
85 |
+
TokenPosition(
|
86 |
+
lambda x: get_last_token_index(x, pipeline),
|
87 |
+
pipeline,
|
88 |
+
id="last_token"
|
89 |
+
)
|
90 |
+
]
|
91 |
+
return token_positions
|