File size: 3,961 Bytes
8cea444 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
print('Loading...')
from src.model_run import RWKV_RNN
import numpy as np
import os, copy, types, gc, sys
import torch
from src.utils import TOKENIZER
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
np.set_printoptions(precision=4, suppress=True, linewidth=200)
WORD_NAME = ["20B_tokenizer.json", "20B_tokenizer.json"]
UNKNOWN_CHAR = None
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
args = types.SimpleNamespace()
args.RUN_DEVICE = "cpu"
args.FLOAT_MODE = "fp32"
args.vocab_size = 50277
args.MODEL_NAME = 'zrwkv-37fifth'
args.n_layer = 12
args.n_embd = 768
args.ctx_len = 1024
user = "User"
bot = "Daniel"
interface = ":"
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
MODEL_NAME = args.MODEL_NAME
print(f'loading... {MODEL_NAME}')
model = RWKV_RNN(args)
model_tokens = []
current_state = None
def run_rnn(tokens, newline_adj = 0):
global model_tokens, current_state
for i in range(len(tokens)):
model_tokens += [int(tokens[i])]
if i == len(tokens) - 1:
out, current_state = model.forward(model_tokens, current_state)
else:
current_state = model.forward(model_tokens, current_state, preprocess_only = True)
out[0] = -999999999
out[187] += newline_adj
return out
all_state = {}
def save_all_stat(name, last_out):
all_state[name] = {}
all_state[name]['out'] = last_out
all_state[name]['rnn'] = copy.deepcopy(current_state)
all_state[name]['token'] = copy.deepcopy(model_tokens)
def load_all_stat(name):
global model_tokens, current_state
current_state = copy.deepcopy(all_state[name]['rnn'])
model_tokens = copy.deepcopy(all_state[name]['token'])
return all_state[name]['out']
print(f'\nRun prompt...')
out = ""
gc.collect()
save_all_stat('chat_init', out)
save_all_stat('chat', out) # ensure that 'chat' key is added to all_state
print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n')
def reply_msg(msg):
print(f'{bot}{interface} {msg}\n')
def on_message(message):
global model_tokens, current_state
msg = message.replace('\\n','\n').strip()
if len(msg) > 10000:
reply_msg('your message is too long (max 1000 tokens)')
return
out = load_all_stat('chat')
new = f"{user}{interface} {msg}\n{bot}{interface}"
out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999)
save_all_stat('chat_pre', out)
begin = len(model_tokens)
out_last = begin
print(f'{bot}{interface}', end='', flush=True)
for i in range(8000):
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=1.0,
top_p_usual=0.85,
top_p_newline=0.85,
)
out = run_rnn([token], newline_adj=1)
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx and 'user' not in str(xxx).lower() and '\n' not in xxx and str(xxx) != ':' and str(xxx) != '\n\n' and len(str(xxx)) > 0:
print(xxx, end='', flush=True)
out_last = begin + i + 1
else:
print('\n', end='', flush=True)
out_last = begin + i + 1
send_msg = tokenizer.tokenizer.decode(model_tokens[begin:])
if '\ufffd' in send_msg or send_msg.endswith(f'{user}{interface}') or send_msg.endswith(f'{bot}{interface}') or '\n' in send_msg:
send_msg = send_msg.strip()
send_msg = send_msg.replace(f'{user}{interface}', '')
send_msg = send_msg.replace(f'{bot}{interface}', '')
send_msg = send_msg.replace('\n', '')
break
save_all_stat('chat', out)
print('Start chatting with Daniel!')
while True:
msg = input(f'{user}{interface} ')
if len(msg.strip()) > 0:
on_message(msg)
else:
print('Error: please say something')
|