Doa-doa's picture
Upload folder using huggingface_hub
72268ee
import sys
min_version = (3, 9)
if sys.version_info < min_version:
print("")
print(f" ## Warning: this project requires Python {min_version[0]}.{min_version[1]} or higher.")
print("")
import torch
from torch import nn
import torch.nn.functional as F
from safetensors import safe_open
import cuda_ext
import json
import math
import gc
from enum import Enum
try:
from flash_attn import flash_attn_func
except:
pass
class ParsedEnum(Enum):
def __str__(self):
return self.name.lower()
def __repr__(self):
return str(self)
@classmethod
def argparse(cls, s):
try:
return cls[s.upper()]
except KeyError:
return s
class ExLlamaConfig:
# Load config from Llama config.json
def __init__(self, model_config_path):
with open(model_config_path) as f:
read_config = json.load(f)
# Loaded/automatic settings
self.bos_token_id = read_config["bos_token_id"] if "bos_token_id" in read_config else 1
self.eos_token_id = read_config["eos_token_id"] if "eos_token_id" in read_config else 2
self.pad_token_id = read_config["pad_token_id"] if "pad_token_id" in read_config else 0
self.hidden_size = read_config["hidden_size"]
self.initializer_range = read_config["initializer_range"]
self.intermediate_size = read_config["intermediate_size"]
self.num_attention_heads = read_config["num_attention_heads"]
self.num_hidden_layers = read_config["num_hidden_layers"]
self.rms_norm_eps = read_config["rms_norm_eps"]
self.vocab_size = read_config["vocab_size"]
if "num_key_value_heads" in read_config:
self.num_key_value_heads = read_config["num_key_value_heads"]
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
else:
self.num_key_value_heads = self.num_attention_heads
self.num_key_value_groups = 1
self.rotary_embedding_base = read_config["rope_theta"] if "rope_theta" in read_config else 10000.0
self.head_dim = self.hidden_size // self.num_attention_heads
self.groupsize = None # Autodetected
self.act_order = False # Autodetected
self.empty_g_idx = False # Autodetected
# Required settings
self.model_path = None # str or list[str]
self.device_map = ExLlamaDeviceMap(self.num_hidden_layers)
# Optional settings
self.max_seq_len = 2048 # Reduce to save memory. Can also be increased, ideally while also using compress_pos_emn and a compatible model/LoRA
self.max_input_len = 2048 # Maximum length of input IDs in a single forward pass. Sequences longer than this will be processed in multiple steps
self.max_attention_size = 2048**2 # Sequences will be processed in chunks to keep the size of the attention weights matrix <= this
self.compress_pos_emb = 1.0 # Increase to compress positional embeddings applied to sequence
self.alpha_value = 1.0 # Alpha value for NTK RoPE scaling. Similar to compress_pos_emb, higher values increaste ctx but add Perplexity.
self.gpu_peer_fix = False # Apparently Torch can have problems transferring tensors directly one GPU to another sometimes. Enable this to expliticly move tensors via system RAM instead, where needed
self.auto_map = None # List of floats with memory allocation in GB, per CUDA device, overrides device_map
# Tuning
self.use_flash_attn_2 = False
self.matmul_recons_thd = 8
self.fused_mlp_thd = 2
self.sdp_thd = 8
self.fused_attn = True
self.matmul_fused_remap = False
self.rmsnorm_no_half2 = False
self.rope_no_half2 = False
self.matmul_no_half2 = False
self.silu_no_half2 = False
self.concurrent_streams = False
# Copy tuning params to C++ extension
def set_tuning_params(self):
cuda_ext.exllama_ext.set_tuning_params(self.matmul_recons_thd,
self.fused_mlp_thd,
self.sdp_thd,
self.matmul_fused_remap,
self.rmsnorm_no_half2,
self.rope_no_half2,
self.matmul_no_half2,
self.silu_no_half2,
self.concurrent_streams)
# Parse and set list of GPU VRAM allocations
def set_auto_map(self, map_string):
if map_string is None: self.auto_map = None
else: self.auto_map = [float(alloc) for alloc in map_string.split(",")]
def calculate_rotary_embedding_base(self):
self.rotary_embedding_base = self.rotary_embedding_base * self.alpha_value ** (self.head_dim / (self.head_dim-2))
# 4-bit linear layer implementation
class Ex4bitLinear:
def __init__(self, config, in_features, out_features, has_bias, tensors, key):
self.config = config
self.key = key
self.in_features = in_features
self.out_features = out_features
self.qweight = tensors[key + ".qweight"]
self.qzeros = tensors[key + ".qzeros"]
self.scales = tensors[key + ".scales"]
self.g_idx = tensors[key + ".g_idx"].cpu() if key + ".g_idx" in tensors else None
self.bias = tensors[key + ".bias"] if has_bias else None
if self.g_idx is not None and (self.g_idx == 0).all():
self.config.empty_g_idx = True
self.g_idx = None
self.device = self.qweight.device
self.device_index = self.device.index
self.q4 = cuda_ext.ext_make_q4(self.qweight,
self.qzeros,
self.scales,
self.g_idx,
self.device_index)
self.height = tensors[key + ".qweight"].shape[0] * 8
self.width = tensors[key + ".qweight"].shape[1]
# Infer groupsize from height of qzeros
self.groupsize = None
if self.qzeros.shape[0] > 1:
self.groupsize = (self.qweight.shape[0] * 8) // self.qzeros.shape[0]
if self.config.groupsize is None:
self.config.groupsize = self.groupsize
# Handle act-order matrix
if self.g_idx is not None:
if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
self.config.act_order = True
def lora_applies(self, lora):
if lora is None: return False
return self.key + ".lora_A.weight" in lora.tensors
def lora_apply(self, lora, x):
lora_a = lora.tensors[self.key + ".lora_A.weight"]
lora_b = lora.tensors[self.key + ".lora_B.weight"]
out = torch.matmul(x, lora_a)
out = torch.matmul(out, lora_b)
# out = cuda_ext.ext_half_matmul(x, lora_a.contiguous(), cublas = True)
# out = cuda_ext.ext_half_matmul(out, lora_b.contiguous(), cublas = True)
return out
def get_lora_tensors_or_meta(self, lora):
if not self.lora_applies(lora):
return cuda_ext.none_tensor, cuda_ext.none_tensor
else:
lora_a = lora.tensors[self.key + ".lora_A.weight"]
lora_b = lora.tensors[self.key + ".lora_B.weight"]
return lora_a, lora_b
def forward(self, x, lora):
if self.lora_applies(lora):
lora_a = lora.tensors[self.key + ".lora_A.weight"]
lora_b = lora.tensors[self.key + ".lora_B.weight"]
out = cuda_ext.ext_q4_matmul(x, self.q4, self.width, lora_a, lora_b)
else:
out = cuda_ext.ext_q4_matmul(x, self.q4, self.width)
# out = cuda_ext.ext_q4_matmul(x, self.q4, self.width)
# if self.lora_applies(lora):
# out += self.lora_apply(lora, x)
if self.bias is not None: out.add_(self.bias)
return out
# Llama MLP
class ExLlamaMLP:
def __init__(self, config, tensors, key):
self.config = config
self.gate_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.intermediate_size, False, tensors, key + ".gate_proj")
self.up_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.intermediate_size, False, tensors, key + ".up_proj")
self.down_proj = Ex4bitLinear(config, self.config.intermediate_size, self.config.hidden_size, False, tensors, key + ".down_proj")
self.act_fn = nn.SiLU()
def fused(self, x, buffer, post_attention_layernorm, lora):
bsz, q_len, _ = x.size()
gate_a, gate_b = self.gate_proj.get_lora_tensors_or_meta(lora)
up_a, up_b = self.up_proj.get_lora_tensors_or_meta(lora)
down_a, down_b = self.down_proj.get_lora_tensors_or_meta(lora)
temp_size = 0
if not gate_a.is_meta: temp_size = max(temp_size, bsz * q_len * gate_a.shape[1])
if not up_a.is_meta: temp_size = max(temp_size, bsz * q_len * up_a.shape[1])
if not down_a.is_meta: temp_size = max(temp_size, bsz * q_len * down_a.shape[1])
if temp_size > 0: lora_temp = torch.empty((1, temp_size), dtype = torch.float16, device = x.device)
else: lora_temp = cuda_ext.none_tensor
cuda_ext.exllama_ext.q4_mlp(x.view(-1, x.shape[-1]),
post_attention_layernorm.weight,
self.config.rms_norm_eps,
self.gate_proj.q4,
self.up_proj.q4,
self.down_proj.q4,
gate_a, gate_b,
up_a, up_b,
down_a, down_b,
lora_temp)
def forward(self, x, buffer, lora):
y = self.gate_proj.forward(x, lora)
y = self.act_fn(y)
y *= self.up_proj.forward(x, lora)
y = self.down_proj.forward(y, lora)
return y
# RMS Layer norm.
class ExLlamaRMSNorm:
def __init__(self, config, tensors, key):
self.config = config
self.variance_epsilon = self.config.rms_norm_eps
self.weight = tensors[key]
def forward(self, hidden_states, buffer):
hidden_states = cuda_ext.ext_rms_norm(hidden_states, self.weight, self.variance_epsilon)
return hidden_states
# Llama attention
class ExLlamaAttention:
def __init__(self, config, tensors, key, sin, cos, index):
self.config = config
self.sin = sin
self.cos = cos
self.index = index
self.q_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_attention_heads * self.config.head_dim, False, tensors, key + ".q_proj")
self.k_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_key_value_heads * self.config.head_dim, False, tensors, key + ".k_proj")
self.v_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_key_value_heads * self.config.head_dim, False, tensors, key + ".v_proj")
self.o_proj = Ex4bitLinear(config, self.config.num_attention_heads * self.config.head_dim, self.config.hidden_size, False, tensors, key + ".o_proj")
def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
# TODO: This seems inefficient. It should be possible to broadcast in the attention matmul to avoid building
# temporary K/V tensors like this
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1: return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def fused(self, hidden_states, cache, buffer, input_layernorm, lora):
bsz, q_len, _ = hidden_states.size()
past_len = cache.current_seq_len
# Lora tensors
q_a, q_b = self.q_proj.get_lora_tensors_or_meta(lora)
k_a, k_b = self.k_proj.get_lora_tensors_or_meta(lora)
v_a, v_b = self.v_proj.get_lora_tensors_or_meta(lora)
o_a, o_b = self.o_proj.get_lora_tensors_or_meta(lora)
temp_size = 0
if not q_a.is_meta: temp_size = max(temp_size, bsz * q_len * q_a.shape[1])
if not k_a.is_meta: temp_size = max(temp_size, bsz * q_len * k_a.shape[1])
if not v_a.is_meta: temp_size = max(temp_size, bsz * q_len * v_a.shape[1])
if not o_a.is_meta: temp_size = max(temp_size, bsz * q_len * o_a.shape[1])
if temp_size > 0: lora_temp = torch.empty((1, temp_size), dtype = torch.float16, device = hidden_states.device)
else: lora_temp = cuda_ext.none_tensor
# Project q, k, v, apply position embeddings to k and v, update cache
query_states = torch.empty((bsz, q_len, self.config.num_attention_heads * self.config.head_dim), dtype = torch.float16, device = hidden_states.device)
key_states = torch.empty((bsz, q_len, self.config.num_key_value_heads * self.config.head_dim), dtype = torch.float16, device = hidden_states.device)
value_states = torch.empty((bsz, q_len, self.config.num_key_value_heads * self.config.head_dim), dtype = torch.float16, device = hidden_states.device)
cuda_ext.exllama_ext.q4_attn(hidden_states,
input_layernorm.weight,
self.config.rms_norm_eps,
query_states,
key_states,
value_states,
self.q_proj.q4,
self.k_proj.q4,
self.v_proj.q4,
self.sin,
self.cos,
q_len,
past_len,
self.config.num_attention_heads,
self.config.num_key_value_heads,
self.config.head_dim,
cache.key_states[self.index],
cache.value_states[self.index],
cache.max_seq_len,
q_a, q_b,
k_a, k_b,
v_a, v_b,
lora_temp)
query_states = query_states.view(bsz, q_len, self.config.num_attention_heads, self.config.head_dim)
# Get k, v with past
key_states = cache.key_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz)
value_states = cache.value_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz)
# Repeat K/V heads if num_key_value_headsn_kv_heads < n_heads
query_states.transpose_(1, 2)
key_states = self.repeat_kv(key_states, self.config.num_key_value_groups)
value_states = self.repeat_kv(value_states, self.config.num_key_value_groups)
# Attention
# TODO: Figure out if we can use cublasHgemmStridedBatched() to do this matmul without reshaping. Torch uses
# gemmStridedBatchedEx() internally, so it should be possible.
# -- Flash Attention 2.0
if self.config.use_flash_attn_2 and (past_len == 0 or q_len == 1):
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states = query_states.transpose(1, 2)
attn_output = flash_attn_func(query_states, key_states, value_states, causal = (past_len == 0))
# -- HF Transformers regular attention, faster on shorter sequences, same VRAM usage
else:
key_states.transpose_(2, 3)
attn_weights = torch.matmul(query_states, key_states)
attn_weights /= math.sqrt(self.config.head_dim)
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size)
# Output projection
cuda_ext.exllama_ext.q4_attn_2(hidden_states,
attn_output,
self.o_proj.q4,
o_a, o_b,
lora_temp)
# return hidden_states
def forward(self, hidden_states, cache, buffer, lora):
bsz, q_len, _ = hidden_states.size()
past_len = cache.current_seq_len
# Project q, k, v, apply position embeddings to k and v
query_states = self.q_proj.forward(hidden_states, lora)
key_states = self.k_proj.forward(hidden_states, lora)
cuda_ext.exllama_ext.rope_(query_states, self.sin, self.cos, past_len, self.config.num_attention_heads, self.config.head_dim)
cuda_ext.exllama_ext.rope_(key_states, self.sin, self.cos, past_len, self.config.num_key_value_heads, self.config.head_dim)
query_states = query_states.view(bsz, q_len, self.config.num_attention_heads, self.config.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.config.num_key_value_heads, self.config.head_dim).transpose(1, 2)
value_states = self.v_proj.forward(hidden_states, lora).view(bsz, q_len, self.config.num_key_value_heads, self.config.head_dim).transpose(1, 2)
# Add keys and values to cache
new_keys = cache.key_states[self.index].narrow(2, past_len, q_len).narrow(0, 0, bsz)
new_values = cache.value_states[self.index].narrow(2, past_len, q_len).narrow(0, 0, bsz)
new_keys.copy_(key_states)
new_values.copy_(value_states)
# Key/value tensors with past
key_states = cache.key_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz)
value_states = cache.value_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz)
# Attention
# -- Flash Attention 2.0
if self.config.use_flash_attn_2 and (past_len == 0 or q_len == 1):
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states = query_states.transpose(1, 2)
attn_output = flash_attn_func(query_states, key_states, value_states, causal = (past_len == 0))
# -- HF Transformers regular attention, faster on shorter sequences, same VRAM usage
elif self.config.sdp_thd == 0 or q_len < self.config.sdp_thd:
key_states = self.repeat_kv(key_states, self.config.num_key_value_groups)
value_states = self.repeat_kv(value_states, self.config.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
attn_weights /= math.sqrt(self.config.head_dim)
if buffer.attn_mask is not None: attn_weights = attn_weights + buffer.attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2)
# -- Scaled dot-product attention from PyTorch 2, should be comparable to xformers (?)
else:
# Torch's SDP attention has a built-in causal mask feature which we can use only when there is no past, i.e.
# it can only apply a square attention mask. It saves quite a bit of VRAM but in practice Torch seems to use
# the same amount of memory at peak anyway.
#
# TODO: Apparently flash attention is disabled when supplying an attention mask tensor. Figure out if this
# is true and maybe drop SDP altogether. If causal masking in flash-attn is updated eventually there should
# be no need for this anyway.
key_states = self.repeat_kv(key_states, self.config.num_key_value_groups)
value_states = self.repeat_kv(value_states, self.config.num_key_value_groups)
if past_len > 0 or (bsz > 1 and buffer.attn_mask is not None):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = buffer.attn_mask, is_causal = False)
else:
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = None, is_causal = True)
attn_output = attn_output.transpose(1, 2)
# Output projection
attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size)
attn_output = self.o_proj.forward(attn_output, lora)
return attn_output
def _rows(x):
xdp = 1
for y in x.shape[:-1]: xdp *= y
return xdp
class ExLlamaDecoderLayer:
def __init__(self, config, tensors, key, index, sin, cos):
self.config = config
self.index = index
self.self_attn = ExLlamaAttention(self.config, tensors, key + ".self_attn", sin, cos, self.index)
self.mlp = ExLlamaMLP(self.config, tensors, key + ".mlp")
self.input_layernorm = ExLlamaRMSNorm(self.config, tensors, key + ".input_layernorm.weight")
self.post_attention_layernorm = ExLlamaRMSNorm(self.config, tensors, key + ".post_attention_layernorm.weight")
def forward(self, hidden_states, cache, buffer, lora):
# Self-attention
if self.config.fused_attn and _rows(hidden_states) == 1:
self.self_attn.fused(hidden_states, cache, buffer, self.input_layernorm, lora)
else:
residual = hidden_states
hidden_states = self.input_layernorm.forward(hidden_states, buffer)
hidden_states = self.self_attn.forward(hidden_states, cache, buffer, lora)
hidden_states = residual + hidden_states
# MLP
if self.config.fused_mlp_thd > 0 and _rows(hidden_states) <= self.config.fused_mlp_thd:
self.mlp.fused(hidden_states, buffer, self.post_attention_layernorm, lora)
else:
residual = hidden_states
hidden_states = self.post_attention_layernorm.forward(hidden_states, buffer)
hidden_states = self.mlp.forward(hidden_states, buffer, lora)
hidden_states = residual + hidden_states
return hidden_states
# Persistent cache for inference. Allocate the whole thing up front.
class ExLlamaCache:
def __init__(self, model, batch_size = 1, max_seq_len = -1, copy_from = None):
self.model = model
self.config = self.model.config
self.max_seq_len = max_seq_len if max_seq_len != -1 else self.config.max_seq_len
self.batch_size = batch_size
self.key_states = []
self.value_states = []
self.current_seq_len = 0
# Preallocate full-length cache
for i in range(self.config.num_hidden_layers):
if copy_from is None:
p_key_states = torch.zeros(self.batch_size, self.config.num_key_value_heads, self.max_seq_len, self.config.head_dim, dtype = torch.float16, device = self.model.config.device_map.layers[i])
p_value_states = torch.zeros(self.batch_size, self.config.num_key_value_heads, self.max_seq_len, self.config.head_dim, dtype = torch.float16, device = self.model.config.device_map.layers[i])
else:
p_key_states = copy_from.key_states[i].clone()
p_value_states = copy_from.value_states[i].clone()
self.key_states.append(p_key_states)
self.value_states.append(p_value_states)
def zero(self):
for i in range(self.config.num_hidden_layers):
self.key_states[i].zero_()
self.value_states[i].zero_()
def clone(self):
new = ExLlamaCache(self.model, batch_size = self.batch_size, max_seq_len = self.max_seq_len, copy_from = self)
return new
def roll_left(self):
for i in range(self.config.num_hidden_layers):
self.key_states[i] = torch.roll(self.key_states[i], shifts = -1, dims = 2)
self.value_states[i] = torch.roll(self.value_states[i], shifts = -1, dims = 2)
self.current_seq_len -= 1
def copy_states(self, target, from_column, from_columns, to_column, to_columns, from_row, from_rows, to_row, to_rows):
assert from_rows == 1
assert from_columns == to_columns
assert to_column + to_columns <= target.max_seq_len
assert from_column + from_columns <= self.max_seq_len
for i in range(self.config.num_hidden_layers):
source_view_k = self.key_states[i].narrow(0, from_row, from_rows).narrow(2, from_column, from_columns)
source_view_v = self.value_states[i].narrow(0, from_row, from_rows).narrow(2, from_column, from_columns)
target_view_k = target.key_states[i].narrow(0, to_row, to_rows).narrow(2, to_column, to_columns)
target_view_v = target.value_states[i].narrow(0, to_row, to_rows).narrow(2, to_column, to_columns)
if to_rows > 1:
source_view_k = source_view_k.expand_as(target_view_k)
source_view_v = source_view_v.expand_as(target_view_v)
target_view_k.copy_(source_view_k)
target_view_v.copy_(source_view_v)
# Device map for the model.
class ExLlamaDeviceMap:
def __init__(self, num_layers):
self.num_layers = num_layers
self.embed_tokens = "cpu" # Embedding table on CPU saves 400 MB on the 30B model with no measurable impact on performance
self.lm_head = "cuda:0"
self.norm = "cuda:0"
self.layers = ["cuda:0"] * self.num_layers
def get_layers_devs(self):
return sorted(list(set(self.layers)))
def get_all_devs(self):
return sorted(list(set(self.layers + [self.lm_head, self.norm, self.embed_tokens])))
def map(self, key):
if key.startswith("lm_head."): return self.lm_head
if key.startswith("model.embed_tokens."): return self.embed_tokens
if key.startswith("model.norm."): return self.norm
if key.startswith("model.layers."):
num = int(key.split(".")[2])
return self.layers[num]
raise ValueError("Unknown key: " + key)
class ExLlamaBuffer:
config: ExLlamaConfig
def __init__(self, config):
self.config = config
# Attention mask
attn_mask: torch.Tensor = None
# Move to device
def to(self, device):
new = ExLlamaBuffer(self.config)
new.attn_mask = None if self.attn_mask is None else _move_tensor(self.attn_mask, device, "attn_mask", self.config)
return new
def _device_to_int(device):
return int(device[device.find(":") + 1:])
def _skip_key(key):
if key.endswith("_proj.bias"): return True
if key.endswith(".rotary_emb.inv_freq"): return True
return False
def _move_tensor(tensor, new_device, name, config):
device = str(tensor.device)
if device == new_device: return tensor
if config.gpu_peer_fix:
if str(device).startswith("cuda:") and str(new_device).startswith("cuda:"):
tensor = tensor.to("cpu")
return tensor.to(new_device)
def _layer_dtype_size(key):
if key.endswith(".weight"): return 2
if key.endswith(".qweight"): return 4
if key.endswith(".qzeros"): return 4
if key.endswith(".scales"): return 2
if key.endswith(".g_idx"): return 0
raise ValueError("Unrecognized layer: " + key)
class ExLlama:
def __init__(self, config):
self.config = config
# Copy tuning parameters to C++ extension
self.config.set_tuning_params()
# Read tensor list from file(s)
if isinstance(self.config.model_path, str): model_path = [self.config.model_path]
else: model_path = self.config.model_path
# Read tensor list from file(s), and measure layer sizes
load_keys = {}
decoder_size = 0
norm_size = 0
head_size = 0
for path in model_path:
with safe_open(path, framework = "pt", device = "cpu") as f:
for key in f.keys():
if _skip_key(key): continue
load_keys[key] = path
if key.startswith("model.layers.0."):
tensor_slice = f.get_slice(key)
shape = tensor_slice.get_shape()
decoder_size += math.prod(shape) * _layer_dtype_size(key)
del tensor_slice
if key.startswith("model.norm."):
tensor_slice = f.get_slice(key)
shape = tensor_slice.get_shape()
norm_size += math.prod(shape) * _layer_dtype_size(key)
del tensor_slice
if key.startswith("lm_head."):
tensor_slice = f.get_slice(key)
shape = tensor_slice.get_shape()
head_size += math.prod(shape) * _layer_dtype_size(key)
del tensor_slice
# Begin auto mapping if enabled
if self.config.auto_map is not None:
self.config.device_map.embed_tokens = "cpu"
self.config.device_map.layers = ["cuda:0"] + ["?"] * (self.config.num_hidden_layers - 1)
# Assign layers automatically
device_usage = 0
device_index = 0
layer_index_device = 0
max_usage = self.config.auto_map[device_index] * (1024 ** 3)
for layer in range(self.config.num_hidden_layers + 2):
this_layer_size = decoder_size
if layer == self.config.num_hidden_layers + 0: this_layer_size = norm_size
elif layer == self.config.num_hidden_layers + 1: this_layer_size = head_size
while device_usage + this_layer_size > max_usage:
device_index += 1
device_usage = 0
layer_index_device = 0
max_usage = self.config.auto_map[device_index] * (1024 ** 3)
if device_index >= len(self.config.auto_map): raise ValueError("Model too large for device allocation scheme.")
target = f"cuda:{device_index}"
if layer == self.config.num_hidden_layers + 0: self.config.device_map.norm = target
elif layer == self.config.num_hidden_layers + 1: self.config.device_map.lm_head = target
else: self.config.device_map.layers[layer] = f"cuda:{device_index}"
device_usage += this_layer_size
layer_index_device += 1
# Load up to 1 GB of tensors at a time, closing and reopening the file in between each chunk
max_dq_buffer_size = 0
tensors = {}
st_mem = 0
MAX_ST_MEM = 1024**3
f = None
prev_path = ""
for key, path in load_keys.items():
device = self.config.device_map.map(key)
if f is None or st_mem > MAX_ST_MEM or path != prev_path:
if f is not None: del f
f = safe_open(path, framework = "pt", device = "cpu")
prev_path = path
st_mem = 0
tensor = f.get_tensor(key)
size = tensor.numel() * tensor.element_size()
st_mem += size
if key.endswith(".scales"): tensor = tensor.half()
if key == "lm_head.weight": tensor = tensor.float() if device == "cpu" else tensor.half()
if key == "model.norm.weight": tensor = tensor.half()
if key.endswith(".embed_tokens.weight"): tensor = tensor.half()
if key.endswith(".input_layernorm.weight"): tensor = tensor.half()
if key.endswith(".post_attention_layernorm.weight"): tensor = tensor.half()
if device == "cpu": keep_tensor = tensor.clone()
else: keep_tensor = tensor.to(device)
del tensor
if key.endswith(".qweight"): max_dq_buffer_size = max(max_dq_buffer_size, keep_tensor.numel() * 8)
tensors[key] = keep_tensor
del f
# Head
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias = False, device = "meta")
self.lm_head.weight = nn.Parameter(tensors["lm_head.weight"])
# self.lm_head_data = tensors["lm_head.weight"].transpose(0, 1).contiguous()
# Token embeddings
self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size, self.config.pad_token_id, device = "meta")
self.embed_tokens.weight = nn.Parameter(tensors["model.embed_tokens.weight"])
with torch.no_grad():
self.embed_tokens.weight[self.config.pad_token_id] = 0
# Norm
self.norm = ExLlamaRMSNorm(self.config, tensors, "model.norm.weight")
# Prepare position embeddings for max seq length
devs = self.config.device_map.get_layers_devs()
self.sincos = {}
for device in devs:
inv_freq = 1.0 / (self.config.rotary_embedding_base ** (torch.arange(0, self.config.head_dim, 2, device = device).float() / self.config.head_dim))
t = torch.arange(self.config.max_seq_len, device = device, dtype = torch.float32)
if self.config.compress_pos_emb != 1.0: t /= self.config.compress_pos_emb
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim = -1)
sin = emb.sin()[None, None, :, :].half()
cos = emb.cos()[None, None, :, :].half()
self.sincos[device] = (sin, cos)
# Decoder layers
modules = []
device_layer_index = [0] * len(devs)
for i in range(self.config.num_hidden_layers):
device = self.config.device_map.layers[i]
sin, cos = self.sincos[device]
layer = ExLlamaDecoderLayer(self.config, tensors, f"model.layers.{i}", i, sin, cos)
modules.append(layer)
self.layers = modules
# Prepare CUDA buffers
self.buffers = []
for dev in self.config.device_map.get_layers_devs():
device_buffers = {}
self.buffers.append(device_buffers)
temp_state = torch.zeros((config.max_input_len, config.intermediate_size), dtype = torch.float16, device = dev)
temp_mlp = torch.zeros((config.fused_mlp_thd * 2, config.intermediate_size), dtype = torch.float16, device = dev)
temp_zeros_float = torch.zeros((1, 65536), dtype = torch.float32, device = dev)
temp_dq = torch.zeros((1, max_dq_buffer_size), dtype = torch.float16, device = dev)
device_buffers["temp_state"] = temp_state
device_buffers["temp_mlp"] = temp_mlp
device_buffers["temp_zeros_float"] = temp_zeros_float
device_buffers["temp_dq"] = temp_dq
cuda_ext.exllama_ext.prepare_buffers(torch.device(dev),
temp_state,
temp_mlp,
temp_zeros_float,
temp_dq)
# Clear the cache
torch.cuda.empty_cache()
def forward(self,
input_ids,
cache,
last_id_only = True,
preprocess_only = False,
lora = None,
output_device = None,
input_mask = None):
q_len = input_ids.shape[-1]
remaining_q_len = q_len
bsz = input_ids.shape[0]
assert input_mask is None or (input_mask.shape[-1] >= input_ids.shape[-1] and input_mask.shape[-2] == input_ids.shape[-2])
# The buffers can only fit max_input_len tokens, so with larger batch sizes we reduce our work size correspondingly.
effective_max_input_len = self.config.max_input_len // bsz
# Split sequence
result = None
chunk_begin = 0
while chunk_begin < q_len:
# Limit chunk_size to max_input_len
chunk_size = min(remaining_q_len, effective_max_input_len)
# Limit chunk_size to keep size of attention operation <= max_attention_size, unless using flash-attn
if not self.config.use_flash_attn_2 or chunk_begin > 0:
past_len = cache.current_seq_len
attn_size = (past_len + remaining_q_len) * remaining_q_len
max_a = self.config.max_attention_size
if attn_size > max_a:
cs = (math.sqrt(past_len ** 2 + 4 * max_a) - past_len) / 2
chunk_size = min(chunk_size, math.floor(cs))
# Process chunk
chunk_end = min(chunk_begin + chunk_size, q_len)
_last_id_only = last_id_only
_preprocess_only = preprocess_only or (chunk_end < q_len and last_id_only)
r = self._forward(input_ids[:, chunk_begin : chunk_end],
cache,
_last_id_only,
_preprocess_only,
lora,
output_device,
input_mask)
if not _preprocess_only:
result = r if result is None else torch.cat((result, r), dim = 1)
chunk_begin = chunk_end
remaining_q_len -= chunk_size
return result
def _forward(self,
input_ids,
cache,
last_id_only = True,
preprocess_only = False,
lora = None,
output_device = None,
input_mask = None):
# if torch.is_grad_enabled():
# raise ValueError("Forward pass called with gradients enabled. Back propagation is not supported yet.")
with torch.no_grad():
batch_size, seq_len = input_ids.shape
past_len = cache.current_seq_len
if output_device is None: output_device = input_ids.device
buffer = ExLlamaBuffer(self.config)
# Build attention mask on first device, copy to others if necessary
devs = self.config.device_map.get_layers_devs()
# if not self.config.use_flash_attn_2:
if seq_len > 1 or input_mask is not None:
attn_mask = torch.zeros(batch_size, 1, seq_len, past_len + seq_len, dtype = torch.float16, device = devs[0])
attn_mask_triu = torch.triu(torch.full((seq_len - 1, seq_len - 1), -65504.))
attn_mask[:, :, : seq_len - 1, past_len + 1: past_len + seq_len] = attn_mask_triu
if input_mask is not None:
input_mask = input_mask[:, :past_len + seq_len]
input_mask = _move_tensor(input_mask, devs[0], "input_mask", self.config)
input_mask = torch.where(input_mask, 0, -65504.).half()
input_mask = input_mask.unsqueeze(1).unsqueeze(2)
attn_mask = torch.minimum(attn_mask, input_mask)
else:
attn_mask = None
# attn_mask = torch.zeros(batch_size, 1, seq_len, seq_len + past_len, dtype = torch.float16, device = devs[0])
buffer.attn_mask = attn_mask
# else:
#
# buffer.attn_mask = None
# Embeddings
# TODO: Allow passing input embeddings instead of IDs
input_ids = _move_tensor(input_ids, self.config.device_map.embed_tokens, "input_ids", self.config)
hidden_states = self.embed_tokens(input_ids)
# Split buffers to devices
buffers = {devs[0]: buffer}
for device in devs[1:]:
buffers[device] = buffer.to(device)
# Decoder layers
for i, decoder_layer in enumerate(self.layers):
device = self.config.device_map.layers[i]
hidden_states = _move_tensor(hidden_states, device, "hidden_states", self.config)
hidden_states = decoder_layer.forward(hidden_states, cache, buffers[device], lora)
cache.current_seq_len += seq_len
# Early exit when we don't need logits
if preprocess_only: return None
# Norm
hidden_states = _move_tensor(hidden_states, self.config.device_map.norm, "hidden_states", self.config)
hidden_states = self.norm.forward(hidden_states, buffer)
# Head
if last_id_only: hidden_states = hidden_states[:, -1:, :].contiguous()
if self.config.device_map.lm_head == "cpu": hidden_states = hidden_states.float()
hidden_states = _move_tensor(hidden_states, self.config.device_map.lm_head, "hidden_states", self.config)
logits = self.lm_head(hidden_states)
# logits = cuda_ext.matmul_half(hidden_states, self.lm_head_data, cublas = False)
logits = logits.float()
logits = _move_tensor(logits, output_device, "logits", self.config)
return logits
# Free unmanaged resources allocated by the C++ extension. Call this before dereferencing the ExLlama object,
# e.g. if you intend to create a new instance to load another model, but don't call it in a destructor that wraps
# the object, since it relies on CUDA function calls and the CUDA context is one of the first things to go when
# a PyTorch application terminates, before other managed objects are destroyed.
def free_unmanaged(self):
cuda_ext.exllama_ext.cleanup()