|
import cuda_ext |
|
from model import ExLlama, ExLlamaCache |
|
from lora import ExLlamaLora |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
class ExLlamaGenerator: |
|
|
|
class Settings: |
|
|
|
temperature = 0.95 |
|
top_k = 40 |
|
top_p = 0.65 |
|
min_p = 0.0 |
|
typical = 0.0 |
|
|
|
token_repetition_penalty_max = 1.15 |
|
token_repetition_penalty_sustain = 256 |
|
token_repetition_penalty_decay = 128 |
|
|
|
beams = 1 |
|
beam_length = 1 |
|
|
|
|
|
model: ExLlama |
|
sequence: torch.Tensor or None |
|
sequence_actual: torch.Tensor or None |
|
settings: Settings |
|
beams: int or None |
|
max_beam_length: int |
|
in_beam_search: True |
|
disallowed_tokens: list[int] or None |
|
lora: ExLlamaLora or None |
|
|
|
|
|
def __init__(self, model, tokenizer, cache): |
|
|
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.cache = cache |
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
|
self.cache.current_seq_len = 0 |
|
self.sequence = None |
|
self.sequence_actual = None |
|
self.settings = ExLlamaGenerator.Settings() |
|
|
|
self.beams = None |
|
self.max_beam_length = 0 |
|
self.in_beam_search = False |
|
self.disallowed_tokens = None |
|
self.lora = None |
|
|
|
|
|
def make_rep_mask(self, penalty_max, sustain, decay): |
|
|
|
return cuda_ext.ext_rep_penalty_mask_cpu(self.model.config.vocab_size, self.sequence, penalty_max, sustain, decay) |
|
|
|
|
|
def batched_sample(self, logits, temperature, top_k, top_p, min_p, typical, num = 1): |
|
|
|
if logits.shape[0] == 1: return self.sample(logits, temperature, top_k, top_p, min_p, typical, num) |
|
|
|
samples = [] |
|
scores = [] |
|
for i in range(logits.shape[0]): |
|
t, s = self.sample(logits[i, :, :], temperature, top_k, top_p, min_p, typical) |
|
samples.append(t) |
|
scores.append(s) |
|
|
|
return torch.cat(samples, dim = 0), torch.cat(scores, dim = 0) |
|
|
|
|
|
|
|
|
|
def sample_current(self, logits, num = 1): |
|
|
|
return self.sample(logits, |
|
self.settings.temperature, |
|
self.settings.top_k, |
|
self.settings.top_p, |
|
self.settings.min_p, |
|
self.settings.typical) |
|
|
|
|
|
|
|
|
|
def sample(self, logits, temperature, top_k, top_p, min_p, typical, num = 1): |
|
|
|
|
|
|
|
if logits.dim() == 3: logits = logits[0, -1, :] |
|
elif logits.dim() == 2: logits = logits[-1, :] |
|
else: raise ValueError("Bad logits dimension") |
|
|
|
|
|
|
|
if self.disallowed_tokens is not None: |
|
logits[self.disallowed_tokens] = float("-inf") |
|
|
|
|
|
|
|
logits /= temperature |
|
logits += 1e-8 |
|
probs = torch.softmax(logits, dim = -1) |
|
|
|
|
|
|
|
if top_k == 0: |
|
top_probs, top_indices = torch.sort(probs, descending = True) |
|
else: |
|
top_probs, top_indices = torch.topk(probs, top_k) |
|
top_probs = F.normalize(top_probs, p = 1, dim = -1) |
|
|
|
|
|
|
|
if top_p > 0.0: |
|
|
|
num_top_p_probs = 0 |
|
cum_prob = top_probs[0].item() |
|
while True: |
|
num_top_p_probs += 1 |
|
if num_top_p_probs == top_probs.shape[-1]: break |
|
if top_probs[num_top_p_probs].item() < min_p: break |
|
cum_prob += top_probs[num_top_p_probs].item() |
|
if cum_prob > top_p: break |
|
|
|
top_probs = top_probs[:num_top_p_probs] |
|
top_probs = F.normalize(top_probs, p = 1, dim = -1) |
|
top_indices = top_indices[:num_top_p_probs] |
|
|
|
|
|
|
|
if typical > 0.0: |
|
|
|
epsilon = 1e-10 |
|
log_probs = (top_probs + epsilon).log() |
|
neg_entropy = (top_probs * log_probs).sum() |
|
entropy_dev = (neg_entropy - log_probs).abs() |
|
_, entropy_dev_order = torch.sort(entropy_dev) |
|
|
|
top_probs = top_probs.gather(-1, entropy_dev_order) |
|
top_indices = top_indices.gather(-1, entropy_dev_order) |
|
|
|
num_typical_probs = 0 |
|
cum_prob = top_probs[0].item() |
|
while True: |
|
num_typical_probs += 1 |
|
if num_typical_probs == top_probs.shape[-1]: break |
|
cum_prob += top_probs[num_typical_probs].item() |
|
if cum_prob > typical: break |
|
|
|
top_probs = top_probs[:num_typical_probs] |
|
top_probs = F.normalize(top_probs, p = 1, dim = -1) |
|
top_indices = top_indices[:num_typical_probs] |
|
|
|
|
|
|
|
sampled_ind = torch.multinomial(top_probs, top_probs.shape[-1] if num == -1 else min(num, top_probs.shape[-1])) |
|
sampled_tokens = top_indices[sampled_ind] |
|
sampled_probs = top_probs[sampled_ind] |
|
|
|
if sampled_tokens.shape[0] > 1: |
|
sampled_tokens, ind = sampled_tokens.sort() |
|
sampled_probs = sampled_probs[ind] |
|
|
|
return sampled_tokens.unsqueeze(0), sampled_probs.unsqueeze(0) |
|
|
|
|
|
def disallow_tokens(self, tokens): |
|
|
|
self.disallowed_tokens = tokens |
|
|
|
|
|
def gen_begin(self, in_tokens, mask = None): |
|
|
|
self.end_beam_search() |
|
|
|
self.sequence = in_tokens.clone() |
|
self.sequence_actual = in_tokens.clone() |
|
self.cache.current_seq_len = 0 |
|
|
|
self.model.forward(self.sequence[:, :-1], self.cache, preprocess_only = True, lora = self.lora, input_mask = mask) |
|
|
|
|
|
def gen_begin_empty(self): |
|
|
|
self.end_beam_search() |
|
self.sequence = None |
|
self.sequence_actual = None |
|
self.cache.current_seq_len = 0 |
|
|
|
|
|
def gen_begin_reuse(self, in_tokens, mask = None): |
|
|
|
self.end_beam_search() |
|
if self.sequence is None or self.cache.current_seq_len == 0: |
|
self.gen_begin(in_tokens, mask = mask) |
|
return 0 |
|
|
|
|
|
|
|
|
|
reuse = 0 |
|
while reuse < self.sequence.shape[-1] and reuse < in_tokens.shape[-1] and self.sequence[0, reuse] == in_tokens[0, reuse]: |
|
reuse += 1 |
|
|
|
if reuse < 2: |
|
self.gen_begin(in_tokens, mask = mask) |
|
return 0 |
|
|
|
|
|
|
|
self.cache.current_seq_len = reuse - 1 |
|
self.sequence = self.sequence[:, :reuse] |
|
self.sequence_actual = self.sequence.clone() |
|
|
|
if reuse < in_tokens.shape[-1]: self.gen_feed_tokens(in_tokens[:, reuse:], mask = mask) |
|
return reuse |
|
|
|
|
|
def gen_feed_tokens(self, in_tokens, mask = None): |
|
|
|
if self.sequence is None: |
|
self.gen_begin(in_tokens, mask = mask) |
|
return |
|
|
|
self.end_beam_search() |
|
|
|
start = self.sequence.shape[-1] - 1 |
|
if start < 0: |
|
start = 0 |
|
self.sequence = in_tokens.clone() |
|
else: |
|
self.sequence = torch.cat((self.sequence, in_tokens), dim = 1) |
|
|
|
if start < self.sequence.shape[-1] - 1: |
|
self.model.forward(self.sequence[:, start : -1], self.cache, preprocess_only = True, lora = self.lora, input_mask = mask) |
|
|
|
self.sequence_actual = self.sequence |
|
|
|
|
|
def gen_accept_token(self, token): |
|
|
|
self.end_beam_search() |
|
if self.sequence is None: self.sequence = token |
|
else: self.sequence = torch.cat((self.sequence, token), dim = 1) |
|
self.sequence_actual = self.sequence |
|
|
|
|
|
def gen_rewind(self, num_tokens): |
|
|
|
if num_tokens == 0: return |
|
self.end_beam_search() |
|
self.sequence = self.sequence[:, :-num_tokens] |
|
self.cache.current_seq_len -= num_tokens |
|
self.sequence_actual = self.sequence |
|
|
|
|
|
def gen_prune_right(self, tokens, mask = None): |
|
|
|
self.end_beam_search() |
|
if tokens > self.sequence.shape[-1] - 1: return |
|
self.gen_begin(self.sequence[:, tokens:], mask = mask) |
|
self.sequence_actual = self.sequence |
|
|
|
|
|
def gen_prune_to(self, min_tokens_to_keep, token_id, mask = None): |
|
|
|
self.end_beam_search() |
|
|
|
if self.gen_num_tokens() <= min_tokens_to_keep: return |
|
|
|
while self.gen_num_tokens() > min_tokens_to_keep: |
|
|
|
pruned = False |
|
for i in range(self.sequence.shape[-1] - 1): |
|
if self.sequence[0, i] == token_id: |
|
self.sequence = self.sequence[:, i + 1:] |
|
pruned = True |
|
break |
|
|
|
if not pruned: return |
|
|
|
self.gen_begin(self.sequence, mask = mask) |
|
|
|
|
|
def gen_prune_left(self, num_tokens, mask = None): |
|
|
|
num_tokens = min(num_tokens, self.sequence_actual.shape[-1] - 1) |
|
|
|
if self.in_beam_search: |
|
self.end_beam_search() |
|
self.sequence = self.sequence[:, num_tokens:] |
|
self.begin_beam_search() |
|
else: |
|
self.sequence = self.sequence[:, num_tokens:] |
|
self.gen_begin(self.sequence, mask = mask) |
|
|
|
|
|
def gen_num_tokens(self): |
|
|
|
return self.sequence_actual.shape[-1] |
|
|
|
|
|
|
|
|
|
def generate_simple(self, prompt, max_new_tokens = 128): |
|
|
|
self.end_beam_search() |
|
|
|
ids, mask = self.tokenizer.encode(prompt, return_mask = True, max_seq_len = self.model.config.max_seq_len) |
|
self.gen_begin(ids, mask = mask) |
|
|
|
max_new_tokens = min(max_new_tokens, self.model.config.max_seq_len - ids.shape[1]) |
|
|
|
eos = torch.zeros((ids.shape[0],), dtype = torch.bool) |
|
for i in range(max_new_tokens): |
|
token = self.gen_single_token(mask = mask) |
|
for j in range(token.shape[0]): |
|
if token[j, 0].item() == self.tokenizer.eos_token_id: eos[j] = True |
|
if eos.all(): break |
|
|
|
text = self.tokenizer.decode(self.sequence[0] if self.sequence.shape[0] == 1 else self.sequence) |
|
return text |
|
|
|
|
|
|
|
|
|
def apply_rep_penalty(self, logits): |
|
|
|
cuda_ext.ext_apply_rep_penalty_mask_cpu(self.sequence, |
|
self.settings.token_repetition_penalty_max, |
|
self.settings.token_repetition_penalty_sustain, |
|
self.settings.token_repetition_penalty_decay, |
|
logits) |
|
|
|
|
|
|
|
|
|
def gen_single_token(self, constraints = None, mask = None): |
|
|
|
self.end_beam_search() |
|
|
|
|
|
|
|
if self.sequence is not None: |
|
|
|
logits = self.model.forward(self.sequence[:, -1:], self.cache, lora = self.lora, input_mask = mask) |
|
self.apply_rep_penalty(logits) |
|
|
|
logits[:, :, self.tokenizer.bos_token_id] = -10000.0 |
|
|
|
if constraints is not None: |
|
|
|
for c in constraints: logits[:, :, c] += 10000.0 |
|
logits[:, :, :] -= 10000.0 |
|
|
|
token, _ = self.batched_sample(logits, |
|
self.settings.temperature, |
|
self.settings.top_k, |
|
self.settings.top_p, |
|
self.settings.min_p + 0.01 if constraints is not None else 0.0, |
|
self.settings.typical) |
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
if constraints is not None: |
|
token = constraints[0] |
|
else: |
|
token = torch.Tensor([[self.tokenizer.bos_token_id]]).long() |
|
|
|
self.gen_accept_token(token) |
|
return token |
|
|
|
|
|
|
|
|
|
class Beam: |
|
|
|
sequence: torch.Tensor |
|
probs: torch.Tensor |
|
cache: ExLlamaCache |
|
current_seq_pos: int |
|
settings = None |
|
generator = None |
|
|
|
sampled_tokens: torch.Tensor |
|
sampled_probs: torch.Tensor |
|
moved: bool = False |
|
|
|
def __init__(self, settings, generator, first_token = None, first_prob = None, seq_pos = None): |
|
|
|
self.settings = settings |
|
self.generator = generator |
|
|
|
self.sequence = first_token.unsqueeze(0).unsqueeze(0) if first_token is not None else None |
|
self.probs = first_prob.unsqueeze(0).unsqueeze(0) if first_prob is not None else None |
|
self.cache = ExLlamaCache(self.generator.model, max_seq_len = self.settings.beam_length) |
|
|
|
self.current_seq_pos = seq_pos |
|
|
|
|
|
def __len__(self): |
|
|
|
return self.sequence.shape[-1] |
|
|
|
|
|
def clone(self): |
|
|
|
new = ExLlamaGenerator.Beam(self.settings, self.generator) |
|
new.sequence = self.sequence.clone() |
|
new.probs = self.probs.clone() |
|
new.cache = self.cache.clone() |
|
new.current_seq_pos = self.current_seq_pos |
|
new.sampled_tokens = self.sampled_tokens.clone() |
|
new.sampled_probs = self.sampled_probs.clone() |
|
new.moved = self.moved |
|
return new |
|
|
|
|
|
|
|
|
|
def advance(self): |
|
|
|
self.cache.roll_left() |
|
self.sequence = self.sequence[:, 1:] |
|
self.probs = self.probs[:, 1:] |
|
self.current_seq_pos += 1 |
|
|
|
|
|
|
|
|
|
def cum_log_probs(self): |
|
|
|
cum_log_prob = torch.sum(torch.log(self.probs)) |
|
return cum_log_prob |
|
|
|
def sampled_cum_log_probs(self): |
|
|
|
cum_log_prob = torch.sum(torch.log(self.probs)) |
|
return torch.log(self.sampled_probs) + cum_log_prob |
|
|
|
|
|
|
|
|
|
def to_sequence(self): |
|
|
|
|
|
|
|
new_tokens = 0 |
|
added_tokens = 0 |
|
slen = self.generator.sequence.shape[-1] |
|
tlen = self.current_seq_pos + len(self) |
|
if tlen > slen: |
|
new_tokens = tlen - slen |
|
added_tokens = new_tokens |
|
self.generator.sequence = torch.cat((self.generator.sequence, self.sequence[:, -new_tokens:]), dim = 1) |
|
self.generator.cache.current_seq_len = tlen - 1 |
|
|
|
|
|
|
|
new_tokens_ = new_tokens |
|
for i in range(new_tokens_, len(self)): |
|
if self.generator.sequence[0, -i - 1] != self.sequence[0, -i - 1]: new_tokens = i + 1 |
|
|
|
|
|
|
|
if new_tokens > added_tokens: |
|
self.generator.sequence[0, -new_tokens:] = self.sequence[0, -new_tokens:] |
|
|
|
if new_tokens > len(self) - 1: new_tokens = len(self) - 1 |
|
if new_tokens > 0: |
|
self.cache.copy_states(self.generator.cache, |
|
len(self) - 1 - new_tokens, new_tokens, |
|
self.generator.cache.current_seq_len - new_tokens, new_tokens, |
|
0, 1, 0, 1) |
|
|
|
|
|
|
|
|
|
def record_last_cache_column(self): |
|
|
|
self.generator.cache.copy_states(self.cache, |
|
self.generator.cache.current_seq_len - 1, 1, |
|
len(self) - 1, 1, |
|
0, 1, 0, 1) |
|
|
|
|
|
def begin_beam_search(self): |
|
|
|
self.beams = None |
|
if self.settings.beams == 1 and self.settings.beam_length == 1: return |
|
|
|
self.in_beam_search = True |
|
|
|
|
|
|
|
def beam_search(self): |
|
|
|
if self.settings.beams == 1 and self.settings.beam_length == 1: return self.gen_single_token() |
|
assert self.in_beam_search |
|
|
|
|
|
if self.sequence is None: return self.gen_single_token() |
|
|
|
c_cache_len = self.cache.current_seq_len |
|
c_seq_len = self.sequence_actual.shape[-1] |
|
|
|
|
|
|
|
max_beam_length = min(self.model.config.max_seq_len - self.settings.beam_length, self.settings.beam_length) |
|
while self.beams is None or len(self.beams[0]) < max_beam_length: |
|
|
|
if self.beams is None: |
|
|
|
|
|
|
|
|
|
logits = self.model.forward(self.sequence[:, -1:], self.cache, lora = self.lora) |
|
|
|
cuda_ext.ext_apply_rep_penalty_mask_cpu(self.sequence, |
|
self.settings.token_repetition_penalty_max, |
|
self.settings.token_repetition_penalty_sustain, |
|
self.settings.token_repetition_penalty_decay, |
|
logits) |
|
|
|
tokens, probs = self.sample(logits, |
|
self.settings.temperature, |
|
self.settings.top_k, |
|
self.settings.top_p, |
|
self.settings.min_p, |
|
self.settings.typical, |
|
num = self.settings.beams) |
|
|
|
|
|
|
|
|
|
self.beams = [] |
|
while len(self.beams) < min(self.settings.beams, tokens.shape[-1]): |
|
|
|
beam = ExLlamaGenerator.Beam(self.settings, self, tokens[0, len(self.beams)], probs[0, len(self.beams)], c_seq_len) |
|
self.beams.append(beam) |
|
|
|
else: |
|
|
|
|
|
|
|
|
|
for beam in self.beams: |
|
|
|
beam.to_sequence() |
|
|
|
|
|
logits = self.model.forward(self.sequence[:, -1:], self.cache, lora = self.lora) |
|
|
|
cuda_ext.ext_apply_rep_penalty_mask_cpu(self.sequence, |
|
self.settings.token_repetition_penalty_max, |
|
self.settings.token_repetition_penalty_sustain, |
|
self.settings.token_repetition_penalty_decay, |
|
logits) |
|
|
|
tokens, probs = self.sample(logits, |
|
self.settings.temperature, |
|
self.settings.top_k, |
|
self.settings.top_p, |
|
self.settings.min_p, |
|
self.settings.typical, |
|
num = -1) |
|
|
|
beam.sampled_tokens = tokens |
|
beam.sampled_probs = probs |
|
|
|
beam.record_last_cache_column() |
|
self.cache.current_seq_len -= 1 |
|
|
|
|
|
|
|
tokens_ = [] |
|
probs_ = [] |
|
cum_log_probs_ = [] |
|
beams_ = [] |
|
for i, beam in enumerate(self.beams): |
|
tokens_.append(beam.sampled_tokens.squeeze(0)) |
|
probs_.append(beam.sampled_probs.squeeze(0)) |
|
cum_log_probs_.append(beam.sampled_cum_log_probs().squeeze(0)) |
|
beams_.append(torch.Tensor([i] * beam.sampled_tokens.shape[-1]).to(torch.int)) |
|
|
|
tokens_all = torch.cat(tokens_, dim = 0) |
|
probs_all = torch.cat(probs_, dim = 0) |
|
cum_log_probs_all = torch.cat(cum_log_probs_, dim = 0) |
|
beams_all = torch.cat(beams_, dim = 0) |
|
|
|
|
|
|
|
cum_log_probs_all, ind = cum_log_probs_all.sort(descending = True) |
|
probs_all = probs_all[ind] |
|
tokens_all = tokens_all[ind] |
|
beams_all = beams_all[ind] |
|
|
|
|
|
|
|
cum_log_probs_all = cum_log_probs_all[:self.settings.beams] |
|
probs_all = probs_all[:self.settings.beams] |
|
tokens_all = tokens_all[:self.settings.beams] |
|
beams_all = beams_all[:self.settings.beams] |
|
|
|
|
|
|
|
beams_all, ind = beams_all.sort() |
|
cum_log_probs_all = cum_log_probs_all[ind] |
|
tokens_all = tokens_all[ind] |
|
probs_all = probs_all[ind] |
|
|
|
|
|
|
|
|
|
|
|
for beam in self.beams: beam.moved = False |
|
beams_new = [] |
|
|
|
for i in range(len(beams_all)): |
|
|
|
new_token = tokens_all[i] |
|
new_prob = probs_all[i] |
|
beam_idx = beams_all[i].item() |
|
|
|
if not self.beams[beam_idx].moved: |
|
|
|
self.beams[beam_idx].sequence = torch.cat((self.beams[beam_idx].sequence, new_token.unsqueeze(0).unsqueeze(0)), dim = 1) |
|
self.beams[beam_idx].probs = torch.cat((self.beams[beam_idx].probs, new_prob.unsqueeze(0).unsqueeze(0)), dim = 1) |
|
self.beams[beam_idx].moved = True |
|
beams_new.append(self.beams[beam_idx]) |
|
|
|
else: |
|
|
|
nbeam = self.beams[beam_idx].clone() |
|
nbeam.sequence[:, -1] = new_token |
|
nbeam.probs[:, -1] = new_prob |
|
beams_new.append(nbeam) |
|
|
|
self.beams = beams_new |
|
|
|
|
|
|
|
|
|
max_log_probs = float("-inf") |
|
best_beam = None |
|
best_beam_idx = -1 |
|
for beam_idx, beam in enumerate(self.beams): |
|
beam_log_probs = beam.cum_log_probs() |
|
if beam_log_probs > max_log_probs: |
|
max_log_probs = beam_log_probs |
|
best_beam = beam |
|
best_beam_idx = beam_idx |
|
|
|
best_token = best_beam.sequence[:, 0] |
|
|
|
|
|
|
|
self.sequence[0, c_seq_len] = best_token |
|
self.sequence_actual = torch.cat((self.sequence_actual, best_token.unsqueeze(0)), dim = 1) |
|
|
|
|
|
|
|
best_beam.to_sequence() |
|
|
|
|
|
|
|
beams_new = [best_beam] |
|
|
|
for idx, beam in enumerate(self.beams): |
|
if idx != best_beam_idx and beam.sequence[:, 0] == best_token: |
|
beams_new.append(beam) |
|
|
|
self.beams = beams_new |
|
|
|
|
|
|
|
for beam in self.beams: beam.advance() |
|
|
|
|
|
|
|
return best_token |
|
|
|
|
|
def end_beam_search(self): |
|
|
|
if not self.in_beam_search: return |
|
|
|
self.sequence = self.sequence_actual.clone() |
|
self.cache.current_seq_len = self.sequence.shape[-1] - 1 |
|
self.in_beam_search = False |
|
|
|
|
|
def replace_last_token(self, token, seq = False): |
|
|
|
self.sequence_actual[:, -1] = token |
|
if seq: self.sequence[:, -1] = token |
|
|
|
|
|
def sequence_ends_with(self, tokens): |
|
|
|
if self.sequence_actual.shape[-1] < tokens.shape[-1] + 1: return False |
|
for i in range(tokens.shape[-1]): |
|
if self.sequence_actual[0, -i - 1] != tokens[0, -i - 1]: return False |
|
return True |
|
|
|
|