|
from model import ExLlama, ExLlamaCache, ExLlamaConfig |
|
from tokenizer import ExLlamaTokenizer |
|
from generator import ExLlamaGenerator |
|
import torch |
|
import torch.nn.functional as F |
|
import os, glob |
|
import cuda_ext |
|
|
|
|
|
|
|
model_directory = "/mnt/str/models/_test_models/TheBloke_Llama-2-13B-chat-GPTQ/" |
|
|
|
|
|
|
|
tokenizer_path = os.path.join(model_directory, "tokenizer.model") |
|
model_config_path = os.path.join(model_directory, "config.json") |
|
st_pattern = os.path.join(model_directory, "*.safetensors") |
|
model_path = glob.glob(st_pattern) |
|
|
|
|
|
|
|
config = ExLlamaConfig(model_config_path) |
|
config.model_path = model_path |
|
|
|
model = ExLlama(config) |
|
tokenizer = ExLlamaTokenizer(tokenizer_path) |
|
|
|
cache = ExLlamaCache(model, batch_size = 2) |
|
generator = ExLlamaGenerator(model, tokenizer, cache) |
|
|
|
|
|
|
|
generator.settings.token_repetition_penalty_max = 1.15 |
|
generator.settings.temperature = 0.95 |
|
generator.settings.top_k = 40 |
|
generator.settings.top_p = 0.75 |
|
|
|
|
|
|
|
|
|
f1 = \ |
|
"""[INST] <<SYS>> |
|
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. |
|
<</SYS>> |
|
{prompt}[/INST]""" |
|
|
|
f2 = \ |
|
"""[INST] <<SYS>> |
|
<</SYS>> |
|
You are a rude and obnoxious assistant. You hate everything and everyone. |
|
{prompt}[/INST]""" |
|
|
|
|
|
prompts = \ |
|
[ |
|
f1.replace("{prompt}", "Tell me about Homer Simpson"), |
|
f2.replace("{prompt}", "Tell me about Homer Simpson"), |
|
] |
|
|
|
def generate_cfg(prompts, alpha, max_new_tokens): |
|
|
|
ids, mask = tokenizer.encode(prompts, return_mask = True) |
|
generator.gen_begin(ids, mask = mask) |
|
|
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
|
logits = model.forward(generator.sequence[:, -1:], cache, input_mask = mask) |
|
generator.apply_rep_penalty(logits) |
|
|
|
logits = F.log_softmax(logits, dim = -1) |
|
logits_mixed = (1 - alpha) * logits[0] + alpha * logits[1] |
|
|
|
sampled_token, _ = generator.sample_current(logits_mixed) |
|
if sampled_token.item() == tokenizer.eos_token_id: break |
|
|
|
batch_token = sampled_token.repeat(2, 1) |
|
generator.gen_accept_token(batch_token) |
|
|
|
output = tokenizer.decode(generator.sequence[0]) |
|
return output |
|
|
|
for i in range(10): |
|
|
|
alpha = i / 5.0 - 0.4 |
|
print() |
|
print(f"--------------------------------------") |
|
print(f"alpha = {alpha:.1f}") |
|
print(f"--------------------------------------") |
|
output = generate_cfg(prompts, alpha, 200) |
|
print(output[len(prompts[0]):].strip()) |
|
|