|
from model import ExLlama, ExLlamaCache, ExLlamaConfig |
|
from lora import ExLlamaLora |
|
from tokenizer import ExLlamaTokenizer |
|
from generator import ExLlamaGenerator |
|
import argparse |
|
import torch |
|
import sys |
|
import os |
|
import glob |
|
import model_init |
|
|
|
|
|
|
|
torch.set_grad_enabled(False) |
|
torch.cuda._lazy_init() |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description = "Simple chatbot example for ExLlama") |
|
|
|
model_init.add_args(parser) |
|
|
|
parser.add_argument("-lora", "--lora", type = str, help = "Path to LoRA binary to use during benchmark") |
|
parser.add_argument("-loracfg", "--lora_config", type = str, help = "Path to LoRA config to use during benchmark") |
|
parser.add_argument("-ld", "--lora_dir", type = str, help = "Path to LoRA config and binary. to use during benchmark") |
|
|
|
parser.add_argument("-p", "--prompt", type = str, help = "Prompt file") |
|
parser.add_argument("-un", "--username", type = str, help = "Display name of user", default = "User") |
|
parser.add_argument("-bn", "--botname", type = str, help = "Display name of chatbot", default = "Chatbort") |
|
parser.add_argument("-bf", "--botfirst", action = "store_true", help = "Start chat on bot's turn") |
|
|
|
parser.add_argument("-nnl", "--no_newline", action = "store_true", help = "Do not break bot's response on newline (allow multi-paragraph responses)") |
|
parser.add_argument("-temp", "--temperature", type = float, help = "Temperature", default = 0.95) |
|
parser.add_argument("-topk", "--top_k", type = int, help = "Top-K", default = 20) |
|
parser.add_argument("-topp", "--top_p", type = float, help = "Top-P", default = 0.65) |
|
parser.add_argument("-minp", "--min_p", type = float, help = "Min-P", default = 0.00) |
|
parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Repetition penalty", default = 1.15) |
|
parser.add_argument("-repps", "--repetition_penalty_sustain", type = int, help = "Past length for repetition penalty", default = 256) |
|
parser.add_argument("-beams", "--beams", type = int, help = "Number of beams for beam search", default = 1) |
|
parser.add_argument("-beamlen", "--beam_length", type = int, help = "Number of future tokens to consider", default = 1) |
|
|
|
args = parser.parse_args() |
|
model_init.post_parse(args) |
|
model_init.get_model_files(args) |
|
|
|
|
|
|
|
if args.lora_dir is not None: |
|
args.lora_config = os.path.join(args.lora_dir, "adapter_config.json") |
|
args.lora = os.path.join(args.lora_dir, "adapter_model.bin") |
|
|
|
|
|
|
|
print(f" -- Sequence length: {args.length}") |
|
print(f" -- Temperature: {args.temperature:.2f}") |
|
print(f" -- Top-K: {args.top_k}") |
|
print(f" -- Top-P: {args.top_p:.2f}") |
|
print(f" -- Min-P: {args.min_p:.2f}") |
|
print(f" -- Repetition penalty: {args.repetition_penalty:.2f}") |
|
print(f" -- Beams: {args.beams} x {args.beam_length}") |
|
|
|
print_opts = [] |
|
if args.no_newline: print_opts.append("no_newline") |
|
if args.botfirst: print_opts.append("botfirst") |
|
|
|
model_init.print_options(args, print_opts) |
|
|
|
|
|
|
|
model_init.set_globals(args) |
|
|
|
|
|
|
|
username = args.username |
|
bot_name = args.botname |
|
|
|
if args.prompt is not None: |
|
with open(args.prompt, "r") as f: |
|
past = f.read() |
|
past = past.replace("{username}", username) |
|
past = past.replace("{bot_name}", bot_name) |
|
past = past.strip() + "\n" |
|
else: |
|
past = f"{bot_name}: Hello, {username}\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
config = model_init.make_config(args) |
|
|
|
model = ExLlama(config) |
|
cache = ExLlamaCache(model) |
|
tokenizer = ExLlamaTokenizer(args.tokenizer) |
|
|
|
model_init.print_stats(model) |
|
|
|
|
|
|
|
lora = None |
|
if args.lora: |
|
print(f" -- LoRA config: {args.lora_config}") |
|
print(f" -- Loading LoRA: {args.lora}") |
|
if args.lora_config is None: |
|
print(f" ## Error: please specify lora path to adapter_config.json") |
|
sys.exit() |
|
lora = ExLlamaLora(model, args.lora_config, args.lora) |
|
if lora.bias_ignored: |
|
print(f" !! Warning: LoRA zero bias ignored") |
|
|
|
|
|
|
|
generator = ExLlamaGenerator(model, tokenizer, cache) |
|
generator.settings = ExLlamaGenerator.Settings() |
|
generator.settings.temperature = args.temperature |
|
generator.settings.top_k = args.top_k |
|
generator.settings.top_p = args.top_p |
|
generator.settings.min_p = args.min_p |
|
generator.settings.token_repetition_penalty_max = args.repetition_penalty |
|
generator.settings.token_repetition_penalty_sustain = args.repetition_penalty_sustain |
|
generator.settings.token_repetition_penalty_decay = generator.settings.token_repetition_penalty_sustain // 2 |
|
generator.settings.beams = args.beams |
|
generator.settings.beam_length = args.beam_length |
|
|
|
generator.lora = lora |
|
|
|
break_on_newline = not args.no_newline |
|
|
|
|
|
|
|
min_response_tokens = 4 |
|
max_response_tokens = 256 |
|
extra_prune = 256 |
|
|
|
print(past, end = "") |
|
ids = tokenizer.encode(past) |
|
generator.gen_begin(ids) |
|
|
|
next_userprompt = username + ": " |
|
|
|
first_round = True |
|
|
|
while True: |
|
|
|
res_line = bot_name + ":" |
|
res_tokens = tokenizer.encode(res_line) |
|
num_res_tokens = res_tokens.shape[-1] |
|
|
|
if first_round and args.botfirst: in_tokens = res_tokens |
|
|
|
else: |
|
|
|
|
|
|
|
in_line = input(next_userprompt) |
|
in_line = username + ": " + in_line.strip() + "\n" |
|
|
|
next_userprompt = username + ": " |
|
|
|
|
|
|
|
|
|
past += in_line |
|
|
|
|
|
|
|
|
|
|
|
in_tokens = tokenizer.encode(in_line) |
|
in_tokens = torch.cat((in_tokens, res_tokens), dim = 1) |
|
|
|
|
|
|
|
|
|
expect_tokens = in_tokens.shape[-1] + max_response_tokens |
|
max_tokens = config.max_seq_len - expect_tokens |
|
if generator.gen_num_tokens() >= max_tokens: |
|
generator.gen_prune_to(config.max_seq_len - expect_tokens - extra_prune, tokenizer.newline_token_id) |
|
|
|
|
|
|
|
generator.gen_feed_tokens(in_tokens) |
|
|
|
|
|
|
|
print(res_line, end = "") |
|
sys.stdout.flush() |
|
|
|
generator.begin_beam_search() |
|
|
|
for i in range(max_response_tokens): |
|
|
|
|
|
|
|
if i < min_response_tokens: |
|
generator.disallow_tokens([tokenizer.newline_token_id, tokenizer.eos_token_id]) |
|
else: |
|
generator.disallow_tokens(None) |
|
|
|
|
|
|
|
gen_token = generator.beam_search() |
|
|
|
|
|
|
|
if gen_token.item() == tokenizer.eos_token_id: |
|
generator.replace_last_token(tokenizer.newline_token_id) |
|
|
|
|
|
|
|
num_res_tokens += 1 |
|
text = tokenizer.decode(generator.sequence_actual[:, -num_res_tokens:][0]) |
|
new_text = text[len(res_line):] |
|
|
|
skip_space = res_line.endswith("\n") and new_text.startswith(" ") |
|
res_line += new_text |
|
if skip_space: new_text = new_text[1:] |
|
|
|
print(new_text, end="") |
|
sys.stdout.flush() |
|
|
|
|
|
|
|
if break_on_newline and gen_token.item() == tokenizer.newline_token_id: break |
|
if gen_token.item() == tokenizer.eos_token_id: break |
|
|
|
|
|
|
|
|
|
if res_line.endswith(f"{username}:"): |
|
plen = tokenizer.encode(f"{username}:").shape[-1] |
|
generator.gen_rewind(plen) |
|
next_userprompt = " " |
|
break |
|
|
|
generator.end_beam_search() |
|
|
|
past += res_line |
|
first_round = False |
|
|