File size: 899 Bytes
561ca81
ab85003
561ca81
2432281
ab85003
b74218d
2432281
b74218d
561ca81
 
 
 
 
 
 
 
d903d0d
f27b8ce
561ca81
 
 
 
c2a02de
 
561ca81
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from rwkvstic.load import RWKV
import torch
model = RWKV(
    "https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-Instruct-test1-20230124.pth",
    "pytorch(cpu/gpu)",
    runtimedtype=torch.float32,
    useGPU=torch.cuda.is_available(),
    dtype=torch.float32
)
import gradio as gr


def predict(input, history=None):
    model.setState(history)
    model.loadContext(newctx=f"{input}\n\nBot: ")
    r = model.forward(number=100,stopStrings=["User: "])
    rr = [(input,r["output"])]
    return rr, r["state"]

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    state = model.emptyState
    ctx, state = model.loadContext(newctx="User: ")
    state = gr.State(state)
    with gr.Row():
        txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)

    txt.submit(predict, [txt, state], [chatbot, state])

demo.launch()