Update README.md
Browse files
README.md
CHANGED
@@ -2,8 +2,46 @@
|
|
2 |
tags:
|
3 |
- model_hub_mixin
|
4 |
- pytorch_model_hub_mixin
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
---
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
tags:
|
3 |
- model_hub_mixin
|
4 |
- pytorch_model_hub_mixin
|
5 |
+
- crosscoder
|
6 |
+
license: mit
|
7 |
+
datasets:
|
8 |
+
- HuggingFaceFW/fineweb
|
9 |
+
- lmsys/lmsys-chat-1m
|
10 |
+
base_model:
|
11 |
+
- google/gemma-2-2b-it
|
12 |
+
- google/gemma-2-2b
|
13 |
+
pipeline_tag: feature-extraction
|
14 |
---
|
15 |
+
This crosscoder was trained on parallel activations of Gemma 2 2B and Gemma 2 2B IT at layer 13 on a subset of fineweb and lsmsy-chat-1m dataset.
|
16 |
|
17 |
+
You can load it using our branch of the `dictionary_learning` library:
|
18 |
+
```py
|
19 |
+
!pip install git+https://github.com/jkminder/dictionary_learning
|
20 |
+
from dictionary_learning import CrossCoder
|
21 |
+
from nnsight import LanguageModel
|
22 |
+
from dlabutils import model_path
|
23 |
+
import torch as th
|
24 |
+
|
25 |
+
crosscoder = CrossCoder.from_pretrained("Butanium/gemma-2-2b-crosscoder-l13-mu4.1e-02-lr1e-04", from_hub=True)
|
26 |
+
gemma_2 = LanguageModel(model_path("google/gemma-2-2b"), device_map="cuda:0")
|
27 |
+
gemma_2_it = LanguageModel(model_path("google/gemma-2-2b-it"), device_map="cuda:1")
|
28 |
+
prompt = "quick fox brown"
|
29 |
+
|
30 |
+
with gemma_2.trace(prompt):
|
31 |
+
l13_act_base = gemma_2.model.layers[13].output[0][:, -1].save() # (1, 2304)
|
32 |
+
gemma_2.model.layers[13].output.stop()
|
33 |
+
|
34 |
+
with gemma_2_it.trace(prompt):
|
35 |
+
l13_act_it = gemma_2_it.model.layers[13].output[0][:, -1].save() # (1, 2304)
|
36 |
+
gemma_2_it.model.layers[13].output.stop()
|
37 |
+
|
38 |
+
|
39 |
+
crosscoder_input = th.cat([l13_act_base, l13_act_it], dim=0).unsqueeze(0).cpu() # (batch, 2, 2304)
|
40 |
+
print(crosscoder_input.shape)
|
41 |
+
reconstruction, features = crosscoder(crosscoder_input, output_features=True)
|
42 |
+
|
43 |
+
# print metrics
|
44 |
+
print(f"MSE loss: {th.nn.functional.mse_loss(reconstruction, crosscoder_input).item():.2f}")
|
45 |
+
print(f"L1 sparsity: {features.abs().sum():.1f}")
|
46 |
+
print(f"L0 sparsity: {(features > 1e-4).sum()}")
|
47 |
+
```
|