cats_exp / cats.py
vxbrandon's picture
Upload model
ec539e9 verified
import importlib
import json
import os
from typing import List
import numpy as np
import torch
import torch.nn as nn
from transformers import (
PretrainedConfig,
PreTrainedModel,
AutoConfig, AutoModelForCausalLM,
)
from utils.constants import MISTRAL_7B
from utils.utils import _get_submodules
class Cats(nn.Module):
def __init__(
self,
wrapped_module: nn.Module,
threshold: float = 0,
hist_num_bins: int = 1000,
hist_min: int = -1,
hist_max: int = 1,
):
super(Cats, self).__init__()
self.wrapped_module = wrapped_module
self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False)
self.histogram_bins = torch.linspace(hist_min, hist_max, hist_num_bins - 2)
self.histogram_bins = torch.cat(
[torch.tensor([-torch.inf]), self.histogram_bins, torch.tensor([torch.inf])]
)
self.hist_counts = torch.zeros(hist_num_bins - 1)
self.abs_hist_counts = torch.zeros(hist_num_bins - 1)
self.collect_stats = True
def disable_collect_stats(self):
self.collect_stats = False
def enable_collect_stats(self):
self.collect_stats = True
def set_threshold(self, threshold: float):
self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False)
def forward(self, x):
x = self.wrapped_module(x)
if self.collect_stats:
self.hist_counts += torch.histogram(x, bins=self.histogram_bins)[0]
self.abs_hist_counts += torch.histogram(
torch.abs(x), bins=self.histogram_bins
)[0]
x[abs(x) < self.threshold] = 0
return x
# Function to load existing data from a JSON file
def load_data(file_path):
try:
with open(file_path, "r") as json_file:
return json.load(json_file)
except FileNotFoundError:
return {} # Return an empty dictionary if the file does not exist
# Function to save the dictionary to a JSON file
def save_to_json(data, file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as json_file:
json.dump(data, json_file, indent=4)
class CatsConfig(PretrainedConfig):
model_type = "cats_model"
def __init__(
self,
wrapped_model_config=AutoConfig.from_pretrained(MISTRAL_7B),
wrapped_model_class_name: str = "MistralForCausalLM",
target_modules: List[str] = ["act_fn"],
target_sparsity: float = 0.5,
**kwargs,
):
self.target_modules = target_modules
self.target_sparsity = target_sparsity
self.wrapped_model_class_name = wrapped_model_class_name
self.__dict__.update(wrapped_model_config.__dict__)
super().__init__(**kwargs)
class CatsModel(PreTrainedModel):
config_class = CatsConfig
def __init__(self, config, wrapped_model_pretrained_dir: str = None, **kwargs):
super().__init__(config)
transformers_module = importlib.import_module("transformers")
self.wrapped_model_class = getattr(transformers_module, config.wrapped_model_class_name)
self.wrapped_model = self.wrapped_model_class(config)
if wrapped_model_pretrained_dir is not None:
self.wrapped_model = self.wrapped_model_class.from_pretrained(wrapped_model_pretrained_dir)
print(self.__dict__)
self.inject_cats()
def inject_cats(self):
for name, module in self.wrapped_model.named_modules():
parent, target, target_name = _get_submodules(self.wrapped_model, name)
if target_name in self.config.target_modules:
print(f"{name} is replaced.")
# Replace target module with target module + CATS
cats = Cats(wrapped_module=target)
setattr(parent, target_name, cats)
def enable_collect_stats(self):
for module in self.wrapped_model.named_modules():
if isinstance(module, Cats):
module.enable_collect_stats()
def disable_adapters(self) -> None:
for module in self.wrapped_model.named_modules():
if isinstance(module, Cats):
module.disable_collect_stats()
# def __getattr__(self, name: str):
# """Forward missing attributes to the wrapped module."""
# try:
# return super().__getattr__(name) # defer to nn.Module's logic
# except AttributeError:
# return getattr(self.model, name)
def simple_exp():
model_dir = MISTRAL_7B
config = AutoConfig.from_pretrained(model_dir)
cats_config = CatsConfig(config, wrapped_model_class_name="MistralForCausalLM")
model = CatsModel(cats_config, wrapped_model_pretrained_dir=None)
print(model)
print(model.wrapped_model)
print(model.config)
CatsConfig.register_for_auto_class()
CatsModel.register_for_auto_class("AutoModelForCausalLM")
repo_id = "thrunlab/cats_exp"
model.push_to_hub(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True)
if __name__ == "__main__":
simple_exp()