Hazzzardous commited on
Commit
5e303ea
·
1 Parent(s): b549737

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -73
app.py CHANGED
@@ -2,7 +2,19 @@
2
  RWKV RNN Model - Gradio Space for HuggingFace
3
  YT - Mean Gene Hacks - https://www.youtube.com/@MeanGeneHacks
4
  (C) Gene Ruebsamen - 2/7/2023
5
- License: GPL3
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
 
8
  import gradio as gr
@@ -10,12 +22,18 @@ import codecs
10
  from ast import literal_eval
11
  from datetime import datetime
12
  from rwkvstic.load import RWKV
13
- from rwkvstic.agnostic.backends import TORCH, TORCH_QUANT, TORCH_STREAM
14
  import torch
15
  import gc
16
 
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
 
 
 
 
 
 
 
19
  def to_md(text):
20
  return text.replace("\n", "<br />")
21
 
@@ -23,22 +41,21 @@ def to_md(text):
23
  def get_model():
24
  model = None
25
  model = RWKV(
26
- "https://huggingface.co/BlinkDL/rwkv-4-pile-7b/resolve/main/RWKV-4-Pile-7B-Instruct-test1-20230124.pth",
27
- "pytorch(cpu/gpu)",
28
- runtimedtype=torch.bfloat16,
29
- useGPU=torch.cuda.is_available(),
30
- dtype=torch.bfloat16
31
  )
32
  return model
33
 
 
34
  model = None
35
 
 
36
  def infer(
37
  prompt,
38
- mode = "generative",
39
  max_new_tokens=10,
40
  temperature=0.1,
41
  top_p=1.0,
 
42
  stop="<|endoftext|>",
43
  seed=42,
44
  ):
@@ -49,27 +66,27 @@ def infer(
49
  if (DEVICE == "cuda"):
50
  torch.cuda.empty_cache()
51
  model = get_model()
52
-
53
  max_new_tokens = int(max_new_tokens)
54
  temperature = float(temperature)
 
55
  top_p = float(top_p)
56
- stop = [x.strip(' ') for x in stop.split(',')]
57
  seed = seed
58
 
59
  assert 1 <= max_new_tokens <= 384
60
  assert 0.0 <= temperature <= 1.0
61
  assert 0.0 <= top_p <= 1.0
62
 
63
- if temperature == 0.0:
64
- temperature = 0.01
65
  if prompt == "":
66
  prompt = " "
67
 
68
  # Clear model state for generative mode
69
  model.resetState()
70
  if (mode == "Q/A"):
71
- prompt = f"Q:\n{prompt}\n\nExpert Long Detailed Response:"
72
-
73
  print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}")
74
  print(f"OUTPUT ({datetime.now()}):\n-------\n")
75
  # Load prompt
@@ -78,11 +95,12 @@ def infer(
78
  done = False
79
  with torch.no_grad():
80
  for _ in range(max_new_tokens):
81
- char = model.forward(stopStrings=stop,temp=temperature,top_p_usual=top_p)["output"]
 
82
  print(char, end='', flush=True)
83
  generated_text += char
84
  generated_text = generated_text.lstrip("\n ")
85
-
86
  for stop_word in stop:
87
  stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
88
  if stop_word != '' and stop_word in generated_text:
@@ -93,13 +111,13 @@ def infer(
93
  print("<stopped>\n")
94
  break
95
 
96
- #print(f"{generated_text}")
97
-
98
  for stop_word in stop:
99
  stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
100
  if stop_word != '' and stop_word in generated_text:
101
  generated_text = generated_text[:generated_text.find(stop_word)]
102
-
103
  gc.collect()
104
  yield generated_text
105
 
@@ -107,137 +125,161 @@ def infer(
107
  def chat(
108
  prompt,
109
  history,
 
110
  max_new_tokens=10,
111
  temperature=0.1,
112
  top_p=1.0,
113
- stop="<|endoftext|>",
114
  seed=42,
115
  ):
116
  global model
117
  history = history or []
118
-
 
 
119
  if model == None:
120
  gc.collect()
121
  if (DEVICE == "cuda"):
122
  torch.cuda.empty_cache()
123
  model = get_model()
124
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  if len(history) == 0:
126
  # no history, so lets reset chat state
127
  model.resetState()
128
- history = [[],model.emptyState]
 
129
  else:
130
- model.setState(history[1])
 
 
 
 
 
 
 
131
  max_new_tokens = int(max_new_tokens)
132
  temperature = float(temperature)
133
  top_p = float(top_p)
134
- stop = [x.strip(' ') for x in stop.split(',')]
135
  seed = seed
136
 
137
  assert 1 <= max_new_tokens <= 384
138
  assert 0.0 <= temperature <= 1.0
139
  assert 0.0 <= top_p <= 1.0
140
 
141
- if temperature == 0.0:
142
- temperature = 0.01
143
- if prompt == "":
144
- prompt = " "
145
-
146
  print(f"CHAT ({datetime.now()}):\n-------\n{prompt}")
147
  print(f"OUTPUT ({datetime.now()}):\n-------\n")
148
  # Load prompt
149
- model.loadContext(newctx="\nUser: "+prompt+"\n\nBot: ")
150
- generated_text = ""
151
- done = False
152
- gen = model.forward(number=max_new_tokens, stopStrings=stop+["User"],temp=temperature,top_p_usual=top_p)
153
- generated_text = gen["output"]
154
- generated_text = generated_text.lstrip("\n ")
 
 
155
  print(f"{generated_text}")
156
-
157
- for stop_word in stop:
158
- stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
159
- if stop_word != '' and stop_word in generated_text:
160
- generated_text = generated_text[:generated_text.find(stop_word)]
161
-
162
  gc.collect()
163
  history[0].append((prompt, generated_text))
164
- return history[0],[history[0],gen["state"]]
165
 
166
 
167
  examples = [
168
  [
169
  # Question Answering
170
- '''What is the capital of Germany?''',"Q/A", 25, 0.2, 1.0, "<|endoftext|>"],
171
  [
172
  # Question Answering
173
- '''Are humans good or bad?''',"Q/A", 150, 0.8, 0.8, "<|endoftext|>"],
 
 
 
174
  [
175
  # Chatbot
176
  '''This is a conversation between two AI large language models named Alex and Fritz. They are exploring each other's capabilities, and trying to ask interesting questions of one another to explore the limits of each others AI.
 
177
  Conversation:
178
  Alex: Good morning, Fritz, what type of LLM are you based upon?
179
  Fritz: Morning Alex, I am an RNN with transformer level performance. My language model is 100% attention free.
180
  Alex:''', "generative", 220, 0.9, 0.9, "\\n\\n,<|endoftext|>"],
181
  [
182
  # Generate List
183
- '''Q. Give me list of fiction books.
184
- 1. Harry Potter
185
- 2. Lord of the Rings
186
- 3. Game of Thrones
187
- Q. Give me a list of vegetables.
188
- 1. Broccoli
189
- 2. Celery
190
- 3. Tomatoes
191
- Q. Give me a list of car manufacturers.''', "generative", 80, 0.2, 1.0, "\\n\\n,<|endoftext|>"],
192
  [
193
  # Natural Language Interface
194
- '''You are the writing assistant for Stephen King. You have worked in the fiction/horror genre for 30 years. You are a Pulitzer Prize-winning author, and now you are tasked with developing a skeletal outline for his newest horror novel, set to be completed in the spring of 2024. Create a summary of this work.
195
- Summary:''',"generative", 200, 0.85, 0.8, "<|endoftext|>"]
196
  ]
197
 
198
 
199
  iface = gr.Interface(
200
  fn=infer,
201
- description='''<p>RNN With Transformer-level LLM Performance. (<a href='https://github.com/BlinkDL/RWKV-LM'>github</a>)
202
- According to the author: "It combines the best of RNN and transformers - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding"
203
- <p>Thanks to https://huggingface.co/spaces/yahma/rwkv for the basis for this space</p>''',
204
  allow_flagging="never",
205
  inputs=[
206
  gr.Textbox(lines=20, label="Prompt"), # prompt
207
- gr.Radio(["generative","Q/A"], value="generative", label="Choose Mode"),
 
208
  gr.Slider(1, 256, value=40), # max_tokens
209
  gr.Slider(0.0, 1.0, value=0.8), # temperature
210
  gr.Slider(0.0, 1.0, value=0.85), # top_p
211
- gr.Textbox(lines=1, value="<|endoftext|>") # stop
 
212
  ],
213
- outputs=gr.Textbox(lines=25),
214
  examples=examples,
215
  cache_examples=False,
216
  ).queue()
217
 
218
  chatiface = gr.Interface(
219
  fn=chat,
220
- description='''<p>RNN With Transformer-level LLM Performance. (<a href='https://github.com/BlinkDL/RWKV-LM'>github</a>)
221
- According to the author: "It combines the best of RNN and transformers - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding"
222
- <p>Thanks to https://huggingface.co/spaces/yahma/rwkv for the basis for this space</p>''',
223
  allow_flagging="never",
224
  inputs=[
225
  gr.Textbox(lines=5, label="Message"), # prompt
226
  "state",
 
 
227
  gr.Slider(1, 256, value=60), # max_tokens
228
  gr.Slider(0.0, 1.0, value=0.8), # temperature
229
- gr.Slider(0.0, 1.0, value=0.85), # top_p
230
- gr.Textbox(lines=1, value="<|endoftext|>") # stop
231
  ],
232
- outputs=[gr.Chatbot(color_map=("green", "pink")),"state"],
 
233
  ).queue()
234
 
235
  demo = gr.TabbedInterface(
236
 
237
- [iface,chatiface],["Generative","Chatbot"],
238
- title="RWKV-4 (7b Instruct)",
239
-
240
- )
241
 
242
  demo.queue()
243
- demo.launch(share=False)
 
2
  RWKV RNN Model - Gradio Space for HuggingFace
3
  YT - Mean Gene Hacks - https://www.youtube.com/@MeanGeneHacks
4
  (C) Gene Ruebsamen - 2/7/2023
5
+
6
+ This program is free software: you can redistribute it and/or modify
7
+ it under the terms of the GNU General Public License as published by
8
+ the Free Software Foundation, either version 3 of the License, or
9
+ (at your option) any later version.
10
+
11
+ This program is distributed in the hope that it will be useful,
12
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
+ GNU General Public License for more details.
15
+
16
+ You should have received a copy of the GNU General Public License
17
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
18
  """
19
 
20
  import gradio as gr
 
22
  from ast import literal_eval
23
  from datetime import datetime
24
  from rwkvstic.load import RWKV
25
+ from config import config, title
26
  import torch
27
  import gc
28
 
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
 
31
+ desc = '''<p>RNN with Transformer-level LLM Performance (<a href='https://github.com/BlinkDL/RWKV-LM'>github</a>).
32
+ According to the author: "It combines the best of RNN and transformers - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding."'''
33
+
34
+ thanks = '''<p>Thanks to <a href='https://www.rftcapital.com'>RFT Capital</a> for donating compute capability for our experiments. Additional thanks to the author of the <a href="https://github.com/harrisonvanderbyl/rwkvstic">rwkvstic</a> library.</p>'''
35
+
36
+
37
  def to_md(text):
38
  return text.replace("\n", "<br />")
39
 
 
41
  def get_model():
42
  model = None
43
  model = RWKV(
44
+ **config
 
 
 
 
45
  )
46
  return model
47
 
48
+
49
  model = None
50
 
51
+
52
  def infer(
53
  prompt,
54
+ mode="generative",
55
  max_new_tokens=10,
56
  temperature=0.1,
57
  top_p=1.0,
58
+ end_adj=0.0,
59
  stop="<|endoftext|>",
60
  seed=42,
61
  ):
 
66
  if (DEVICE == "cuda"):
67
  torch.cuda.empty_cache()
68
  model = get_model()
69
+
70
  max_new_tokens = int(max_new_tokens)
71
  temperature = float(temperature)
72
+ end_adj = float(end_adj)
73
  top_p = float(top_p)
74
+ stop = [x.strip(' ') for x in stop.split(',')]
75
  seed = seed
76
 
77
  assert 1 <= max_new_tokens <= 384
78
  assert 0.0 <= temperature <= 1.0
79
  assert 0.0 <= top_p <= 1.0
80
 
81
+ temperature = max(0.05, temperature)
 
82
  if prompt == "":
83
  prompt = " "
84
 
85
  # Clear model state for generative mode
86
  model.resetState()
87
  if (mode == "Q/A"):
88
+ prompt = f"Ask Expert\n\nQuestion:\n{prompt}\n\nExpert Full Answer:\n"
89
+
90
  print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}")
91
  print(f"OUTPUT ({datetime.now()}):\n-------\n")
92
  # Load prompt
 
95
  done = False
96
  with torch.no_grad():
97
  for _ in range(max_new_tokens):
98
+ char = model.forward(stopStrings=stop, temp=temperature, top_p_usual=top_p)[
99
+ "output"]
100
  print(char, end='', flush=True)
101
  generated_text += char
102
  generated_text = generated_text.lstrip("\n ")
103
+
104
  for stop_word in stop:
105
  stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
106
  if stop_word != '' and stop_word in generated_text:
 
111
  print("<stopped>\n")
112
  break
113
 
114
+ # print(f"{generated_text}")
115
+
116
  for stop_word in stop:
117
  stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
118
  if stop_word != '' and stop_word in generated_text:
119
  generated_text = generated_text[:generated_text.find(stop_word)]
120
+
121
  gc.collect()
122
  yield generated_text
123
 
 
125
  def chat(
126
  prompt,
127
  history,
128
+ username,
129
  max_new_tokens=10,
130
  temperature=0.1,
131
  top_p=1.0,
 
132
  seed=42,
133
  ):
134
  global model
135
  history = history or []
136
+
137
+ intro = ""
138
+
139
  if model == None:
140
  gc.collect()
141
  if (DEVICE == "cuda"):
142
  torch.cuda.empty_cache()
143
  model = get_model()
144
+
145
+ username = username.strip()
146
+ username = username or "USER"
147
+
148
+ intro = f'''The following is a verbose and detailed conversation between an AI assistant called FRITZ, and a human user called USER. FRITZ is intelligent, knowledgeable, wise and polite.
149
+
150
+ {username}: What year was the french revolution?
151
+ FRITZ: The French Revolution started in 1789, and lasted 10 years until 1799.
152
+ {username}: 3+5=?
153
+ FRITZ: The answer is 8.
154
+ {username}: What year did the Berlin Wall fall?
155
+ FRITZ: The Berlin wall stood for 28 years and fell in 1989.
156
+ {username}: solve for a: 9-a=2
157
+ FRITZ: The answer is a=7, because 9-7 = 2.
158
+ {username}: wat is lhc
159
+ FRITZ: The Large Hadron Collider (LHC) is a high-energy particle collider, built by CERN, and completed in 2008. It was used to confirm the existence of the Higgs boson in 2012.
160
+ {username}: Tell me about yourself.
161
+ FRITZ: My name is Fritz. I am an RNN based Large Language Model (LLM).
162
+ '''
163
+
164
  if len(history) == 0:
165
  # no history, so lets reset chat state
166
  model.resetState()
167
+ history = [[], model.emptyState]
168
+ print("reset chat state")
169
  else:
170
+ if (history[0][0][0].split(':')[0] != username):
171
+ model.resetState()
172
+ history = [[], model.emptyState]
173
+ print("username changed, reset state")
174
+ else:
175
+ model.setState(history[1])
176
+ intro = ""
177
+
178
  max_new_tokens = int(max_new_tokens)
179
  temperature = float(temperature)
180
  top_p = float(top_p)
 
181
  seed = seed
182
 
183
  assert 1 <= max_new_tokens <= 384
184
  assert 0.0 <= temperature <= 1.0
185
  assert 0.0 <= top_p <= 1.0
186
 
187
+ temperature = max(0.05, temperature)
188
+
189
+ prompt = f"{username}: " + prompt + "\n"
 
 
190
  print(f"CHAT ({datetime.now()}):\n-------\n{prompt}")
191
  print(f"OUTPUT ({datetime.now()}):\n-------\n")
192
  # Load prompt
193
+
194
+ model.loadContext(newctx=intro+prompt)
195
+
196
+ out = model.forward(number=max_new_tokens, stopStrings=[
197
+ "<|endoftext|>", username+":"], temp=temperature, top_p_usual=top_p)
198
+
199
+ generated_text = out["output"].lstrip("\n ")
200
+ generated_text = generated_text.rstrip("USER:")
201
  print(f"{generated_text}")
202
+
 
 
 
 
 
203
  gc.collect()
204
  history[0].append((prompt, generated_text))
205
+ return history[0], [history[0], out["state"]]
206
 
207
 
208
  examples = [
209
  [
210
  # Question Answering
211
+ '''What is the capital of Germany?''', "Q/A", 25, 0.2, 1.0, "<|endoftext|>"],
212
  [
213
  # Question Answering
214
+ '''Are humans good or bad?''', "Q/A", 150, 0.8, 0.8, "<|endoftext|>"],
215
+ [
216
+ # Question Answering
217
+ '''What is the purpose of Vitamin A?''', "Q/A", 50, 0.2, 0.8, "<|endoftext|>"],
218
  [
219
  # Chatbot
220
  '''This is a conversation between two AI large language models named Alex and Fritz. They are exploring each other's capabilities, and trying to ask interesting questions of one another to explore the limits of each others AI.
221
+
222
  Conversation:
223
  Alex: Good morning, Fritz, what type of LLM are you based upon?
224
  Fritz: Morning Alex, I am an RNN with transformer level performance. My language model is 100% attention free.
225
  Alex:''', "generative", 220, 0.9, 0.9, "\\n\\n,<|endoftext|>"],
226
  [
227
  # Generate List
228
+ '''Task given:
229
+
230
+ Please Write a Short story about a cat learning python
231
+
232
+ Best Full Response:
233
+ ''', "generative", 140, 0.85, 0.8, "<|endoftext|>"],
 
 
 
234
  [
235
  # Natural Language Interface
236
+ '''Here is a short story (in the style of Tolkien) in which Aiden attacks a robot with a sword:
237
+ ''', "generative", 140, 0.85, 0.8, "<|endoftext|>"]
238
  ]
239
 
240
 
241
  iface = gr.Interface(
242
  fn=infer,
243
+ description=f'''<h3>Generative and Question/Answer</h3>{desc}{thanks}''',
 
 
244
  allow_flagging="never",
245
  inputs=[
246
  gr.Textbox(lines=20, label="Prompt"), # prompt
247
+ gr.Radio(["generative", "Q/A"],
248
+ value="generative", label="Choose Mode"),
249
  gr.Slider(1, 256, value=40), # max_tokens
250
  gr.Slider(0.0, 1.0, value=0.8), # temperature
251
  gr.Slider(0.0, 1.0, value=0.85), # top_p
252
+ gr.Slider(-999, 0.0, value=0.0) # end_adj
253
+ gr.Textbox(lines=1, value="<|endoftext|>") # stop
254
  ],
255
+ outputs=gr.Textbox(label="Generated Output", lines=25),
256
  examples=examples,
257
  cache_examples=False,
258
  ).queue()
259
 
260
  chatiface = gr.Interface(
261
  fn=chat,
262
+ description=f'''<h3>Chatbot</h3><h4>Refresh page or change name to reset memory context</h4>{desc}{thanks}''',
 
 
263
  allow_flagging="never",
264
  inputs=[
265
  gr.Textbox(lines=5, label="Message"), # prompt
266
  "state",
267
+ gr.Text(lines=1, value="USER", label="Your Name",
268
+ placeholder="Enter your Name"),
269
  gr.Slider(1, 256, value=60), # max_tokens
270
  gr.Slider(0.0, 1.0, value=0.8), # temperature
271
+ gr.Slider(0.0, 1.0, value=0.85) # top_p
 
272
  ],
273
+ outputs=[gr.Chatbot(label="Chat Log", color_map=(
274
+ "green", "pink")), "state"],
275
  ).queue()
276
 
277
  demo = gr.TabbedInterface(
278
 
279
+ [iface, chatiface], ["Generative", "Chatbot"],
280
+ title=title,
281
+
282
+ )
283
 
284
  demo.queue()
285
+ demo.launch(share=False)