File size: 1,622 Bytes
56fc52c
 
 
 
05f2ce9
 
 
 
 
 
 
 
 
56fc52c
05f2ce9
56fc52c
05f2ce9
 
 
 
 
 
 
 
7b67709
 
05f2ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
---
tags:
- model_hub_mixin
- pytorch_model_hub_mixin
- crosscoder
license: mit
datasets:
- HuggingFaceFW/fineweb
- lmsys/lmsys-chat-1m
base_model:
- google/gemma-2-2b-it
- google/gemma-2-2b
pipeline_tag: feature-extraction
---
This crosscoder was trained on parallel activations of Gemma 2 2B and Gemma 2 2B IT at layer 13 on a subset of fineweb and lsmsy-chat-1m dataset.

You can load it using our branch of the `dictionary_learning` library:
```py
!pip install git+https://github.com/jkminder/dictionary_learning
from dictionary_learning import CrossCoder
from nnsight import LanguageModel
import torch as th

crosscoder = CrossCoder.from_pretrained("Butanium/gemma-2-2b-crosscoder-l13-mu4.1e-02-lr1e-04", from_hub=True)
gemma_2 = LanguageModel("google/gemma-2-2b", device_map="cuda:0")
gemma_2_it = LanguageModel("google/gemma-2-2b-it", device_map="cuda:1")
prompt = "quick fox brown"

with gemma_2.trace(prompt):
    l13_act_base = gemma_2.model.layers[13].output[0][:, -1].save() # (1, 2304)
    gemma_2.model.layers[13].output.stop()

with gemma_2_it.trace(prompt):
    l13_act_it = gemma_2_it.model.layers[13].output[0][:, -1].save() # (1, 2304)
    gemma_2_it.model.layers[13].output.stop()


crosscoder_input = th.cat([l13_act_base, l13_act_it], dim=0).unsqueeze(0).cpu() # (batch, 2, 2304)
print(crosscoder_input.shape)
reconstruction, features = crosscoder(crosscoder_input, output_features=True)

# print metrics
print(f"MSE loss: {th.nn.functional.mse_loss(reconstruction, crosscoder_input).item():.2f}")
print(f"L1 sparsity: {features.abs().sum():.1f}")
print(f"L0 sparsity: {(features > 1e-4).sum()}")
```