File size: 8,865 Bytes
0eeee8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
from pathlib import Path
import torch
import torch.nn.functional as F
from torch import version as torch_version
from modules import shared
from modules.logging_colors import logger
from modules.models import clear_torch_cache
from modules.text_generation import get_max_prompt_length
try:
from exllama.generator import ExLlamaGenerator
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
from exllama.tokenizer import ExLlamaTokenizer
except:
logger.warning('exllama module failed to import. Will attempt to import from repositories/.')
try:
from modules.relative_imports import RelativeImport
with RelativeImport("repositories/exllama"):
from generator import ExLlamaGenerator
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
except:
logger.error(
"Could not find repositories/exllama. Please ensure that exllama"
" (https://github.com/turboderp/exllama) is cloned inside repositories/ and is up to date."
)
raise
class ExllamaModel:
def __init__(self):
pass
@classmethod
def from_pretrained(self, path_to_model):
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
tokenizer_model_path = path_to_model / "tokenizer.model"
model_config_path = path_to_model / "config.json"
# Find the model checkpoint
model_path = None
for ext in ['.safetensors', '.pt', '.bin']:
found = list(path_to_model.glob(f"*{ext}"))
if len(found) > 0:
if len(found) > 1:
logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
model_path = found[-1]
break
config = ExLlamaConfig(str(model_config_path))
config.model_path = str(model_path)
config.max_seq_len = shared.args.max_seq_len
config.compress_pos_emb = shared.args.compress_pos_emb
if shared.args.gpu_split:
config.set_auto_map(shared.args.gpu_split)
config.gpu_peer_fix = True
if shared.args.alpha_value > 1 and shared.args.rope_freq_base == 0:
config.alpha_value = shared.args.alpha_value
config.calculate_rotary_embedding_base()
elif shared.args.rope_freq_base > 0:
config.rotary_embedding_base = shared.args.rope_freq_base
if torch_version.hip:
config.rmsnorm_no_half2 = True
config.rope_no_half2 = True
config.matmul_no_half2 = True
config.silu_no_half2 = True
model = ExLlama(config)
tokenizer = ExLlamaTokenizer(str(tokenizer_model_path))
cache = ExLlamaCache(model)
generator = ExLlamaGenerator(model, tokenizer, cache)
result = self()
result.config = config
result.model = model
result.cache = cache
result.tokenizer = tokenizer
result.generator = generator
return result, result
def encode(self, string, **kwargs):
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True)
def decode(self, ids, **kwargs):
if isinstance(ids, list):
ids = torch.tensor([ids])
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
ids = ids.view(1, -1)
return self.tokenizer.decode(ids)[0]
def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
if token_ids.shape[-1] > 1:
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()
def generate_with_streaming(self, prompt, state):
# The cache batch size must be 2 for CFG and 1 otherwise
if state['guidance_scale'] == 1:
if self.cache.batch_size == 2:
del self.cache
clear_torch_cache()
self.cache = ExLlamaCache(self.model)
self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
else:
if self.cache.batch_size == 1:
del self.cache
clear_torch_cache()
self.cache = ExLlamaCache(self.model, batch_size=2)
self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
self.generator.settings.temperature = state['temperature']
self.generator.settings.top_p = state['top_p']
self.generator.settings.top_k = state['top_k']
self.generator.settings.typical = state['typical_p']
self.generator.settings.token_repetition_penalty_max = state['repetition_penalty']
self.generator.settings.token_repetition_penalty_sustain = -1 if state['repetition_penalty_range'] <= 0 else state['repetition_penalty_range']
if state['ban_eos_token']:
self.generator.disallow_tokens([self.tokenizer.eos_token_id])
else:
self.generator.disallow_tokens(None)
if state['custom_token_bans']:
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
if len(to_ban) > 0:
self.generator.disallow_tokens(to_ban)
# Case 1: no CFG
if state['guidance_scale'] == 1:
self.generator.end_beam_search()
# Tokenizing the input
ids = self.generator.tokenizer.encode(prompt, max_seq_len=self.model.config.max_seq_len)
if state['add_bos_token']:
ids = torch.cat(
[torch.tensor([[self.tokenizer.bos_token_id]]).to(ids.device),
ids], dim=1
).to(torch.int64)
ids = ids[:, -get_max_prompt_length(state):]
if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - ids.shape[-1]
else:
max_new_tokens = state['max_new_tokens']
self.generator.gen_begin_reuse(ids)
initial_len = self.generator.sequence[0].shape[0]
has_leading_space = False
for i in range(max_new_tokens):
token = self.generator.gen_single_token()
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
has_leading_space = True
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
if has_leading_space:
decoded_text = ' ' + decoded_text
yield decoded_text
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
break
# Case 2: CFG
# Copied from https://github.com/turboderp/exllama/blob/master/example_cfg.py
else:
alpha = state['guidance_scale']
prompts = [prompt, state['negative_prompt'] or '']
ids, mask = self.tokenizer.encode(
prompts,
return_mask=True,
max_seq_len=self.model.config.max_seq_len,
add_bos=state['add_bos_token']
)
if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - ids[0].shape[-1]
else:
max_new_tokens = state['max_new_tokens']
self.generator.gen_begin(ids, mask=mask)
initial_len = self.generator.sequence[0].shape[0]
has_leading_space = False
for i in range(max_new_tokens):
logits = self.model.forward(self.generator.sequence[:, -1:], self.cache, input_mask=mask)
self.generator.apply_rep_penalty(logits)
logits = F.log_softmax(logits, dim=-1)
logits_mixed = alpha * logits[0] + (1 - alpha) * logits[1]
token, _ = self.generator.sample_current(logits_mixed)
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
has_leading_space = True
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
if has_leading_space:
decoded_text = ' ' + decoded_text
yield decoded_text
if token.item() == self.tokenizer.eos_token_id or shared.stop_everything:
break
batch_token = token.repeat(2, 1)
self.generator.gen_accept_token(batch_token)
def generate(self, prompt, state):
output = ''
for output in self.generate_with_streaming(prompt, state):
pass
return output
|