File size: 6,992 Bytes
7464087
 
 
 
 
290c9c9
 
 
7464087
cf5b288
7464087
 
 
 
 
 
 
 
0692d85
cf5b288
 
23694dc
7464087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e997b3
 
 
 
 
 
 
7464087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b62f67
 
7464087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b62f67
7464087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b62f67
7464087
 
 
 
3e997b3
7464087
 
7b62f67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7464087
7b62f67
 
 
 
7464087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import gradio as gr
import argparse
import os, gc, torch
from datetime import datetime
from huggingface_hub import hf_hub_download
# from pynvml import *
# nvmlInit()
# gpu_h = nvmlDeviceGetHandleByIndex(0)
ctx_limit = 4096
desc = f'''链接:<a href='https://colab.research.google.com/drive/1J1gLMMMA8GbD9JuQt6OKmwCTl9mWU0bb?usp=sharing'>太慢了?用Colab自己部署吧</a> <br /> <a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 0.5em">ChatRWKV</a><a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 0.5em">RWKV-LM</a><a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 0.5em">RWKV pip package</a><a href="https://zhuanlan.zhihu.com/p/618011122" target="_blank" style="margin:0 0.5em">知乎教程</a>
'''

parser = argparse.ArgumentParser(prog = 'ChatGal RWKV')
parser.add_argument('--share',action='store_true') 
args = parser.parse_args()
os.environ["RWKV_JIT_ON"] = '1'

from rwkv.model import RWKV
model_path = hf_hub_download(repo_id="Synthia/ChatGalRWKV", filename="rwkv-chatgal-v1-3B-ctx4096-epoch2.pth")
if 'ON_COLAB' in os.environ and os.environ['ON_COLAB'] == '1':
    os.environ["RWKV_JIT_ON"] = '0'
    os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
    model = RWKV(model=model_path, strategy='cuda bf16')
else:
    model = RWKV(model=model_path, strategy='cpu bf16')
from rwkv.utils import PIPELINE, PIPELINE_ARGS
pipeline = PIPELINE(model, "20B_tokenizer.json")

def infer(
        ctx,
        token_count=10,
        temperature=0.7,
        top_p=1.0,
        presencePenalty = 0.05,
        countPenalty = 0.05,
):
    args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
                     alpha_frequency = countPenalty,
                     alpha_presence = presencePenalty,
                     token_ban = [0], # ban the generation of some tokens
                     token_stop = []) # stop generation whenever you see any token here

    # ctx = ctx.strip().split('\n')
    # for c in range(len(ctx)):
    #     ctx[c] = ctx[c].strip().strip('\u3000').strip('\r')
    # ctx = list(filter(lambda c: c != '', ctx))
    # ctx = '\n' + ('\n'.join(ctx)).strip()
    # if ctx == '':
    #     ctx = '\n'

    # gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
    # print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}',flush=True)
    
    all_tokens = []
    out_last = 0
    out_str = ''
    occurrence = {}
    state = None
    for i in range(int(token_count)):
        out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
        for n in args.token_ban:
            out[n] = -float('inf')
        for n in occurrence:
            out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)

        token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
        if token in args.token_stop:
            break
        all_tokens += [token]
        if token not in occurrence:
            occurrence[token] = 1
        else:
            occurrence[token] += 1
        
        tmp = pipeline.decode(all_tokens[out_last:])
        if '\ufffd' not in tmp:
            out_str += tmp
            yield out_str
            out_last = i + 1
    gc.collect()
    torch.cuda.empty_cache()
    yield out_str

examples = [
    ["""女招待: 欢迎光临。您远道而来,想必一定很累了吧?

深见: 不会……空气也清爽,也让我焕然一新呢

女招待: 是吗。那真是太好了

我因为撰稿的需要,而造访了这间位于信州山间的温泉宿驿。""", 200, 2.0, 0.4, 0.1, 0.1],
    ["翡翠: 欢迎回来,志贵少爷。", 200, 2.0, 0.4, 0.1, 0.1],
    ["""莲华: 你的目的,就是这个万华镜吧?

莲华拿出了万华镜。

深见: 啊……

好像被万华镜拽过去了一般,我的腿不由自主地向它迈去

深见: 是这个……就是这个啊……

烨烨生辉的魔法玩具。
连接现实与梦之世界的、诱惑的桥梁。

深见: 请让我好好看看……

我刚想把手伸过去,莲华就一下子把它收了回去。""", 200, 2.0, 0.4, 0.1, 0.1],
    ["""嘉祥: 偶尔来一次也不错。

我坐到客厅的沙发上,拍了拍自己的大腿。

巧克力&香草: 喵喵?

巧克力: 咕噜咕噜咕噜~♪人家最喜欢让主人掏耳朵了~♪

巧克力: 主人好久都没有帮我们掏耳朵了,现在人家超兴奋的喵~♪

香草: 身为猫娘饲主,这点服务也是应该的对吧?

香草: 老实说我也有点兴奋呢咕噜咕噜咕噜~♪

我摸摸各自占据住我左右两腿的两颗猫头。

嘉祥: 开心归开心,拜托你们俩别一直乱动啊,很危险的。""", 200, 2.0, 0.4, 0.1, 0.1],
]

iface = gr.Interface(
    fn=infer,
    description=f'''这是GalGame剧本续写模型(实验性质,不保证效果)。<b>请点击例子(在页面底部)</b>,可编辑内容。这里只看输入的最后约1200字,请写好,标点规范,无错别字,否则电脑会模仿你的错误。<b>为避免占用资源,每次生成限制长度。可将输出内容复制到输入,然后继续生成</b>。推荐提高temp改善文采,降低topp改善逻辑,提高两个penalty避免重复,具体幅度请自己实验。<br /> {desc}''',
    allow_flagging="never",
    inputs=[
        gr.Textbox(lines=10, label="Prompt 输入的前文", value="""嘉祥: 偶尔来一次也不错。

我坐到客厅的沙发上,拍了拍自己的大腿。

巧克力&香草: 喵喵?

巧克力: 咕噜咕噜咕噜~♪人家最喜欢让主人掏耳朵了~♪

巧克力: 主人好久都没有帮我们掏耳朵了,现在人家超兴奋的喵~♪

香草: 身为猫娘饲主,这点服务也是应该的对吧?

香草: 老实说我也有点兴奋呢咕噜咕噜咕噜~♪

我摸摸各自占据住我左右两腿的两颗猫头。

嘉祥: 开心归开心,拜托你们俩别一直乱动啊,很危险的。"""),  # prompt
        gr.Slider(10, 200, step=10, value=200, label="token_count 每次生成的长度"),  # token_count
        gr.Slider(0.2, 2.0, step=0.1, value=2, label="temperature 默认0.7,高则变化丰富,低则保守求稳"),  # temperature
        gr.Slider(0.0, 1.0, step=0.05, value=0.4, label="top_p 默认1.0,高则标新立异,低则循规蹈矩"),  # top_p
        gr.Slider(0.0, 1.0, step=0.1, value=0.1, label="presencePenalty 默认0.0,避免写过的类似字"),  # presencePenalty
        gr.Slider(0.0, 1.0, step=0.1, value=0.1, label="countPenalty 默认0.0,额外避免写过多次的类似字"),  # countPenalty
    ],
    outputs=gr.Textbox(label="Output 输出的续写", lines=28),
    examples=examples,
    cache_examples=False,
).queue()

demo = gr.TabbedInterface(
    [iface], ["Generative"]
)

demo.queue(max_size=5)
if args.share:
    demo.launch(share=True)
else:
    demo.launch(share=False)