File size: 2,942 Bytes
88a170b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
#!/usr/bin/env python
import torch
import os
import argparse
import logging
import sys
from tensorizer import TensorSerializer
from transformers import AutoModelForCausalLM, AutoConfig
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
def tensorize_model(
model_name: str,
model_path: str,
tensorizer_path: str,
dtype: str = "fp32",
) -> dict:
"""
Create a tensorized version of model weights. If fp16 or bf16 is True,
the model will be converted to fp16 or bf16.
If `model_path` is None weights will be saved in `./model_weights/torch_weights/model_name`.
If `tensorizer_path` is None weights will be saved in `./model_weights/tensorizer_weights/model_name/dtype_str`.
Args:
model_name (str): Name of model on hugging face hub
model_path (str, optional): Local path where model weights are saved.
tensorizer_path (str, optional): Local path where tensorizer weights are saved.
path (str): Local path where tensorized model weights are saved
dtype (str): One of `"fp32"`, `"fp16"`, and `"bf16"`. Defaults to `"fp32"`.
Returns:
dict: Dictionary containing the tensorized model path and dtype.
"""
if dtype == 'fp32' or dtype is None:
torch_dtype = torch.float32
elif dtype == 'bf16':
torch_dtype = torch.bfloat16
elif dtype == 'fp16':
torch_dtype = torch.float16
logger.info(f"Loading {model_name} in {dtype} from {model_path}...")
model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True,
).to('cuda:0')
logger.info(f"Tensorizing model {model_name} in {dtype} and writing tensors to {tensorizer_path}...")
serializer = TensorSerializer(tensorizer_path)
serializer.write_module(model)
serializer.close()
# Write config to tensorized model weights directory
# dir_path = os.path.dirname(tensorizer_path)
# config_path = os.path.join(dir_path, 'config.json')
model_config = model.config
model_config.save_pretrained(model_name)
logger.info(f"Tensorized model {model_name} in {dtype} and wrote tensors to {tensorizer_path} and config to {config_path}...")
return {"tensorized_weights_path": tensorizer_path, "dtype": dtype}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=(
"A simple script for tensorizing a torch model."
)
)
parser.add_argument("--model_name", type=str)
parser.add_argument("--model_path", type=str, default=None)
parser.add_argument("--tensorizer_path", type=str, default=None)
parser.add_argument("--dtype", type=str, default="fp32")
args = parser.parse_args()
model_info = tensorize_model(
args.model_name,
model_path=args.model_path,
tensorizer_path=args.tensorizer_path,
dtype=args.dtype
) |