gururise commited on
Commit
d5f9f96
·
1 Parent(s): ffde006

Add application file

Browse files
Files changed (2) hide show
  1. app.py +229 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import codecs
3
+ from ast import literal_eval
4
+ from datetime import datetime
5
+ from rwkvstic.load import RWKV
6
+ from rwkvstic.agnostic.backends import TORCH, TORCH_QUANT, TORCH_STREAM
7
+ import torch
8
+ import gc
9
+
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ def to_md(text):
13
+ return text.replace("\n", "<br />")
14
+
15
+
16
+ def get_model():
17
+ model = None
18
+ model = RWKV(
19
+ "https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-Instruct-test1-20230124.pth",
20
+ "pytorch(cpu/gpu)",
21
+ runtimedtype=torch.float32,
22
+ useGPU=torch.cuda.is_available(),
23
+ dtype=torch.float32
24
+ )
25
+ return model
26
+
27
+ model = None
28
+
29
+ def infer(
30
+ prompt,
31
+ mode = "generative",
32
+ max_new_tokens=10,
33
+ temperature=0.1,
34
+ top_p=1.0,
35
+ stop="<|endoftext|>",
36
+ seed=42,
37
+ ):
38
+ global model
39
+
40
+ if model == None:
41
+ gc.collect()
42
+ if (DEVICE == "cuda"):
43
+ torch.cuda.empty_cache()
44
+ model = get_model()
45
+
46
+ max_new_tokens = int(max_new_tokens)
47
+ temperature = float(temperature)
48
+ top_p = float(top_p)
49
+ stop = [x.strip(' ') for x in stop.split(',')]
50
+ seed = seed
51
+
52
+ assert 1 <= max_new_tokens <= 384
53
+ assert 0.0 <= temperature <= 1.0
54
+ assert 0.0 <= top_p <= 1.0
55
+
56
+ if temperature == 0.0:
57
+ temperature = 0.01
58
+ if prompt == "":
59
+ prompt = " "
60
+
61
+ if (mode == "generative"):
62
+ # Clear model state for generative mode
63
+ model.resetState()
64
+ else: # Q/A
65
+ model.resetState()
66
+ prompt = f"Expert Questions & Helpful Answers\nAsk Research Experts\nQuestion:\n{prompt}\n\nFull Answer:"
67
+
68
+ print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}")
69
+ print(f"OUTPUT ({datetime.now()}):\n-------\n")
70
+ # Load prompt
71
+ model.loadContext(newctx=prompt)
72
+ generated_text = ""
73
+ done = False
74
+ with torch.no_grad():
75
+ for _ in range(max_new_tokens):
76
+ char = model.forward(stopStrings=stop,temp=temperature,top_p_usual=top_p)["output"]
77
+ print(char, end='', flush=True)
78
+ generated_text += char
79
+ generated_text = generated_text.lstrip("\n ")
80
+
81
+ for stop_word in stop:
82
+ stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
83
+ if stop_word != '' and stop_word in generated_text:
84
+ done = True
85
+ break
86
+ yield generated_text
87
+ if done:
88
+ print("<stopped>\n")
89
+ break
90
+
91
+ print(f"{generated_text}")
92
+
93
+ for stop_word in stop:
94
+ stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
95
+ if stop_word != '' and stop_word in generated_text:
96
+ generated_text = generated_text[:generated_text.find(stop_word)]
97
+
98
+ gc.collect()
99
+ yield generated_text
100
+
101
+
102
+ def chat(
103
+ prompt,
104
+ history,
105
+ max_new_tokens=10,
106
+ temperature=0.1,
107
+ top_p=1.0,
108
+ stop="<|endoftext|>",
109
+ seed=42,
110
+ ):
111
+ global model
112
+ history = history or []
113
+
114
+ if model == None:
115
+ gc.collect()
116
+ if (DEVICE == "cuda"):
117
+ torch.cuda.empty_cache()
118
+ model = get_model()
119
+
120
+ max_new_tokens = int(max_new_tokens)
121
+ temperature = float(temperature)
122
+ top_p = float(top_p)
123
+ stop = [x.strip(' ') for x in stop.split(',')]
124
+ seed = seed
125
+
126
+ assert 1 <= max_new_tokens <= 384
127
+ assert 0.0 <= temperature <= 1.0
128
+ assert 0.0 <= top_p <= 1.0
129
+
130
+ if temperature == 0.0:
131
+ temperature = 0.01
132
+ if prompt == "":
133
+ prompt = " "
134
+
135
+ print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}")
136
+ print(f"OUTPUT ({datetime.now()}):\n-------\n")
137
+ # Load prompt
138
+ model.loadContext(newctx=prompt)
139
+ generated_text = ""
140
+ done = False
141
+ generated_text = model.forward(number=max_new_tokens, stopStrings=stop,temp=temperature,top_p_usual=top_p)["output"]
142
+
143
+ generated_text = generated_text.lstrip("\n ")
144
+ print(f"{generated_text}")
145
+
146
+ for stop_word in stop:
147
+ stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
148
+ if stop_word != '' and stop_word in generated_text:
149
+ generated_text = generated_text[:generated_text.find(stop_word)]
150
+
151
+ gc.collect()
152
+ history.append((prompt, generated_text))
153
+ return history,history
154
+
155
+
156
+ examples = [
157
+ [
158
+ # Question Answering
159
+ '''What is the capital of Germany?''',"Q/A", 25, 0.2, 1.0, "<|endoftext|>"],
160
+ [
161
+ # Question Answering
162
+ '''Are humans good or bad?''',"Q/A", 150, 0.8, 0.8, "<|endoftext|>"],
163
+ [
164
+ # Chatbot
165
+ '''This is a conversation two AI large language models named Alex and Fritz. They are exploring each other's capabilities, and trying to ask interesting questions of one another to explore the limits of each others AI.
166
+
167
+ Conversation:
168
+ Alex: Good morning, Fritz!
169
+ Fritz:''', "generative", 200, 0.9, 0.9, "\\n\\n,<|endoftext|>"],
170
+ [
171
+ # Generate List
172
+ '''Q. Give me list of fiction books.
173
+ 1. Harry Potter
174
+ 2. Lord of the Rings
175
+ 3. Game of Thrones
176
+
177
+ Q. Give me a list of vegetables.
178
+ 1. Broccoli
179
+ 2. Celery
180
+ 3. Tomatoes
181
+
182
+ Q. Give me a list of car manufacturers.''', "generative", 80, 0.2, 1.0, "\\n\\n,<|endoftext|>"],
183
+ [
184
+ # Natural Language Interface
185
+ '''You are the writing assistant for Stephen King. You have worked in the fiction/horror genre for 30 years. You are a Pulitzer Prize-winning author, and now you are tasked with developing a skeletal outline for his newest novel, set to be completed in the spring of 2024. Create a title and brief description for the first 5 chapters of this work.\n\nTitle:''',"generative", 250, 0.85, 0.85, "<|endoftext|>"]
186
+ ]
187
+
188
+
189
+ iface = gr.Interface(
190
+ fn=infer,
191
+ description='''<p><a href='https://github.com/BlinkDL/RWKV-LM'>RWKV Language Model</a> - RNN With Transformer-level LLM Performance</p>
192
+ <p>Big thank you to <a href='https://www.rftcapital.com'>RFT Capital</a> for providing compute capability for our experiments.</p>''',
193
+ allow_flagging="never",
194
+ inputs=[
195
+ gr.Textbox(lines=20, label="Prompt"), # prompt
196
+ gr.Radio(["generative","Q/A"], value="generative", label="Choose Mode"),
197
+ gr.Slider(1, 384, value=20), # max_tokens
198
+ gr.Slider(0.0, 1.0, value=0.2), # temperature
199
+ gr.Slider(0.0, 1.0, value=0.9), # top_p
200
+ gr.Textbox(lines=1, value="<|endoftext|>") # stop
201
+ ],
202
+ outputs=gr.Textbox(lines=25),
203
+ examples=examples,
204
+ )
205
+
206
+ chatiface = gr.Interface(
207
+ fn=chat,
208
+ description='''<p><a href='https://github.com/BlinkDL/RWKV-LM'>RWKV Language Model</a> - RNN With Transformer-level LLM Performance</p>
209
+ <p>Big thank you to <a href='https://www.rftcapital.com'>RFT Capital</a> for providing compute capability for our experiments.</p>''',
210
+ allow_flagging="never",
211
+ inputs=[
212
+ gr.Textbox(lines=5, label="Message"), # prompt
213
+ "state",
214
+ gr.Slider(1, 384, value=20), # max_tokens
215
+ gr.Slider(0.0, 1.0, value=0.2), # temperature
216
+ gr.Slider(0.0, 1.0, value=0.9), # top_p
217
+ gr.Textbox(lines=1, value="<|endoftext|>,\\n") # stop
218
+ ],
219
+ outputs=[gr.Chatbot(color_map=("green", "pink")),"state"],
220
+ )
221
+
222
+ demo = gr.TabbedInterface(
223
+
224
+ [iface,chatiface],["Generative","Chatbot"],
225
+ title="RWKV-4 (1.5b Instruct)",
226
+
227
+ ).queue()
228
+
229
+ demo.launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ scipy
3
+ torch
4
+ inquirer
5
+ rwkvstic