|
|
|
import torch |
|
import os |
|
import argparse |
|
import logging |
|
import sys |
|
|
|
from tensorizer import TensorSerializer |
|
from transformers import AutoModelForCausalLM, AutoConfig |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO, stream=sys.stdout) |
|
|
|
def tensorize_model( |
|
model_name: str, |
|
model_path: str, |
|
tensorizer_path: str, |
|
dtype: str = "fp32", |
|
) -> dict: |
|
""" |
|
Create a tensorized version of model weights. If fp16 or bf16 is True, |
|
the model will be converted to fp16 or bf16. |
|
|
|
If `model_path` is None weights will be saved in `./model_weights/torch_weights/model_name`. |
|
If `tensorizer_path` is None weights will be saved in `./model_weights/tensorizer_weights/model_name/dtype_str`. |
|
|
|
Args: |
|
model_name (str): Name of model on hugging face hub |
|
model_path (str, optional): Local path where model weights are saved. |
|
tensorizer_path (str, optional): Local path where tensorizer weights are saved. |
|
path (str): Local path where tensorized model weights are saved |
|
dtype (str): One of `"fp32"`, `"fp16"`, and `"bf16"`. Defaults to `"fp32"`. |
|
|
|
Returns: |
|
dict: Dictionary containing the tensorized model path and dtype. |
|
""" |
|
|
|
|
|
if dtype == 'fp32' or dtype is None: |
|
torch_dtype = torch.float32 |
|
|
|
elif dtype == 'bf16': |
|
torch_dtype = torch.bfloat16 |
|
|
|
elif dtype == 'fp16': |
|
torch_dtype = torch.float16 |
|
|
|
logger.info(f"Loading {model_name} in {dtype} from {model_path}...") |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, trust_remote_code=True, |
|
).to('cuda:0') |
|
|
|
logger.info(f"Tensorizing model {model_name} in {dtype} and writing tensors to {tensorizer_path}...") |
|
|
|
serializer = TensorSerializer(tensorizer_path) |
|
serializer.write_module(model) |
|
serializer.close() |
|
|
|
|
|
|
|
|
|
model_config = model.config |
|
model_config.save_pretrained(model_name) |
|
|
|
logger.info(f"Tensorized model {model_name} in {dtype} and wrote tensors to {tensorizer_path} and config to {config_path}...") |
|
|
|
return {"tensorized_weights_path": tensorizer_path, "dtype": dtype} |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser(description=( |
|
"A simple script for tensorizing a torch model." |
|
) |
|
) |
|
|
|
parser.add_argument("--model_name", type=str) |
|
parser.add_argument("--model_path", type=str, default=None) |
|
parser.add_argument("--tensorizer_path", type=str, default=None) |
|
parser.add_argument("--dtype", type=str, default="fp32") |
|
|
|
args = parser.parse_args() |
|
|
|
model_info = tensorize_model( |
|
args.model_name, |
|
model_path=args.model_path, |
|
tensorizer_path=args.tensorizer_path, |
|
dtype=args.dtype |
|
) |