|
import asyncio |
|
import websockets |
|
import json |
|
from sentencepiece import SentencePieceProcessor |
|
|
|
from model import ExLlama, ExLlamaCache, ExLlamaConfig |
|
from lora import ExLlamaLora |
|
from tokenizer import ExLlamaTokenizer |
|
from generator import ExLlamaGenerator |
|
import argparse |
|
import torch |
|
import sys |
|
import os |
|
import glob |
|
import model_init |
|
|
|
|
|
|
|
model: ExLlama |
|
cache: ExLlamaCache |
|
config: ExLlamaConfig |
|
generator: ExLlamaGenerator |
|
tokenizer: ExLlamaTokenizer |
|
max_cached_strings = 100 |
|
tokenizer_cache = {} |
|
|
|
|
|
prompt_ids: torch.tensor |
|
stop_strings: list |
|
stop_tokens: list |
|
held_text: str |
|
max_stop_string: int |
|
remaining_tokens: int |
|
|
|
full_prompt: str |
|
utilized_prompt: str |
|
built_response: str |
|
|
|
def cached_tokenize(text: str): |
|
global model, cache, config, generator, tokenizer |
|
global max_cached_strings, tokenizer_cache |
|
|
|
if text in tokenizer_cache: |
|
return tokenizer_cache[text] |
|
|
|
while len(tokenizer_cache) >= max_cached_strings: |
|
del tokenizer_cache[next(iter(tokenizer_cache))] |
|
|
|
new_enc = tokenizer.encode(text) |
|
tokenizer_cache[text] = new_enc |
|
return new_enc |
|
|
|
def begin_stream(prompt: str, stop_conditions: list, max_new_tokens: int, gen_settings: ExLlamaGenerator.Settings): |
|
global model, cache, config, generator, tokenizer |
|
global stop_strings, stop_tokens, prompt_ids, held_text, max_stop_string, remaining_tokens |
|
global full_prompt, utilized_prompt, built_response |
|
|
|
|
|
|
|
max_input_tokens = model.config.max_seq_len - max_new_tokens |
|
input_ids = cached_tokenize(prompt) |
|
input_ids = input_ids[:, -max_input_tokens:] |
|
prompt_ids = input_ids |
|
|
|
full_prompt = prompt |
|
utilized_prompt = tokenizer.decode(prompt_ids)[0] |
|
built_response = "" |
|
|
|
remaining_tokens = max_new_tokens |
|
|
|
|
|
|
|
stop_strings = [] |
|
stop_tokens = [] |
|
for t in stop_conditions: |
|
if isinstance(t, int): stop_tokens += [t] |
|
if isinstance(t, str): stop_strings += [t] |
|
|
|
held_text = "" |
|
|
|
max_stop_string = 2 |
|
for ss in stop_strings: |
|
max_stop_string = max(max_stop_string, get_num_tokens(ss) + 2) |
|
|
|
generator.settings = gen_settings |
|
|
|
|
|
|
|
generator.gen_begin_reuse(input_ids) |
|
|
|
def stream(): |
|
global model, cache, config, generator, tokenizer |
|
global stop_strings, stop_tokens, prompt_ids, held_text, max_stop_string, remaining_tokens |
|
global full_prompt, utilized_prompt, built_response |
|
|
|
|
|
|
|
if remaining_tokens == 0: |
|
return held_text, True, full_prompt + built_response, utilized_prompt + built_response, built_response |
|
remaining_tokens -= 1 |
|
|
|
|
|
|
|
old_tail = tokenizer.decode(generator.sequence_actual[:, -max_stop_string:])[0] |
|
next_token = generator.gen_single_token() |
|
|
|
|
|
|
|
if next_token in stop_tokens: |
|
return held_text, True, full_prompt + built_response, utilized_prompt + built_response, built_response |
|
|
|
|
|
|
|
new_tail = tokenizer.decode(generator.sequence_actual[:, -(max_stop_string + 1):])[0] |
|
added_text = new_tail[len(old_tail):] |
|
held_text += added_text |
|
|
|
|
|
|
|
partial_ss = False |
|
for ss in stop_strings: |
|
|
|
|
|
|
|
position = held_text.find(ss) |
|
if position != -1: |
|
built_response += held_text[:position] |
|
return held_text[:position], True, full_prompt + built_response, utilized_prompt + built_response, built_response |
|
|
|
|
|
|
|
overlap = 0 |
|
for j in range(1, min(len(held_text), len(ss)) + 1): |
|
if held_text[-j:] == ss[:j]: overlap = j |
|
if overlap > 0: partial_ss = True |
|
|
|
|
|
|
|
if partial_ss: |
|
return "", False, full_prompt + built_response, utilized_prompt + built_response, built_response |
|
|
|
stream_text = held_text |
|
held_text = "" |
|
built_response += stream_text |
|
return stream_text, False, full_prompt, utilized_prompt, built_response |
|
|
|
def leftTrimTokens(text: str, desiredLen: int): |
|
|
|
encodedText = tokenizer.encode(text) |
|
if encodedText.shape[-1] <= desiredLen: |
|
return text |
|
else: |
|
return tokenizer.decode(encodedText[:, -desiredLen:])[0] |
|
|
|
def oneshot_generation(prompt: str, stop_conditions: list, max_new_tokens: int, gen_settings: ExLlamaGenerator.Settings): |
|
|
|
begin_stream(prompt, stop_conditions, max_new_tokens, gen_settings) |
|
response = "" |
|
while True: |
|
_, eos, _, _, _ = stream() |
|
if eos: break |
|
|
|
return full_prompt + built_response, utilized_prompt + built_response, built_response |
|
|
|
|
|
def get_num_tokens(text: str): |
|
|
|
return cached_tokenize(text).shape[-1] |
|
|
|
|
|
|
|
|
|
|
|
async def estimateToken(request, ws): |
|
text = request["text"] |
|
numTokens=get_num_tokens(text) |
|
return numTokens |
|
|
|
async def oneShotInfer(request, ws): |
|
stopToken = request["stopToken"] |
|
fullContext = request["text"] |
|
maxNew = int(request["maxNew"]) |
|
top_p = float(request["top_p"]) |
|
top_k = int(request["top_k"]) |
|
temp = float(request["temp"]) |
|
rep_pen = float(request["rep_pen"]) |
|
sc = [tokenizer.eos_token_id] |
|
sc.append(stopToken) |
|
|
|
gs = ExLlamaGenerator.Settings() |
|
gs.top_k = top_k |
|
gs.top_p = top_p |
|
gs.temperature = temp |
|
gs.token_repetition_penalty_max = rep_pen |
|
|
|
full_ctx, util_ctx, response = oneshot_generation(prompt=fullContext, stop_conditions=sc, max_new_tokens=maxNew, gen_settings=gs) |
|
|
|
return full_ctx, util_ctx, response |
|
|
|
async def streamInfer(request, ws): |
|
stopToken = [tokenizer.eos_token_id] |
|
stopToken += request["stopToken"].split(',') |
|
prompt = request["text"] |
|
maxNew = int(request["maxNew"]) |
|
top_p = float(request["top_p"]) |
|
top_k = int(request["top_k"]) |
|
temp = float(request["temp"]) |
|
rep_pen = float(request["rep_pen"]) |
|
gs = ExLlamaGenerator.Settings() |
|
gs.top_k = top_k |
|
gs.top_p = top_p |
|
gs.temperature = temp |
|
gs.token_repetition_penalty_max = rep_pen |
|
begin_stream(prompt, stopToken, maxNew, gs) |
|
while True: |
|
chunk, eos, x, y, builtResp = stream() |
|
await ws.send(json.dumps({'action':request["action"], |
|
'request_id':request['request_id'], |
|
'utilContext':utilized_prompt + builtResp, |
|
'response':builtResp})) |
|
if eos: break |
|
return utilized_prompt + built_response,builtResp |
|
|
|
|
|
async def main(websocket, path): |
|
async for message in websocket: |
|
|
|
request = json.loads(message) |
|
reqID = request["request_id"] |
|
action = request["action"] |
|
|
|
if action == "estimateToken": |
|
response = await estimateToken(request, websocket) |
|
await websocket.send(json.dumps({'action':action, 'request_id':reqID, 'response':response})) |
|
|
|
elif action == "echo": |
|
await websocket.send(json.dumps({'action':action, 'request_id':reqID})) |
|
|
|
elif action == "oneShotInfer": |
|
fctx, utlctx, res = await oneShotInfer(request, websocket) |
|
await websocket.send(json.dumps({'action':action, 'request_id':reqID,'utilContext':utlctx, 'response':res})) |
|
|
|
elif action == "leftTrim": |
|
prompt = request["text"] |
|
desiredLen = int(request["desiredLen"]) |
|
processedPrompt = leftTrimTokens(prompt, desiredLen) |
|
await websocket.send(json.dumps({'action':action, 'request_id':reqID, 'response':processedPrompt})) |
|
|
|
else: |
|
utlctx, builtResp= await streamInfer(request, websocket) |
|
await websocket.send(json.dumps({'action':action, 'request_id':reqID,'utilContext':utlctx, 'response':builtResp+'</s>'})) |
|
|
|
|
|
|
|
|
|
|
|
|
|
model_directory = "./models/Llama-2-70B-chat-GPTQ/" |
|
|
|
tokenizer_path = os.path.join(model_directory, "tokenizer.model") |
|
model_config_path = os.path.join(model_directory, "config.json") |
|
st_pattern = os.path.join(model_directory, "*.safetensors") |
|
model_path = glob.glob(st_pattern)[0] |
|
esTokenizer = SentencePieceProcessor(model_file = tokenizer_path) |
|
config = ExLlamaConfig(model_config_path) |
|
config.set_auto_map('17.615,18.8897') |
|
config.model_path = model_path |
|
|
|
model = ExLlama(config) |
|
print(f"Model loaded: {model_path}") |
|
|
|
tokenizer = ExLlamaTokenizer(tokenizer_path) |
|
cache = ExLlamaCache(model) |
|
generator = ExLlamaGenerator(model, tokenizer, cache) |
|
start_server = websockets.serve(main, "0.0.0.0", 8080) |
|
|
|
asyncio.get_event_loop().run_until_complete(start_server) |
|
asyncio.get_event_loop().run_forever() |
|
|