grad / repositories /exllama /generator.py
Doa-doa's picture
Upload folder using huggingface_hub
72268ee
import cuda_ext
from model import ExLlama, ExLlamaCache
from lora import ExLlamaLora
import torch
import torch.nn.functional as F
class ExLlamaGenerator:
class Settings:
temperature = 0.95
top_k = 40 # consider the most probable top_k samples, 0 to disable top_k sampling
top_p = 0.65 # consider tokens up to a cumulative probabiltiy of top_p, 0.0 to disable top_p sampling
min_p = 0.0 # Do not consider tokens with probability less than this
typical = 0.0 # Locally typical sampling threshold, 0.0 to disable typical sampling
token_repetition_penalty_max = 1.15 # Repetition penalty for most recent tokens
token_repetition_penalty_sustain = 256 # No. most recent tokens to repeat penalty for, -1 to apply to whole context
token_repetition_penalty_decay = 128 # Gradually decrease penalty over this many tokens
beams = 1
beam_length = 1
model: ExLlama
sequence: torch.Tensor or None
sequence_actual: torch.Tensor or None
settings: Settings
beams: int or None
max_beam_length: int
in_beam_search: True
disallowed_tokens: list[int] or None
lora: ExLlamaLora or None
def __init__(self, model, tokenizer, cache):
self.model = model
self.tokenizer = tokenizer
self.cache = cache
self.reset()
def reset(self):
self.cache.current_seq_len = 0
self.sequence = None
self.sequence_actual = None
self.settings = ExLlamaGenerator.Settings()
self.beams = None
self.max_beam_length = 0
self.in_beam_search = False
self.disallowed_tokens = None
self.lora = None
def make_rep_mask(self, penalty_max, sustain, decay):
return cuda_ext.ext_rep_penalty_mask_cpu(self.model.config.vocab_size, self.sequence, penalty_max, sustain, decay)
def batched_sample(self, logits, temperature, top_k, top_p, min_p, typical, num = 1):
if logits.shape[0] == 1: return self.sample(logits, temperature, top_k, top_p, min_p, typical, num)
samples = []
scores = []
for i in range(logits.shape[0]):
t, s = self.sample(logits[i, :, :], temperature, top_k, top_p, min_p, typical)
samples.append(t)
scores.append(s)
return torch.cat(samples, dim = 0), torch.cat(scores, dim = 0)
# Sample one token from logits with current settings
def sample_current(self, logits, num = 1):
return self.sample(logits,
self.settings.temperature,
self.settings.top_k,
self.settings.top_p,
self.settings.min_p,
self.settings.typical)
# Sample one token from logits
def sample(self, logits, temperature, top_k, top_p, min_p, typical, num = 1):
# torch.manual_seed(42)
if logits.dim() == 3: logits = logits[0, -1, :]
elif logits.dim() == 2: logits = logits[-1, :]
else: raise ValueError("Bad logits dimension")
# Disallow tokens
if self.disallowed_tokens is not None:
logits[self.disallowed_tokens] = float("-inf")
# Base probabilities
logits /= temperature
logits += 1e-8
probs = torch.softmax(logits, dim = -1)
# Top K
if top_k == 0:
top_probs, top_indices = torch.sort(probs, descending = True)
else:
top_probs, top_indices = torch.topk(probs, top_k)
top_probs = F.normalize(top_probs, p = 1, dim = -1)
# Top P
if top_p > 0.0:
num_top_p_probs = 0
cum_prob = top_probs[0].item()
while True:
num_top_p_probs += 1
if num_top_p_probs == top_probs.shape[-1]: break
if top_probs[num_top_p_probs].item() < min_p: break
cum_prob += top_probs[num_top_p_probs].item()
if cum_prob > top_p: break
top_probs = top_probs[:num_top_p_probs]
top_probs = F.normalize(top_probs, p = 1, dim = -1)
top_indices = top_indices[:num_top_p_probs]
# Locally typical sampling
if typical > 0.0:
epsilon = 1e-10
log_probs = (top_probs + epsilon).log()
neg_entropy = (top_probs * log_probs).sum()
entropy_dev = (neg_entropy - log_probs).abs()
_, entropy_dev_order = torch.sort(entropy_dev)
top_probs = top_probs.gather(-1, entropy_dev_order)
top_indices = top_indices.gather(-1, entropy_dev_order)
num_typical_probs = 0
cum_prob = top_probs[0].item()
while True:
num_typical_probs += 1
if num_typical_probs == top_probs.shape[-1]: break
cum_prob += top_probs[num_typical_probs].item()
if cum_prob > typical: break
top_probs = top_probs[:num_typical_probs]
top_probs = F.normalize(top_probs, p = 1, dim = -1)
top_indices = top_indices[:num_typical_probs]
# Multinomial sampling from top_probs, kept in same order as top_indices
sampled_ind = torch.multinomial(top_probs, top_probs.shape[-1] if num == -1 else min(num, top_probs.shape[-1]))
sampled_tokens = top_indices[sampled_ind]
sampled_probs = top_probs[sampled_ind] # Return probs before second norm
if sampled_tokens.shape[0] > 1:
sampled_tokens, ind = sampled_tokens.sort()
sampled_probs = sampled_probs[ind]
return sampled_tokens.unsqueeze(0), sampled_probs.unsqueeze(0)
def disallow_tokens(self, tokens):
self.disallowed_tokens = tokens
def gen_begin(self, in_tokens, mask = None):
self.end_beam_search()
self.sequence = in_tokens.clone()
self.sequence_actual = in_tokens.clone()
self.cache.current_seq_len = 0
self.model.forward(self.sequence[:, :-1], self.cache, preprocess_only = True, lora = self.lora, input_mask = mask)
def gen_begin_empty(self):
self.end_beam_search()
self.sequence = None
self.sequence_actual = None
self.cache.current_seq_len = 0
def gen_begin_reuse(self, in_tokens, mask = None):
self.end_beam_search()
if self.sequence is None or self.cache.current_seq_len == 0:
self.gen_begin(in_tokens, mask = mask)
return 0
# if in_tokens.shape[-1] < self.sequence.shape[-1]:
# self.sequence = self.sequence[:, :in_tokens.shape[-1]]
reuse = 0
while reuse < self.sequence.shape[-1] and reuse < in_tokens.shape[-1] and self.sequence[0, reuse] == in_tokens[0, reuse]:
reuse += 1
if reuse < 2:
self.gen_begin(in_tokens, mask = mask)
return 0
# print (f"Reusing cache: {reuse} tokens")
self.cache.current_seq_len = reuse - 1
self.sequence = self.sequence[:, :reuse]
self.sequence_actual = self.sequence.clone()
if reuse < in_tokens.shape[-1]: self.gen_feed_tokens(in_tokens[:, reuse:], mask = mask)
return reuse
def gen_feed_tokens(self, in_tokens, mask = None):
if self.sequence is None:
self.gen_begin(in_tokens, mask = mask)
return
self.end_beam_search()
start = self.sequence.shape[-1] - 1
if start < 0:
start = 0
self.sequence = in_tokens.clone()
else:
self.sequence = torch.cat((self.sequence, in_tokens), dim = 1)
if start < self.sequence.shape[-1] - 1:
self.model.forward(self.sequence[:, start : -1], self.cache, preprocess_only = True, lora = self.lora, input_mask = mask)
self.sequence_actual = self.sequence
def gen_accept_token(self, token):
self.end_beam_search()
if self.sequence is None: self.sequence = token
else: self.sequence = torch.cat((self.sequence, token), dim = 1)
self.sequence_actual = self.sequence
def gen_rewind(self, num_tokens):
if num_tokens == 0: return
self.end_beam_search()
self.sequence = self.sequence[:, :-num_tokens]
self.cache.current_seq_len -= num_tokens
self.sequence_actual = self.sequence
def gen_prune_right(self, tokens, mask = None):
self.end_beam_search()
if tokens > self.sequence.shape[-1] - 1: return
self.gen_begin(self.sequence[:, tokens:], mask = mask)
self.sequence_actual = self.sequence
def gen_prune_to(self, min_tokens_to_keep, token_id, mask = None):
self.end_beam_search()
if self.gen_num_tokens() <= min_tokens_to_keep: return
while self.gen_num_tokens() > min_tokens_to_keep:
pruned = False
for i in range(self.sequence.shape[-1] - 1):
if self.sequence[0, i] == token_id:
self.sequence = self.sequence[:, i + 1:]
pruned = True
break
if not pruned: return
self.gen_begin(self.sequence, mask = mask)
def gen_prune_left(self, num_tokens, mask = None):
num_tokens = min(num_tokens, self.sequence_actual.shape[-1] - 1)
if self.in_beam_search:
self.end_beam_search() # TODO: Try to avoid restarting beam search when generating past chunk boundary
self.sequence = self.sequence[:, num_tokens:]
self.begin_beam_search()
else:
self.sequence = self.sequence[:, num_tokens:]
self.gen_begin(self.sequence, mask = mask)
def gen_num_tokens(self):
return self.sequence_actual.shape[-1]
# Simple generator function
def generate_simple(self, prompt, max_new_tokens = 128):
self.end_beam_search()
ids, mask = self.tokenizer.encode(prompt, return_mask = True, max_seq_len = self.model.config.max_seq_len)
self.gen_begin(ids, mask = mask)
max_new_tokens = min(max_new_tokens, self.model.config.max_seq_len - ids.shape[1])
eos = torch.zeros((ids.shape[0],), dtype = torch.bool)
for i in range(max_new_tokens):
token = self.gen_single_token(mask = mask)
for j in range(token.shape[0]):
if token[j, 0].item() == self.tokenizer.eos_token_id: eos[j] = True
if eos.all(): break
text = self.tokenizer.decode(self.sequence[0] if self.sequence.shape[0] == 1 else self.sequence)
return text
# Apply repetition penalty with current settings
def apply_rep_penalty(self, logits):
cuda_ext.ext_apply_rep_penalty_mask_cpu(self.sequence,
self.settings.token_repetition_penalty_max,
self.settings.token_repetition_penalty_sustain,
self.settings.token_repetition_penalty_decay,
logits)
# Generate a single token with the current settings, append to sequence
def gen_single_token(self, constraints = None, mask = None):
self.end_beam_search()
# Simple sampling case:
if self.sequence is not None:
logits = self.model.forward(self.sequence[:, -1:], self.cache, lora = self.lora, input_mask = mask)
self.apply_rep_penalty(logits)
logits[:, :, self.tokenizer.bos_token_id] = -10000.0
if constraints is not None:
for c in constraints: logits[:, :, c] += 10000.0
logits[:, :, :] -= 10000.0
token, _ = self.batched_sample(logits,
self.settings.temperature,
self.settings.top_k,
self.settings.top_p,
self.settings.min_p + 0.01 if constraints is not None else 0.0,
self.settings.typical)
else:
# bos = torch.Tensor([[self.tokenizer.bos_token_id]]).long()
# logits = self.model.forward(bos, self.cache)
# self.cache.current_seq_len = 0
if constraints is not None:
token = constraints[0]
else:
token = torch.Tensor([[self.tokenizer.bos_token_id]]).long()
self.gen_accept_token(token)
return token
# Beam search
class Beam:
sequence: torch.Tensor # tokens generated in beam
probs: torch.Tensor # probability score per token
cache: ExLlamaCache # cached keys/values for this beam
current_seq_pos: int # position of beam in current sequence
settings = None
generator = None
sampled_tokens: torch.Tensor
sampled_probs: torch.Tensor
moved: bool = False
def __init__(self, settings, generator, first_token = None, first_prob = None, seq_pos = None):
self.settings = settings
self.generator = generator
self.sequence = first_token.unsqueeze(0).unsqueeze(0) if first_token is not None else None
self.probs = first_prob.unsqueeze(0).unsqueeze(0) if first_prob is not None else None
self.cache = ExLlamaCache(self.generator.model, max_seq_len = self.settings.beam_length)
self.current_seq_pos = seq_pos
def __len__(self):
return self.sequence.shape[-1]
def clone(self):
new = ExLlamaGenerator.Beam(self.settings, self.generator)
new.sequence = self.sequence.clone()
new.probs = self.probs.clone()
new.cache = self.cache.clone()
new.current_seq_pos = self.current_seq_pos
new.sampled_tokens = self.sampled_tokens.clone()
new.sampled_probs = self.sampled_probs.clone()
new.moved = self.moved
return new
# List of references to this instance
def advance(self):
self.cache.roll_left()
self.sequence = self.sequence[:, 1:]
self.probs = self.probs[:, 1:]
self.current_seq_pos += 1
# Cumulative probabilities
def cum_log_probs(self):
cum_log_prob = torch.sum(torch.log(self.probs))
return cum_log_prob
def sampled_cum_log_probs(self):
cum_log_prob = torch.sum(torch.log(self.probs))
return torch.log(self.sampled_probs) + cum_log_prob
# Insert current beam in sequence
def to_sequence(self):
# Extend generator sequence and cache if needed
new_tokens = 0
added_tokens = 0
slen = self.generator.sequence.shape[-1]
tlen = self.current_seq_pos + len(self)
if tlen > slen:
new_tokens = tlen - slen
added_tokens = new_tokens
self.generator.sequence = torch.cat((self.generator.sequence, self.sequence[:, -new_tokens:]), dim = 1)
self.generator.cache.current_seq_len = tlen - 1
# Determine how much of generator sequence needs to be updated
new_tokens_ = new_tokens
for i in range(new_tokens_, len(self)):
if self.generator.sequence[0, -i - 1] != self.sequence[0, -i - 1]: new_tokens = i + 1
# Update sequence and cache
if new_tokens > added_tokens:
self.generator.sequence[0, -new_tokens:] = self.sequence[0, -new_tokens:]
if new_tokens > len(self) - 1: new_tokens = len(self) - 1
if new_tokens > 0:
self.cache.copy_states(self.generator.cache,
len(self) - 1 - new_tokens, new_tokens,
self.generator.cache.current_seq_len - new_tokens, new_tokens,
0, 1, 0, 1)
# Copy last column of cache to this beam (after generation)
def record_last_cache_column(self):
self.generator.cache.copy_states(self.cache,
self.generator.cache.current_seq_len - 1, 1,
len(self) - 1, 1,
0, 1, 0, 1)
def begin_beam_search(self):
self.beams = None
if self.settings.beams == 1 and self.settings.beam_length == 1: return
self.in_beam_search = True
# self.testl = []
def beam_search(self):
if self.settings.beams == 1 and self.settings.beam_length == 1: return self.gen_single_token()
assert self.in_beam_search
# Kludge: The first token returned with an empty context is generated without beam search
if self.sequence is None: return self.gen_single_token()
c_cache_len = self.cache.current_seq_len
c_seq_len = self.sequence_actual.shape[-1]
# Begin here
max_beam_length = min(self.model.config.max_seq_len - self.settings.beam_length, self.settings.beam_length)
while self.beams is None or len(self.beams[0]) < max_beam_length:
if self.beams is None:
# Initial tokens for initial beams
# self.cache.debug()
logits = self.model.forward(self.sequence[:, -1:], self.cache, lora = self.lora)
cuda_ext.ext_apply_rep_penalty_mask_cpu(self.sequence,
self.settings.token_repetition_penalty_max,
self.settings.token_repetition_penalty_sustain,
self.settings.token_repetition_penalty_decay,
logits)
tokens, probs = self.sample(logits,
self.settings.temperature,
self.settings.top_k,
self.settings.top_p,
self.settings.min_p,
self.settings.typical,
num = self.settings.beams)
# self.cache is updated with k/v for last token
# Setup initial beams
self.beams = []
while len(self.beams) < min(self.settings.beams, tokens.shape[-1]):
beam = ExLlamaGenerator.Beam(self.settings, self, tokens[0, len(self.beams)], probs[0, len(self.beams)], c_seq_len)
self.beams.append(beam)
else:
# Sample from each beam
# print(len(self.beams), end = "")
for beam in self.beams:
beam.to_sequence()
# self.cache.debug()
logits = self.model.forward(self.sequence[:, -1:], self.cache, lora = self.lora)
cuda_ext.ext_apply_rep_penalty_mask_cpu(self.sequence,
self.settings.token_repetition_penalty_max,
self.settings.token_repetition_penalty_sustain,
self.settings.token_repetition_penalty_decay,
logits)
tokens, probs = self.sample(logits,
self.settings.temperature,
self.settings.top_k,
self.settings.top_p,
self.settings.min_p,
self.settings.typical,
num = -1)
beam.sampled_tokens = tokens
beam.sampled_probs = probs
beam.record_last_cache_column()
self.cache.current_seq_len -= 1
# Collect options for all beams
tokens_ = []
probs_ = []
cum_log_probs_ = []
beams_ = []
for i, beam in enumerate(self.beams):
tokens_.append(beam.sampled_tokens.squeeze(0))
probs_.append(beam.sampled_probs.squeeze(0))
cum_log_probs_.append(beam.sampled_cum_log_probs().squeeze(0))
beams_.append(torch.Tensor([i] * beam.sampled_tokens.shape[-1]).to(torch.int))
tokens_all = torch.cat(tokens_, dim = 0)
probs_all = torch.cat(probs_, dim = 0)
cum_log_probs_all = torch.cat(cum_log_probs_, dim = 0)
beams_all = torch.cat(beams_, dim = 0)
# Sort by cumulative probability
cum_log_probs_all, ind = cum_log_probs_all.sort(descending = True)
probs_all = probs_all[ind]
tokens_all = tokens_all[ind]
beams_all = beams_all[ind]
# Reduce to beam limit
cum_log_probs_all = cum_log_probs_all[:self.settings.beams]
probs_all = probs_all[:self.settings.beams]
tokens_all = tokens_all[:self.settings.beams]
beams_all = beams_all[:self.settings.beams]
# Re-sort by beam index
beams_all, ind = beams_all.sort()
cum_log_probs_all = cum_log_probs_all[ind]
tokens_all = tokens_all[ind]
probs_all = probs_all[ind]
# test = [self.tokenizer.decode(beam.sequence) for beam in self.beams]
# Rebuild beams/caches
for beam in self.beams: beam.moved = False
beams_new = []
for i in range(len(beams_all)):
new_token = tokens_all[i]
new_prob = probs_all[i]
beam_idx = beams_all[i].item()
if not self.beams[beam_idx].moved:
self.beams[beam_idx].sequence = torch.cat((self.beams[beam_idx].sequence, new_token.unsqueeze(0).unsqueeze(0)), dim = 1)
self.beams[beam_idx].probs = torch.cat((self.beams[beam_idx].probs, new_prob.unsqueeze(0).unsqueeze(0)), dim = 1)
self.beams[beam_idx].moved = True
beams_new.append(self.beams[beam_idx])
else:
nbeam = self.beams[beam_idx].clone()
nbeam.sequence[:, -1] = new_token
nbeam.probs[:, -1] = new_prob
beams_new.append(nbeam)
self.beams = beams_new
# Beam length is filled up, select winning beam
max_log_probs = float("-inf")
best_beam = None
best_beam_idx = -1
for beam_idx, beam in enumerate(self.beams):
beam_log_probs = beam.cum_log_probs()
if beam_log_probs > max_log_probs:
max_log_probs = beam_log_probs
best_beam = beam
best_beam_idx = beam_idx
best_token = best_beam.sequence[:, 0]
# Insert in sequence
self.sequence[0, c_seq_len] = best_token
self.sequence_actual = torch.cat((self.sequence_actual, best_token.unsqueeze(0)), dim = 1)
# Copy cache state for winning beam
best_beam.to_sequence()
# Prune other beams that don't begin with the winning token
beams_new = [best_beam]
for idx, beam in enumerate(self.beams):
if idx != best_beam_idx and beam.sequence[:, 0] == best_token:
beams_new.append(beam)
self.beams = beams_new
# Advance all remaining beams and caches
for beam in self.beams: beam.advance()
# Done
return best_token
def end_beam_search(self):
if not self.in_beam_search: return
self.sequence = self.sequence_actual.clone()
self.cache.current_seq_len = self.sequence.shape[-1] - 1
self.in_beam_search = False
def replace_last_token(self, token, seq = False):
self.sequence_actual[:, -1] = token
if seq: self.sequence[:, -1] = token
def sequence_ends_with(self, tokens):
if self.sequence_actual.shape[-1] < tokens.shape[-1] + 1: return False
for i in range(tokens.shape[-1]):
if self.sequence_actual[0, -i - 1] != tokens[0, -i - 1]: return False
return True