|
import importlib |
|
import json |
|
import os |
|
from typing import List |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from transformers import ( |
|
PretrainedConfig, |
|
PreTrainedModel, |
|
AutoConfig, AutoModelForCausalLM, |
|
) |
|
|
|
from utils.constants import MISTRAL_7B |
|
from utils.utils import _get_submodules |
|
|
|
class Cats(nn.Module): |
|
def __init__( |
|
self, |
|
wrapped_module: nn.Module, |
|
threshold: float = 0, |
|
hist_num_bins: int = 1000, |
|
hist_min: int = -1, |
|
hist_max: int = 1, |
|
): |
|
super(Cats, self).__init__() |
|
self.wrapped_module = wrapped_module |
|
self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False) |
|
self.histogram_bins = torch.linspace(hist_min, hist_max, hist_num_bins - 2) |
|
self.histogram_bins = torch.cat( |
|
[torch.tensor([-torch.inf]), self.histogram_bins, torch.tensor([torch.inf])] |
|
) |
|
self.hist_counts = torch.zeros(hist_num_bins - 1) |
|
self.abs_hist_counts = torch.zeros(hist_num_bins - 1) |
|
self.collect_stats = True |
|
|
|
def disable_collect_stats(self): |
|
self.collect_stats = False |
|
|
|
def enable_collect_stats(self): |
|
self.collect_stats = True |
|
|
|
def set_threshold(self, threshold: float): |
|
self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False) |
|
|
|
def forward(self, x): |
|
x = self.wrapped_module(x) |
|
if self.collect_stats: |
|
self.hist_counts += torch.histogram(x, bins=self.histogram_bins)[0] |
|
self.abs_hist_counts += torch.histogram( |
|
torch.abs(x), bins=self.histogram_bins |
|
)[0] |
|
x[abs(x) < self.threshold] = 0 |
|
return x |
|
|
|
|
|
|
|
def load_data(file_path): |
|
try: |
|
with open(file_path, "r") as json_file: |
|
return json.load(json_file) |
|
except FileNotFoundError: |
|
return {} |
|
|
|
|
|
|
|
def save_to_json(data, file_path): |
|
os.makedirs(os.path.dirname(file_path), exist_ok=True) |
|
with open(file_path, "w") as json_file: |
|
json.dump(data, json_file, indent=4) |
|
|
|
|
|
class CatsConfig(PretrainedConfig): |
|
model_type = "cats_model" |
|
def __init__( |
|
self, |
|
wrapped_model_config=AutoConfig.from_pretrained(MISTRAL_7B), |
|
wrapped_model_class_name: str = "MistralForCausalLM", |
|
target_modules: List[str] = ["act_fn"], |
|
target_sparsity: float = 0.5, |
|
**kwargs, |
|
): |
|
self.target_modules = target_modules |
|
self.target_sparsity = target_sparsity |
|
self.wrapped_model_class_name = wrapped_model_class_name |
|
self.__dict__.update(wrapped_model_config.__dict__) |
|
super().__init__(**kwargs) |
|
|
|
|
|
class CatsModel(PreTrainedModel): |
|
config_class = CatsConfig |
|
|
|
def __init__(self, config, wrapped_model_pretrained_dir: str = None, **kwargs): |
|
super().__init__(config) |
|
transformers_module = importlib.import_module("transformers") |
|
self.wrapped_model_class = getattr(transformers_module, config.wrapped_model_class_name) |
|
self.wrapped_model = self.wrapped_model_class(config) |
|
if wrapped_model_pretrained_dir is not None: |
|
self.wrapped_model = self.wrapped_model_class.from_pretrained(wrapped_model_pretrained_dir) |
|
print(self.__dict__) |
|
self.inject_cats() |
|
|
|
def inject_cats(self): |
|
for name, module in self.wrapped_model.named_modules(): |
|
parent, target, target_name = _get_submodules(self.wrapped_model, name) |
|
if target_name in self.config.target_modules: |
|
print(f"{name} is replaced.") |
|
|
|
|
|
cats = Cats(wrapped_module=target) |
|
setattr(parent, target_name, cats) |
|
|
|
def enable_collect_stats(self): |
|
for module in self.wrapped_model.named_modules(): |
|
if isinstance(module, Cats): |
|
module.enable_collect_stats() |
|
|
|
def disable_adapters(self) -> None: |
|
for module in self.wrapped_model.named_modules(): |
|
if isinstance(module, Cats): |
|
module.disable_collect_stats() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def simple_exp(): |
|
model_dir = MISTRAL_7B |
|
config = AutoConfig.from_pretrained(model_dir) |
|
cats_config = CatsConfig(config, wrapped_model_class_name="MistralForCausalLM") |
|
model = CatsModel(cats_config, wrapped_model_pretrained_dir=None) |
|
print(model) |
|
print(model.wrapped_model) |
|
print(model.config) |
|
|
|
CatsConfig.register_for_auto_class() |
|
CatsModel.register_for_auto_class("AutoModelForCausalLM") |
|
|
|
repo_id = "thrunlab/cats_exp" |
|
model.push_to_hub(repo_id) |
|
model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
simple_exp() |
|
|