|
from model import ExLlama, ExLlamaCache, ExLlamaConfig |
|
from tokenizer import ExLlamaTokenizer |
|
import argparse, sys, os, glob |
|
from torch import version as torch_version |
|
from globals import set_affinity_str |
|
|
|
def add_args(parser): |
|
|
|
parser.add_argument("-t", "--tokenizer", type = str, help = "Tokenizer model path") |
|
parser.add_argument("-c", "--config", type = str, help = "Model config path (config.json)") |
|
parser.add_argument("-m", "--model", type = str, help = "Model weights path (.pt or .safetensors file)") |
|
parser.add_argument("-d", "--directory", type = str, help = "Path to directory containing config.json, model.tokenizer and * .safetensors") |
|
|
|
parser.add_argument("-gs", "--gpu_split", type = str, help = "Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. -gs 20,7,7") |
|
parser.add_argument("-l", "--length", type = int, help = "Maximum sequence length", default = 2048) |
|
parser.add_argument("-cpe", "--compress_pos_emb", type = float, help = "Compression factor for positional embeddings", default = 1.0) |
|
parser.add_argument("-a", "--alpha", type = float, help = "alpha for context size extension via embedding extension", default = 1.0) |
|
parser.add_argument("-theta", "--theta", type = float, help = "theta (base) for RoPE embeddings") |
|
|
|
parser.add_argument("-gpfix", "--gpu_peer_fix", action = "store_true", help = "Prevent direct copies of data between GPUs") |
|
|
|
parser.add_argument("-flash", "--flash_attn", nargs = '?', const = 'default', metavar = "METHOD", help = "Use Flash Attention with specified input length (must have Flash Attention 2.0 installed)") |
|
|
|
parser.add_argument("-mmrt", "--matmul_recons_thd", type = int, help = "No. rows at which to use reconstruction and cuBLAS for quant matmul. 0 = never, 1 = always", default = 8) |
|
parser.add_argument("-fmt", "--fused_mlp_thd", type = int, help = "Maximum no. of rows for which to use fused MLP. 0 = never", default = 2) |
|
parser.add_argument("-sdpt", "--sdp_thd", type = int, help = "No. rows at which to switch to scaled_dot_product_attention. 0 = never, 1 = always", default = 8) |
|
parser.add_argument("-mmfr", "--matmul_fused_remap", action = "store_true", help = "Fuse column remapping in Q4 matmul kernel") |
|
parser.add_argument("-nfa", "--no_fused_attn", action = "store_true", help = "Disable fused attention") |
|
|
|
parser.add_argument("-rnnh2", "--rmsnorm_no_half2", action = "store_true", help = "Don't use half2 in RMS norm kernel") |
|
parser.add_argument("-rpnh2", "--rope_no_half2", action = "store_true", help = "Don't use half2 in RoPE kernel") |
|
parser.add_argument("-mmnh2", "--matmul_no_half2", action = "store_true", help = "Don't use half2 in Q4 matmul kernel") |
|
parser.add_argument("-snh2", "--silu_no_half2", action = "store_true", help = "Don't use half2 in SiLU kernel") |
|
parser.add_argument("-nh2", "--no_half2", action = "store_true", help = "(All of the above) disable half2 in all kernela") |
|
parser.add_argument("-fh2", "--force_half2", action = "store_true", help = "Force enable half2 even if unsupported") |
|
parser.add_argument("-cs", "--concurrent_streams", action = "store_true", help = "Use concurrent CUDA streams") |
|
|
|
parser.add_argument("-aff", "--affinity", type = str, help = "Comma-separated list, sets processor core affinity. E.g.: -aff 0,1,2,3") |
|
|
|
|
|
def post_parse(args): |
|
|
|
if args.no_half2 or torch_version.hip and not args.force_half2: |
|
args.rmsnorm_no_half2 = True |
|
args.rope_no_half2 = True |
|
args.matmul_no_half2 = True |
|
args.silu_no_half2 = True |
|
|
|
|
|
|
|
|
|
def get_model_files(args): |
|
|
|
if args.directory is not None: |
|
args.tokenizer = os.path.join(args.directory, "tokenizer.model") |
|
args.config = os.path.join(args.directory, "config.json") |
|
st_pattern = os.path.join(args.directory, "*.safetensors") |
|
st = glob.glob(st_pattern) |
|
if len(st) == 0: |
|
print(f" !! No files matching {st_pattern}") |
|
sys.exit() |
|
|
|
|
|
|
|
args.model = st |
|
else: |
|
if args.tokenizer is None or args.config is None or args.model is None: |
|
print(" !! Please specify either -d or all of -t, -c and -m") |
|
sys.exit() |
|
|
|
|
|
|
|
|
|
def _common_chars(names): |
|
cname = max(names, key = len) |
|
for x in names: |
|
for p, c in enumerate(x): |
|
if c != cname[p] and cname[p] != "*": cname = cname[:p] + "*" + cname[p+1:] |
|
return cname |
|
|
|
def print_options(args, extra_options = None): |
|
|
|
print_opts = [] |
|
if args.gpu_split is not None: print_opts.append(f"gpu_split: {args.gpu_split}") |
|
if args.gpu_peer_fix: print_opts.append("gpu_peer_fix") |
|
if args.affinity: print_opts.append(f" --affinity: {args.affinity}") |
|
|
|
if extra_options is not None: print_opts += extra_options |
|
|
|
print(f" -- Tokenizer: {args.tokenizer}") |
|
print(f" -- Model config: {args.config}") |
|
|
|
if isinstance(args.model, str): print(f" -- Model: {args.model}") |
|
else: print(f" -- Model: {_common_chars(args.model)}") |
|
|
|
print(f" -- Sequence length: {args.length}") |
|
if args.compress_pos_emb != 1.0: |
|
print(f" -- RoPE compression factor: {args.compress_pos_emb}") |
|
|
|
if args.alpha != 1.0: |
|
print(f" -- RoPE alpha factor: {args.alpha}") |
|
|
|
print(f" -- Tuning:") |
|
|
|
if args.flash_attn: print(f" -- --flash_attn") |
|
else: print(f" -- --sdp_thd: {args.sdp_thd}" + (" (disabled)" if args.sdp_thd == 0 else "")) |
|
|
|
print(f" -- --matmul_recons_thd: {args.matmul_recons_thd}" + (" (disabled)" if args.matmul_recons_thd == 0 else "")) |
|
print(f" -- --fused_mlp_thd: {args.fused_mlp_thd}" + (" (disabled)" if args.fused_mlp_thd == 0 else "")) |
|
if args.matmul_fused_remap: print(f" -- --matmul_fused_remap") |
|
if args.no_fused_attn: print(f" -- --no_fused_attn") |
|
if args.rmsnorm_no_half2: print(f" -- --rmsnorm_no_half2") |
|
if args.rope_no_half2: print(f" -- --rope_no_half2") |
|
if args.matmul_no_half2: print(f" -- --matmul_no_half2") |
|
if args.silu_no_half2: print(f" -- --silu_no_half2") |
|
if args.concurrent_streams: print(f" -- --concurrent_streams") |
|
|
|
print(f" -- Options: {print_opts}") |
|
|
|
|
|
|
|
|
|
def make_config(args): |
|
|
|
config = ExLlamaConfig(args.config) |
|
config.model_path = args.model |
|
|
|
config.max_seq_len = args.length |
|
config.compress_pos_emb = args.compress_pos_emb |
|
config.set_auto_map(args.gpu_split) |
|
config.gpu_peer_fix = args.gpu_peer_fix |
|
config.alpha_value = args.alpha |
|
config.calculate_rotary_embedding_base() |
|
|
|
if args.flash_attn: |
|
config.use_flash_attn_2 = True |
|
try: |
|
config.max_input_len = int(args.flash_attn) |
|
except ValueError: |
|
pass |
|
|
|
config.matmul_recons_thd = args.matmul_recons_thd |
|
config.fused_mlp_thd = args.fused_mlp_thd |
|
config.sdp_thd = args.sdp_thd |
|
config.matmul_fused_remap = args.matmul_fused_remap |
|
config.fused_attn = not args.no_fused_attn |
|
|
|
config.rmsnorm_no_half2 = args.rmsnorm_no_half2 |
|
config.rope_no_half2 = args.rope_no_half2 |
|
config.matmul_no_half2 = args.matmul_no_half2 |
|
config.silu_no_half2 = args.silu_no_half2 |
|
config.concurrent_streams = args.concurrent_streams |
|
|
|
if args.theta: |
|
config.rotary_embedding_base = args.theta |
|
|
|
return config |
|
|
|
|
|
|
|
|
|
def set_globals(args): |
|
|
|
if args.affinity: set_affinity_str(args.affinity) |
|
|
|
|
|
|
|
|
|
def print_stats(model): |
|
|
|
print(f" -- Groupsize (inferred): {model.config.groupsize if model.config.groupsize is not None else 'None'}") |
|
print(f" -- Act-order (inferred): {'yes' if model.config.act_order else 'no'}") |
|
if model.config.empty_g_idx: |
|
print(f" !! Model has empty group index (discarded)") |
|
|