|
--- |
|
license: cc-by-4.0 |
|
--- |
|
|
|
# Finetuned `Gemma-2-2B` for generating subspaces given any natural language descriptions for `Gemma-2-2B-it` |
|
|
|
In the AxBench paper, we finetuned a subspace generator. The subspace generator is a hyper-network that will generate a subspace for you given a concept description in natural language. **High-quality subspace generator can bypass all dictionary training!** |
|
|
|
## How to use the subspace generator? |
|
|
|
```py |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
class RegressionWrapper(torch.nn.Module): |
|
def __init__(self, base_model, hidden_size, output_dim): |
|
super().__init__() |
|
self.base_model = base_model |
|
self.regression_head = torch.nn.Linear(hidden_size, output_dim) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.base_model.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
output_hidden_states=True, |
|
return_dict=True |
|
) |
|
last_hiddens = outputs.hidden_states[-1] |
|
last_token_representations = last_hiddens[:, -1] |
|
preds = self.regression_head(last_token_representations) |
|
preds = F.normalize(preds, p=2, dim=-1) |
|
return preds |
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
f"google/gemma-2-2b", torch_dtype=torch.bfloat16) |
|
base_tokenizer = AutoTokenizer.from_pretrained( |
|
f"google/gemma-2-2b", model_max_length=512) |
|
|
|
subspace_gen = RegressionWrapper( |
|
base_model, hidden_size, output_dim).bfloat16().to("cuda") |
|
subspace_gen.load_state_dict(torch.load('model.pth')) |
|
|
|
your_new_concept = "terms related to Stanford University" |
|
|
|
inputs = base_tokenizer(your_new_concept, return_tensors="pt").to("cuda") |
|
input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"] |
|
subspace_gen(input_ids, attention_mask)[0] |
|
``` |
|
|