|
from model import ExLlama, ExLlamaCache, ExLlamaConfig |
|
from tokenizer import ExLlamaTokenizer |
|
from generator import ExLlamaGenerator |
|
from lora import ExLlamaLora |
|
import perplexity |
|
from perplexity import Perplexity |
|
import time |
|
import torch |
|
import torch.nn.functional as F |
|
import argparse |
|
import json |
|
import math |
|
import sys |
|
import os |
|
import glob |
|
import model_init |
|
|
|
torch.cuda._lazy_init() |
|
|
|
|
|
torch.set_printoptions(precision = 10) |
|
torch_devices = [f"cuda:{i}" for i in range(torch.cuda.device_count())] |
|
|
|
cache = None |
|
model = None |
|
|
|
def begin(): |
|
global model, cache |
|
|
|
if cache is None: cache = ExLlamaCache(model) |
|
else: cache.current_seq_len = 0 |
|
|
|
|
|
def next_logits(input_ids, apply_lora, last_id_only = True, input_mask = None): |
|
global model, cache |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_logits = model.forward(input_ids, cache, last_id_only, lora=apply_lora, input_mask=input_mask) |
|
return n_logits |
|
|
|
|
|
def tokenize(text): |
|
global tokenizer |
|
|
|
return tokenizer.encode(text) |
|
|
|
|
|
def timer(name, func): |
|
t = time.time() |
|
ret = func() |
|
t = time.time() - t |
|
print(f" ** Time, {name}: {t:.2f} seconds") |
|
return ret |
|
|
|
|
|
mem_base = {} |
|
mem_last = {} |
|
for dev in torch_devices: |
|
torch.cuda.reset_peak_memory_stats(dev) |
|
mem_base[dev] = mem_last[dev] = torch.cuda.max_memory_allocated(dev) |
|
|
|
def mem(name, total = False): |
|
global mem_base, mem_last |
|
|
|
res = f" ** VRAM, {name}: " |
|
first = True |
|
|
|
for device in torch_devices: |
|
mem_c = torch.cuda.max_memory_allocated(device) |
|
mem_this = mem_c - mem_last[device] if not total else mem_c - mem_base[device] |
|
mem_last[device] = mem_c |
|
|
|
if not first: res += " - " |
|
first = False |
|
res += f"[{device}] {mem_this / (1024 ** 2):,.2f} MB" |
|
|
|
print(res) |
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description = "Benchmark tests for ExLlama") |
|
|
|
model_init.add_args(parser) |
|
perplexity.add_args(parser) |
|
|
|
parser.add_argument("-p", "--perf", action = "store_true", help = "Benchmark speed and VRAM usage") |
|
parser.add_argument("-v", "--validate", action = "count", help = "Run validation check and generate some sample output; specify twice for a more thorough test") |
|
parser.add_argument("-lora", "--lora", type = str, help = "Path to LoRA binary to use during benchmark") |
|
parser.add_argument("-loracfg", "--lora_config", type = str, help = "Path to LoRA config to use during benchmark") |
|
parser.add_argument("-ld", "--lora_dir", type = str, help = "Path to LoRA config and binary. to use during benchmark") |
|
|
|
args = parser.parse_args() |
|
|
|
model_init.post_parse(args) |
|
perplexity.post_parse(args) |
|
model_init.get_model_files(args) |
|
|
|
|
|
|
|
if args.lora_dir is not None: |
|
args.lora_config = os.path.join(args.lora_dir, "adapter_config.json") |
|
args.lora = os.path.join(args.lora_dir, "adapter_model.bin") |
|
|
|
|
|
|
|
print_opts = [] |
|
if args.perf: print_opts.append("perf") |
|
if args.validate: print_opts.append("validate") |
|
if args.perplexity: print_opts.append("perplexity") |
|
if args.perplexity_token: print_opts.append("perplexity_token") |
|
|
|
model_init.print_options(args, print_opts) |
|
|
|
|
|
|
|
model_init.set_globals(args) |
|
|
|
|
|
|
|
config = model_init.make_config(args) |
|
|
|
model = timer("Load model", lambda: ExLlama(config)) |
|
tokenizer = timer("Load tokenizer", lambda: ExLlamaTokenizer(args.tokenizer)) |
|
|
|
model_init.print_stats(model) |
|
|
|
torch.cuda.reset_peak_memory_stats("cuda") |
|
mem("Model") |
|
|
|
cache = ExLlamaCache(model) |
|
mem("Cache") |
|
|
|
|
|
|
|
lora = None |
|
if args.lora: |
|
print(f" -- LoRA config: {args.lora_config}") |
|
print(f" -- Loading LoRA: {args.lora}") |
|
if args.lora_config is None: |
|
print(f" ## Error: please specify lora path to adapter_config.json") |
|
sys.exit() |
|
lora = ExLlamaLora(model, args.lora_config, args.lora) |
|
if lora.bias_ignored: |
|
print(f" !! Warning: LoRA zero bias ignored") |
|
|
|
|
|
|
|
gen_tokens = 128 |
|
max_seq_len = args.length |
|
ids = torch.randint(0, 31999, (1, max_seq_len - gen_tokens)).cuda() |
|
|
|
|
|
|
|
if args.perf: |
|
|
|
|
|
|
|
for i in range(1, 3): |
|
print(f" -- Warmup pass {i}...") |
|
begin() |
|
logits = timer("Warmup", lambda: next_logits(ids, lora)) |
|
|
|
|
|
|
|
begin() |
|
|
|
t = time.time() |
|
|
|
print(" -- Inference, first pass.") |
|
logits = timer("Inference", lambda: next_logits(ids, lora)) |
|
|
|
t = time.time() - t |
|
print(f" ** Speed: {ids.shape[-1] / t:.2f} tokens/second") |
|
|
|
for j in range(2): |
|
|
|
t = time.time() |
|
print(f" -- Generating {gen_tokens} tokens, {ids.shape[-1]} token prompt...") |
|
for i in range(gen_tokens): |
|
|
|
logits = logits[0, -1, :] |
|
token = torch.argmax(logits) |
|
next_id = token.unsqueeze(0).unsqueeze(0) |
|
logits = next_logits(next_id, lora) |
|
|
|
t = time.time() - t |
|
print(f" ** Speed: {gen_tokens / t:.2f} tokens/second") |
|
|
|
ids = ids[:, :4] |
|
cache.current_seq_len = 4 |
|
|
|
mem("Inference") |
|
mem("Total", total = True) |
|
|
|
|
|
|
|
|
|
if args.perplexity: |
|
|
|
ppl = Perplexity(args.perplexity, model, cache, tokenizer) |
|
|
|
print(" -- Loading dataset...") |
|
|
|
ppl.load(dataset_path = args.perplexity_dataset, |
|
chunk_size = args.perplexity_chunk_size, |
|
chunk_truncate = args.perplexity_chunk_truncate, |
|
overlap = args.perplexity_chunk_overlap, |
|
minlength = args.perplexity_chunk_min, |
|
json_key = args.perplexity_json_key) |
|
|
|
begin() |
|
|
|
ppl.test(args.perplexity_chunk_num, |
|
lora = lora, |
|
ppl_token = args.perplexity_token) |
|
|
|
|
|
|
|
if args.validate: |
|
|
|
ppl = Perplexity(args.perplexity, model, cache, tokenizer) |
|
|
|
ppl.load(dataset_path = "datasets/wikitext2_val_sample.jsonl", |
|
chunk_size = 2048, |
|
chunk_truncate = 2048, |
|
overlap = 0, |
|
minlength = 50, |
|
json_key = "text") |
|
|
|
|
|
|
|
begin() |
|
|
|
ppl.cache.zero() |
|
model.config.matmul_recons_thd = 1 |
|
ppl.test(8, lora = lora, tag = " (reconstruct)") |
|
ppl.cache.zero() |
|
model.config.matmul_recons_thd = 0 |
|
ppl.test(8, lora = lora, tag = " (quant, token)", ppl_token = True) |
|
|
|
|
|
|
|
|
|
model.config.matmul_recons_thd = 4 |
|
generator = ExLlamaGenerator(model, tokenizer, cache) |
|
generator.settings.top_k = 1 |
|
generator.lora = lora |
|
text = generator.generate_simple("To be or not to be, that is the", max_new_tokens = 20 * args.validate) |
|
print(f" ** Generation: {repr(text)}") |
|
|
|
if args.validate > 1: |
|
|
|
|
|
|
|
bsz = 8 |
|
gen_len = 20 |
|
torch.manual_seed(42) |
|
torch.cuda.manual_seed_all(42) |
|
|
|
|
|
|
|
del cache |
|
cache = ExLlamaCache(model, batch_size = bsz) |
|
|
|
|
|
|
|
identical_batch_prompt = "When you have eliminated the impossible, whatever remains," |
|
continuations = [ |
|
" must be considered", |
|
" ought to be", |
|
" (and some scholars say this is", |
|
" however improbable, is a banana.", |
|
] |
|
|
|
prompts = [identical_batch_prompt] * (bsz - len(continuations)) |
|
for cont in continuations: |
|
prompts.append(identical_batch_prompt + cont) |
|
|
|
ids = tokenizer.encode(prompts) |
|
assert ids.shape[1] < model.config.max_seq_len, f"Max length {ids.shape[1]} exceeds model limit {model.config.max_seq_len}" |
|
|
|
mask = ids.ne(tokenizer.pad_token_id) |
|
|
|
|
|
|
|
sequence = torch.empty((bsz, 0), dtype = torch.long, device = "cpu") |
|
logits = next_logits(ids, lora, input_mask = mask) |
|
|
|
for i in range(gen_len): |
|
logits = logits[:, -1, :] |
|
id_per_batch = torch.argmax(logits, dim=-1) |
|
assert id_per_batch.shape == (bsz,), f"{id_per_batch.shape} != {(bsz,)}" |
|
next_id_per_batch = id_per_batch.unsqueeze(-1) |
|
sequence = torch.cat((sequence, next_id_per_batch), dim = -1) |
|
logits = next_logits(next_id_per_batch, lora) |
|
|
|
|
|
|
|
print(f"\n ** Batching sanity check: 1-{bsz - len(continuations)} should be identical. All should be reasonable for the model you're using.\n") |
|
|
|
outputs = tokenizer.decode(sequence) |
|
for b in range(bsz): |
|
print(f"{b + 1} {repr(prompts[b])} -> {repr(outputs[b])}") |
|
|
|
|
|
|