File size: 7,404 Bytes
d5f9f96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73fc793
10895fa
d5f9f96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a5f13b
d5f9f96
 
 
 
 
 
3a5f13b
d5f9f96
3a5f13b
d5f9f96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import gradio as gr
import codecs
from ast import literal_eval
from datetime import datetime
from rwkvstic.load import RWKV
from rwkvstic.agnostic.backends import TORCH, TORCH_QUANT, TORCH_STREAM
import torch
import gc

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def to_md(text):
    return text.replace("\n", "<br />")


def get_model():
    model = None
    model = RWKV(
        "https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-Instruct-test1-20230124.pth",
        "pytorch(cpu/gpu)",
        runtimedtype=torch.float32,
        useGPU=torch.cuda.is_available(),
        dtype=torch.float32
    )
    return model

model = None

def infer(
        prompt,
        mode = "generative",
        max_new_tokens=10,
        temperature=0.1,
        top_p=1.0,
        stop="<|endoftext|>",
        seed=42,
):
    global model

    if model == None:
        gc.collect()
        if (DEVICE == "cuda"):
            torch.cuda.empty_cache()
        model = get_model()
        
    max_new_tokens = int(max_new_tokens)
    temperature = float(temperature)
    top_p = float(top_p)
    stop =  [x.strip(' ') for x in stop.split(',')]
    seed = seed

    assert 1 <= max_new_tokens <= 384
    assert 0.0 <= temperature <= 1.0
    assert 0.0 <= top_p <= 1.0

    if temperature == 0.0:
        temperature = 0.01
    if prompt == "":
        prompt = " "

    if (mode == "generative"):
        # Clear model state for generative mode
        model.resetState()
    else: # Q/A
        model.resetState()
        prompt = f"Expert Questions & Helpful Answers\nAsk Research Experts\nQuestion:\n{prompt}\n\nFull Answer:"
    
    print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}")
    print(f"OUTPUT ({datetime.now()}):\n-------\n")
    # Load prompt
    model.loadContext(newctx=prompt)
    generated_text = ""
    done = False
    with torch.no_grad():
        for _ in range(max_new_tokens):
            char = model.forward(stopStrings=stop,temp=temperature,top_p_usual=top_p)["output"]
            print(char, end='', flush=True)
            generated_text += char
            generated_text = generated_text.lstrip("\n ")
            
            for stop_word in stop:
                stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
                if stop_word != '' and stop_word in generated_text:
                    done = True
                    break
            yield generated_text
            if done:
                print("<stopped>\n")
                break

    print(f"{generated_text}")
    
    for stop_word in stop:
        stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
        if stop_word != '' and stop_word in generated_text:
            generated_text = generated_text[:generated_text.find(stop_word)]
    
    gc.collect()
    yield generated_text


def chat(
        prompt,
        history,
        max_new_tokens=10,
        temperature=0.1,
        top_p=1.0,
        stop="<|endoftext|>",
        seed=42,
):
    global model
    history = history or []

    if model == None:
        gc.collect()
        if (DEVICE == "cuda"):
            torch.cuda.empty_cache()
        model = get_model()
        
    max_new_tokens = int(max_new_tokens)
    temperature = float(temperature)
    top_p = float(top_p)
    stop =  [x.strip(' ') for x in stop.split(',')]
    seed = seed

    assert 1 <= max_new_tokens <= 384
    assert 0.0 <= temperature <= 1.0
    assert 0.0 <= top_p <= 1.0

    if temperature == 0.0:
        temperature = 0.01
    if prompt == "":
        prompt = " "
    
    print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}")
    print(f"OUTPUT ({datetime.now()}):\n-------\n")
    # Load prompt
    model.loadContext(newctx=prompt)
    generated_text = ""
    done = False
    generated_text = model.forward(number=max_new_tokens, stopStrings=stop,temp=temperature,top_p_usual=top_p)["output"]

    generated_text = generated_text.lstrip("\n ")
    print(f"{generated_text}")
    
    for stop_word in stop:
        stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
        if stop_word != '' and stop_word in generated_text:
            generated_text = generated_text[:generated_text.find(stop_word)]
    
    gc.collect()
    history.append((prompt, generated_text))
    return history,history


examples = [
    [
        # Question Answering
        '''What is the capital of Germany?''',"Q/A", 25, 0.2, 1.0, "<|endoftext|>"],
    [
        # Question Answering
        '''Are humans good or bad?''',"Q/A", 150, 0.8, 0.8, "<|endoftext|>"],
    [
        # Chatbot
        '''This is a conversation two AI large language models named Alex and Fritz. They are exploring each other's capabilities, and trying to ask interesting questions of one another to explore the limits of each others AI.

Conversation:
Alex: Good morning, Fritz!
Fritz:''', "generative", 200, 0.9, 0.9, "\\n\\n,<|endoftext|>"],
    [
        # Generate List
        '''Q. Give me list of fiction books. 
1. Harry Potter
2. Lord of the Rings
3. Game of Thrones

Q. Give me a list of vegetables.
1. Broccoli
2. Celery
3. Tomatoes

Q. Give me a list of car manufacturers.''', "generative", 80, 0.2, 1.0, "\\n\\n,<|endoftext|>"],
    [
        # Natural Language Interface
        '''You are the writing assistant for Stephen King. You have worked in the fiction/horror genre for 30 years. You are a Pulitzer Prize-winning author, and now you are tasked with developing a skeletal outline for his newest novel, set to be completed in the spring of 2024. Create a title and brief description for the first 5 chapters of this work.\n\nTitle:''',"generative", 250, 0.85, 0.85, "<|endoftext|>"]
]


iface = gr.Interface(
    fn=infer,
    description='''<p><a href='https://github.com/BlinkDL/RWKV-LM'>RWKV Language Model</a> - RNN With Transformer-level LLM Performance</p>
    <p>Big thank you to <a href='https://www.rftcapital.com'>RFT Capital</a> for providing compute capability for our experiments.</p>''',
    allow_flagging="never",
    inputs=[
        gr.Textbox(lines=20, label="Prompt"),  # prompt
        gr.Radio(["generative","Q/A"], value="generative", label="Choose Mode"),
        gr.Slider(1, 384, value=20),  # max_tokens
        gr.Slider(0.0, 1.0, value=0.2),  # temperature
        gr.Slider(0.0, 1.0, value=0.9),  # top_p
        gr.Textbox(lines=1, value="<|endoftext|>") # stop
    ],
    outputs=gr.Textbox(lines=25),
    examples=examples,
    cache_examples=False,
).queue()

chatiface = gr.Interface(
    fn=chat,
    description='''<p><a href='https://github.com/BlinkDL/RWKV-LM'>RWKV Language Model</a> - RNN With Transformer-level LLM Performance</p>
    <p>Big thank you to <a href='https://www.rftcapital.com'>RFT Capital</a> for providing compute capability for our experiments.</p>''',
    allow_flagging="never",
    inputs=[
        gr.Textbox(lines=5, label="Message"),  # prompt
        "state",
        gr.Slider(1, 384, value=20),  # max_tokens
        gr.Slider(0.0, 1.0, value=0.2),  # temperature
        gr.Slider(0.0, 1.0, value=0.9),  # top_p
        gr.Textbox(lines=1, value="<|endoftext|>,\\n") # stop
    ],
    outputs=[gr.Chatbot(color_map=("green", "pink")),"state"],
).queue()

demo = gr.TabbedInterface(

    [iface,chatiface],["Generative","Chatbot"],
    title="RWKV-4 (1.5b Instruct)",
    
    )

demo.queue()
demo.launch(share=False)