|
import cuda_ext |
|
from model import ExLlama, ExLlamaCache |
|
from tokenizer import ExLlamaTokenizer |
|
from lora import ExLlamaLora |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
MAX_CACHED_STRINGS = 100 |
|
|
|
class ExLlamaAltGenerator: |
|
|
|
class Settings: |
|
|
|
temperature = 0.95 |
|
top_k = 40 |
|
top_p = 0.65 |
|
min_p = 0.0 |
|
typical = 0.0 |
|
|
|
token_repetition_penalty_max = 1.15 |
|
token_repetition_penalty_sustain = -1 |
|
token_repetition_penalty_decay = 0 |
|
|
|
disallowed_tokens: list[int] = None |
|
lora: ExLlamaLora = None |
|
|
|
|
|
|
|
model: ExLlama |
|
cache: ExLlamaCache |
|
tokenizer: ExLlamaTokenizer |
|
tokenizer_cache = {} |
|
|
|
settings: Settings |
|
stop_strings: list = [] |
|
stop_tokens: list = [] |
|
held_text: str = "" |
|
max_stop_tokens: int = 2 |
|
sequence_ids: torch.Tensor = None |
|
sequence_str: str = None |
|
remaining_tokens: int = 0 |
|
|
|
|
|
def __init__(self, model: ExLlama, tokenizer: ExLlamaTokenizer, cache: ExLlamaCache): |
|
|
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.cache = cache |
|
self.settings = ExLlamaAltGenerator.Settings() |
|
|
|
|
|
def cached_tokenize(self, text: str, encode_special_characters = False): |
|
|
|
if text in self.tokenizer_cache: |
|
return self.tokenizer_cache[text] |
|
|
|
while len(self.tokenizer_cache) >= MAX_CACHED_STRINGS: |
|
del self.tokenizer_cache[next(iter(self.tokenizer_cache))] |
|
|
|
new_enc = self.tokenizer.encode(text, encode_special_characters = encode_special_characters) |
|
self.tokenizer_cache[text] = new_enc |
|
return new_enc |
|
|
|
|
|
def get_num_tokens(self, text: str, encode_special_characters = False): |
|
|
|
return self.cached_tokenize(text, encode_special_characters = encode_special_characters).shape[-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def begin_stream(self, prompt: str, stop_conditions: list, max_new_tokens: int, gen_settings: Settings, encode_special_characters = False): |
|
|
|
assert isinstance(prompt, str), "ExLlamaAltGenerator does not support batched generation" |
|
|
|
|
|
|
|
max_input_tokens = self.model.config.max_seq_len - max_new_tokens |
|
self.remaining_tokens = max_new_tokens |
|
|
|
input_ids = self.cached_tokenize(prompt, encode_special_characters) |
|
applied_input_ids = input_ids[:, -max_input_tokens:] |
|
self.sequence_str = self.tokenizer.decode(applied_input_ids)[0] if applied_input_ids.shape[0] < input_ids.shape[0] else prompt |
|
|
|
|
|
|
|
self.stop_strings = [] |
|
self.stop_tokens = [] |
|
for t in stop_conditions: |
|
if isinstance(t, int): self.stop_tokens += [t] |
|
elif isinstance(t, str): self.stop_strings += [t] |
|
else: raise ValueError("Unsupported type in stop_conditions") |
|
|
|
self.held_text = "" |
|
|
|
self.max_stop_tokens = 2 |
|
for ss in self.stop_strings: |
|
self.max_stop_tokens = max(self.max_stop_tokens, self.get_num_tokens(ss) + 2) |
|
|
|
self.settings = gen_settings |
|
|
|
|
|
|
|
self.gen_begin_reuse(applied_input_ids, gen_settings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def stream(self): |
|
|
|
|
|
|
|
if self.remaining_tokens == 0: |
|
self.sequence_str += self.held_text |
|
return self.held_text, True |
|
|
|
self.remaining_tokens -= 1 |
|
|
|
|
|
|
|
old_tail = self.tokenizer.decode(self.sequence_ids[:, -self.max_stop_tokens:])[0] |
|
|
|
|
|
|
|
next_token = self.gen_single_token(self.settings) |
|
|
|
|
|
|
|
if next_token in self.stop_tokens: |
|
self.sequence_str += self.held_text |
|
return self.held_text, True |
|
|
|
|
|
|
|
new_tail = self.tokenizer.decode(self.sequence_ids[:, -(self.max_stop_tokens + 1):])[0] |
|
self.held_text += new_tail[len(old_tail):] |
|
|
|
|
|
|
|
partial_ss = False |
|
for ss in self.stop_strings: |
|
|
|
|
|
|
|
position = self.held_text.find(ss) |
|
if position != -1: |
|
self.sequence_str += self.held_text[:position] |
|
return self.held_text[:position], True |
|
|
|
|
|
|
|
overlap = 0 |
|
for j in range(1, min(len(self.held_text), len(ss)) + 1): |
|
if self.held_text[-j:] == ss[:j]: overlap = j |
|
if overlap > 0: partial_ss = True |
|
|
|
|
|
|
|
if partial_ss: |
|
return "", False |
|
|
|
|
|
|
|
stream_text = self.held_text |
|
self.held_text = "" |
|
self.sequence_str += stream_text |
|
return stream_text, False |
|
|
|
|
|
|
|
|
|
def generate(self, prompt: str, stop_conditions: list, max_new_tokens: int, gen_settings: Settings, encode_special_characters = False): |
|
|
|
self.begin_stream(prompt, stop_conditions, max_new_tokens, gen_settings, encode_special_characters) |
|
response = "" |
|
while True: |
|
chunk, eos = self.stream() |
|
response += chunk |
|
if eos: break |
|
|
|
return response |
|
|
|
|
|
|
|
|
|
def gen_begin(self, in_tokens, gen_settings): |
|
|
|
self.sequence_ids = in_tokens.clone() |
|
self.cache.current_seq_len = 0 |
|
self.model.forward(self.sequence_ids[:, :-1], self.cache, preprocess_only = True, lora = gen_settings.lora) |
|
|
|
|
|
def gen_begin_reuse(self, in_tokens, gen_settings): |
|
|
|
if self.sequence_ids is None or self.cache.current_seq_len == 0: |
|
self.gen_begin(in_tokens, gen_settings) |
|
return |
|
|
|
reuse = 0 |
|
while reuse < self.sequence_ids.shape[-1] and reuse < in_tokens.shape[-1] and self.sequence_ids[0, reuse] == in_tokens[0, reuse]: |
|
reuse += 1 |
|
|
|
if reuse < 2: |
|
self.gen_begin(in_tokens, gen_settings) |
|
return |
|
|
|
self.cache.current_seq_len = reuse - 1 |
|
self.sequence_ids = in_tokens[:, :reuse] |
|
|
|
if reuse < in_tokens.shape[-1]: self.gen_feed_tokens(in_tokens[:, reuse:], gen_settings) |
|
|
|
|
|
def gen_feed_tokens(self, in_tokens, gen_settings): |
|
|
|
if self.sequence_ids is None: |
|
self.gen_begin(in_tokens, gen_settings) |
|
return |
|
|
|
start = self.cache.current_seq_len |
|
self.sequence_ids = torch.cat((self.sequence_ids, in_tokens), dim = 1) |
|
|
|
self.model.forward(self.sequence_ids[:, start : -1], self.cache, preprocess_only = True, lora = gen_settings.lora) |
|
|
|
|
|
|
|
|
|
def gen_single_token(self, gen_settings): |
|
|
|
|
|
|
|
logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, lora = gen_settings.lora) |
|
token, _ = self.sample(logits, gen_settings) |
|
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1) |
|
return token |
|
|
|
|
|
def sample(self, logits, gen_settings): |
|
|
|
cuda_ext.ext_apply_rep_penalty_mask_cpu(self.sequence_ids, |
|
self.settings.token_repetition_penalty_max, |
|
self.settings.token_repetition_penalty_sustain, |
|
self.settings.token_repetition_penalty_decay, |
|
logits) |
|
|
|
logits[:, :, self.tokenizer.bos_token_id] = -10000.0 |
|
|
|
if logits.dim() == 3: logits = logits[0, -1, :] |
|
elif logits.dim() == 2: logits = logits[-1, :] |
|
else: raise ValueError("Bad logits dimension") |
|
|
|
|
|
|
|
if gen_settings.disallowed_tokens is not None: |
|
logits[gen_settings.disallowed_tokens] = float("-inf") |
|
|
|
|
|
|
|
logits /= gen_settings.temperature |
|
logits += 1e-8 |
|
probs = torch.softmax(logits, dim = -1) |
|
|
|
|
|
|
|
if gen_settings.top_k == 0: |
|
top_probs, top_indices = torch.sort(probs, descending = True) |
|
else: |
|
top_probs, top_indices = torch.topk(probs, gen_settings.top_k) |
|
top_probs = F.normalize(top_probs, p = 1, dim = -1) |
|
|
|
|
|
|
|
if gen_settings.top_p > 0.0: |
|
|
|
num_top_p_probs = 0 |
|
cum_prob = top_probs[0].item() |
|
while True: |
|
num_top_p_probs += 1 |
|
if num_top_p_probs == top_probs.shape[-1]: break |
|
if top_probs[num_top_p_probs].item() < gen_settings.min_p: break |
|
cum_prob += top_probs[num_top_p_probs].item() |
|
if cum_prob > gen_settings.top_p: break |
|
|
|
top_probs = top_probs[:num_top_p_probs] |
|
top_probs = F.normalize(top_probs, p = 1, dim = -1) |
|
top_indices = top_indices[:num_top_p_probs] |
|
|
|
|
|
|
|
if gen_settings.typical > 0.0: |
|
|
|
epsilon = 1e-10 |
|
log_probs = (top_probs + epsilon).log() |
|
neg_entropy = (top_probs * log_probs).sum() |
|
entropy_dev = (neg_entropy - log_probs).abs() |
|
_, entropy_dev_order = torch.sort(entropy_dev) |
|
|
|
top_probs = top_probs.gather(-1, entropy_dev_order) |
|
top_indices = top_indices.gather(-1, entropy_dev_order) |
|
|
|
num_typical_probs = 0 |
|
cum_prob = top_probs[0].item() |
|
while True: |
|
num_typical_probs += 1 |
|
if num_typical_probs == top_probs.shape[-1]: break |
|
cum_prob += top_probs[num_typical_probs].item() |
|
if cum_prob > gen_settings.typical: break |
|
|
|
top_probs = top_probs[:num_typical_probs] |
|
top_probs = F.normalize(top_probs, p = 1, dim = -1) |
|
top_indices = top_indices[:num_typical_probs] |
|
|
|
|
|
|
|
sampled_ind = torch.multinomial(top_probs, 1) |
|
sampled_tokens = top_indices[sampled_ind] |
|
sampled_probs = top_probs[sampled_ind] |
|
|
|
|
|
|
|
|
|
|
|
return sampled_tokens.unsqueeze(0), sampled_probs.unsqueeze(0) |
|
|