|
from model import ExLlama, ExLlamaCache, ExLlamaConfig |
|
from tokenizer import ExLlamaTokenizer |
|
from generator import ExLlamaGenerator |
|
import os, glob |
|
|
|
|
|
|
|
model_directory = "/mnt/str/models/llama-13b-4bit-128g/" |
|
|
|
|
|
|
|
tokenizer_path = os.path.join(model_directory, "tokenizer.model") |
|
model_config_path = os.path.join(model_directory, "config.json") |
|
st_pattern = os.path.join(model_directory, "*.safetensors") |
|
model_path = glob.glob(st_pattern) |
|
|
|
|
|
|
|
prompts = [ |
|
"Once upon a time,", |
|
"I don't like to", |
|
"A turbo encabulator is a", |
|
"In the words of Mark Twain," |
|
] |
|
|
|
|
|
|
|
config = ExLlamaConfig(model_config_path) |
|
config.model_path = model_path |
|
|
|
model = ExLlama(config) |
|
tokenizer = ExLlamaTokenizer(tokenizer_path) |
|
|
|
cache = ExLlamaCache(model, batch_size = len(prompts)) |
|
generator = ExLlamaGenerator(model, tokenizer, cache) |
|
|
|
|
|
|
|
generator.disallow_tokens([tokenizer.eos_token_id]) |
|
|
|
generator.settings.token_repetition_penalty_max = 1.2 |
|
generator.settings.temperature = 0.95 |
|
generator.settings.top_p = 0.65 |
|
generator.settings.top_k = 100 |
|
generator.settings.typical = 0.5 |
|
|
|
|
|
|
|
for line in prompts: |
|
print(line) |
|
|
|
output = generator.generate_simple(prompts, max_new_tokens = 200) |
|
|
|
for line in output: |
|
print("---") |
|
print(line) |
|
|