Chrisneverdie commited on
Commit
b4f6221
·
verified ·
1 Parent(s): 26ef706

Upload 3 files

Browse files
Files changed (3) hide show
  1. OnlySportsLM.pth +3 -0
  2. RWKV_v6_demo.py +269 -0
  3. rwkv_vocab_v20230424.txt +0 -0
OnlySportsLM.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4db460f0fcd2f5ee4f8fa9c12f6a3ef03cce42b5823e50873a6f8255f66bcac
3
+ size 392730658
RWKV_v6_demo.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ # The OnlySports Collection - https://huggingface.co/collections/Chrisneverdie/onlysports-66b3e5cf595eb81220cc27a6
4
+ ########################################################################################################
5
+
6
+ import numpy as np
7
+ np.set_printoptions(precision=4, suppress=True, linewidth=200)
8
+ import types, torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ import os
12
+
13
+
14
+ MyModule = torch.jit.ScriptModule
15
+ MyFunction = torch.jit.script_method
16
+
17
+ class RWKV_TOKENIZER():
18
+ table: list[list[list[bytes]]]
19
+ good: list[set[int]]
20
+ wlen: list[int]
21
+ def __init__(self, file_name):
22
+ self.idx2token = {}
23
+ sorted = [] # must be already sorted
24
+ lines = open(file_name, "r", encoding="utf-8").readlines()
25
+ for l in lines:
26
+ idx = int(l[:l.index(' ')])
27
+ x = eval(l[l.index(' '):l.rindex(' ')])
28
+ x = x.encode("utf-8") if isinstance(x, str) else x
29
+ assert isinstance(x, bytes)
30
+ assert len(x) == int(l[l.rindex(' '):])
31
+ sorted += [x]
32
+ self.idx2token[idx] = x
33
+
34
+ self.token2idx = {}
35
+ for k, v in self.idx2token.items():
36
+ self.token2idx[v] = int(k)
37
+
38
+ # precompute some tables for fast matching
39
+ self.table = [[[] for j in range(256)] for i in range(256)]
40
+ self.good = [set() for i in range(256)]
41
+ self.wlen = [0 for i in range(256)]
42
+
43
+ for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
44
+ s = sorted[i]
45
+ if len(s) >= 2:
46
+ s0 = int(s[0])
47
+ s1 = int(s[1])
48
+ self.table[s0][s1] += [s]
49
+ self.wlen[s0] = max(self.wlen[s0], len(s))
50
+ self.good[s0].add(s1)
51
+
52
+ def encodeBytes(self, src: bytes) -> list[int]:
53
+ src_len: int = len(src)
54
+ tokens: list[int] = []
55
+ i: int = 0
56
+ while i < src_len:
57
+ s: bytes = src[i : i + 1]
58
+
59
+ if i < src_len - 1:
60
+ s1: int = int(src[i + 1])
61
+ s0: int = int(src[i])
62
+ if s1 in self.good[s0]:
63
+ sss: bytes = src[i : i + self.wlen[s0]]
64
+ try:
65
+ s = next(filter(sss.startswith, self.table[s0][s1]))
66
+ except:
67
+ pass
68
+ tokens.append(self.token2idx[s])
69
+ i += len(s)
70
+
71
+ return tokens
72
+
73
+ def decodeBytes(self, tokens):
74
+ return b''.join(map(lambda i: self.idx2token[i], tokens))
75
+
76
+ def encode(self, src: str):
77
+ return self.encodeBytes(src.encode("utf-8"))
78
+
79
+ def decode(self, tokens):
80
+ return self.decodeBytes(tokens).decode('utf-8')
81
+
82
+ def printTokens(self, tokens):
83
+ for i in tokens:
84
+ s = self.idx2token[i]
85
+ try:
86
+ s = s.decode('utf-8')
87
+ except:
88
+ pass
89
+ print(f'{repr(s)}{i}', end=' ')
90
+ # print(repr(s), i)
91
+ print()
92
+
93
+ ########################################################################################################
94
+
95
+ def sample_logits(out, temperature=1.0, top_p=0.8):
96
+ probs = F.softmax(out, dim=-1).numpy()
97
+ sorted_probs = np.sort(probs)[::-1]
98
+ cumulative_probs = np.cumsum(sorted_probs)
99
+ cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
100
+ probs[probs < cutoff] = 0
101
+ if temperature != 1.0:
102
+ probs = probs**(1.0 / temperature)
103
+ probs = probs / np.sum(probs)
104
+ out = np.random.choice(a=len(probs), p=probs)
105
+ return out
106
+
107
+ ########################################################################################################
108
+
109
+ tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt")
110
+
111
+ args = types.SimpleNamespace()
112
+
113
+ args.MODEL_NAME = 'OnlySportsLM'
114
+
115
+
116
+ args.n_layer = 20
117
+ args.n_embd =640
118
+ args.vocab_size = 65536
119
+
120
+ context ="""Kobe Bryant"""
121
+
122
+ NUM_TRIALS = 1
123
+ LENGTH_PER_TRIAL = 120
124
+ TEMPERATURE = 0.5
125
+ TOP_P = 0.7
126
+
127
+ class RWKV_RNN(MyModule):
128
+ def __init__(self, args):
129
+ super().__init__()
130
+ self.args = args
131
+ self.eval() # set torch to inference mode
132
+
133
+ w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
134
+
135
+ for k in w.keys():
136
+ w[k] = w[k].float() # convert to f32 type
137
+ if '.time_' in k: w[k] = w[k].squeeze()
138
+ if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
139
+
140
+ self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
141
+ self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
142
+
143
+ self.w = types.SimpleNamespace() # set self.w from w
144
+ self.w.blocks = {}
145
+ for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
146
+ parts = k.split('.')
147
+ last = parts.pop()
148
+ here = self.w
149
+ for p in parts:
150
+ if p.isdigit():
151
+ p = int(p)
152
+ if p not in here: here[p] = types.SimpleNamespace()
153
+ here = here[p]
154
+ else:
155
+ if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
156
+ here = getattr(here, p)
157
+ setattr(here, last, w[k])
158
+
159
+ def layer_norm(self, x, w):
160
+ return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
161
+
162
+ @MyFunction
163
+ def channel_mixing(self, x, state, i:int, time_maa_k, time_maa_r, kw, vw, rw):
164
+ i0 = (2+self.head_size)*i+0
165
+ sx = state[i0] - x
166
+ xk = x + sx * time_maa_k
167
+ xr = x + sx * time_maa_r
168
+ state[i0] = x
169
+ r = torch.sigmoid(rw @ xr)
170
+ k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
171
+ return r * (vw @ k)
172
+
173
+ @MyFunction
174
+ def time_mixing(self, x, state, i:int, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
175
+ H = self.n_head
176
+ S = self.head_size
177
+
178
+ i1 = (2+S)*i+1
179
+ sx = state[i1] - x
180
+ state[i1] = x
181
+ xxx = x + sx * x_maa
182
+ xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)
183
+ xxx = torch.bmm(xxx, tm_w2).view(5, -1)
184
+ mw, mk, mv, mr, mg = xxx.unbind(dim=0)
185
+
186
+ xw = x + sx * (w_maa + mw)
187
+ xk = x + sx * (k_maa + mk)
188
+ xv = x + sx * (v_maa + mv)
189
+ xr = x + sx * (r_maa + mr)
190
+ xg = x + sx * (g_maa + mg)
191
+
192
+ w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, S, 1)
193
+ w = torch.exp(-torch.exp(w.float()))
194
+
195
+ r = (rw @ xr).view(H, 1, S)
196
+ k = (kw @ xk).view(H, S, 1)
197
+ v = (vw @ xv).view(H, 1, S)
198
+ g = F.silu(gw @ xg)
199
+
200
+ s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)
201
+
202
+ x = torch.zeros(H, S)
203
+ a = k @ v
204
+ x = r @ (time_first * a + s)
205
+ s = a + w * s
206
+
207
+ state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)
208
+ x = x.flatten()
209
+
210
+ x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)
211
+ return ow @ x
212
+
213
+ def forward(self, token, state):
214
+ with torch.no_grad():
215
+ if state == None:
216
+ state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)
217
+
218
+ x = self.w.emb.weight[token]
219
+ x = self.layer_norm(x, self.w.blocks[0].ln0)
220
+ for i in range(self.args.n_layer):
221
+ att = self.w.blocks[i].att
222
+ x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
223
+ att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2,
224
+ att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,
225
+ att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,
226
+ att.ln_x.weight, att.ln_x.bias)
227
+ ffn = self.w.blocks[i].ffn
228
+ x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
229
+ ffn.time_maa_k, ffn.time_maa_r,
230
+ ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
231
+
232
+ x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
233
+ return x.float(), state
234
+
235
+ print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
236
+ model = RWKV_RNN(args)
237
+
238
+ #print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
239
+ init_state = None
240
+ for token in tokenizer.encode(context):
241
+ init_out, init_state = model.forward(token, init_state)
242
+
243
+ for TRIAL in range(NUM_TRIALS):
244
+ print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
245
+ all_tokens = []
246
+ out_last = 0
247
+ out_str = ''
248
+
249
+ out, state = init_out.clone(), init_state.clone()
250
+ for i in range(LENGTH_PER_TRIAL):
251
+ token = sample_logits(out, TEMPERATURE, TOP_P)
252
+ all_tokens += [token]
253
+ try:
254
+ tmp = tokenizer.decode(all_tokens[out_last:])
255
+ if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
256
+ #ans.append(tmp)
257
+ print(tmp, end="", flush=True)
258
+ #print(tmp, end="", flush=True)
259
+ #print(tmp)
260
+ out_last = i + 1
261
+ if '\ufffd' not in tmp: # is valid utf-8 string?
262
+ out_str += tmp
263
+ out_last = i + 1
264
+ except:
265
+ pass
266
+ out, state = model.forward(token, state)
267
+
268
+
269
+ print('\n')
rwkv_vocab_v20230424.txt ADDED
The diff for this file is too large to render. See raw diff