Butanium commited on
Commit
05f2ce9
·
verified ·
1 Parent(s): 56fc52c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +41 -3
README.md CHANGED
@@ -2,8 +2,46 @@
2
  tags:
3
  - model_hub_mixin
4
  - pytorch_model_hub_mixin
 
 
 
 
 
 
 
 
 
5
  ---
 
6
 
7
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
- - Library: [More Information Needed]
9
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  tags:
3
  - model_hub_mixin
4
  - pytorch_model_hub_mixin
5
+ - crosscoder
6
+ license: mit
7
+ datasets:
8
+ - HuggingFaceFW/fineweb
9
+ - lmsys/lmsys-chat-1m
10
+ base_model:
11
+ - google/gemma-2-2b-it
12
+ - google/gemma-2-2b
13
+ pipeline_tag: feature-extraction
14
  ---
15
+ This crosscoder was trained on parallel activations of Gemma 2 2B and Gemma 2 2B IT at layer 13 on a subset of fineweb and lsmsy-chat-1m dataset.
16
 
17
+ You can load it using our branch of the `dictionary_learning` library:
18
+ ```py
19
+ !pip install git+https://github.com/jkminder/dictionary_learning
20
+ from dictionary_learning import CrossCoder
21
+ from nnsight import LanguageModel
22
+ from dlabutils import model_path
23
+ import torch as th
24
+
25
+ crosscoder = CrossCoder.from_pretrained("Butanium/gemma-2-2b-crosscoder-l13-mu4.1e-02-lr1e-04", from_hub=True)
26
+ gemma_2 = LanguageModel(model_path("google/gemma-2-2b"), device_map="cuda:0")
27
+ gemma_2_it = LanguageModel(model_path("google/gemma-2-2b-it"), device_map="cuda:1")
28
+ prompt = "quick fox brown"
29
+
30
+ with gemma_2.trace(prompt):
31
+ l13_act_base = gemma_2.model.layers[13].output[0][:, -1].save() # (1, 2304)
32
+ gemma_2.model.layers[13].output.stop()
33
+
34
+ with gemma_2_it.trace(prompt):
35
+ l13_act_it = gemma_2_it.model.layers[13].output[0][:, -1].save() # (1, 2304)
36
+ gemma_2_it.model.layers[13].output.stop()
37
+
38
+
39
+ crosscoder_input = th.cat([l13_act_base, l13_act_it], dim=0).unsqueeze(0).cpu() # (batch, 2, 2304)
40
+ print(crosscoder_input.shape)
41
+ reconstruction, features = crosscoder(crosscoder_input, output_features=True)
42
+
43
+ # print metrics
44
+ print(f"MSE loss: {th.nn.functional.mse_loss(reconstruction, crosscoder_input).item():.2f}")
45
+ print(f"L1 sparsity: {features.abs().sum():.1f}")
46
+ print(f"L0 sparsity: {(features > 1e-4).sum()}")
47
+ ```