import gradio as gr import argparse import os, gc, torch from datetime import datetime from huggingface_hub import hf_hub_download # from pynvml import * # nvmlInit() # gpu_h = nvmlDeviceGetHandleByIndex(0) ctx_limit = 4096 desc = f'''链接:太慢了?用Colab自己部署吧
ChatRWKVRWKV-LMRWKV pip package知乎教程 ''' parser = argparse.ArgumentParser(prog = 'ChatGal RWKV') parser.add_argument('--share',action='store_true') args = parser.parse_args() os.environ["RWKV_JIT_ON"] = '1' from rwkv.model import RWKV model_path = hf_hub_download(repo_id="Synthia/ChatGalRWKV", filename="rwkv-chatgal-v1-3B-ctx4096-epoch2.pth") if 'ON_COLAB' in os.environ and os.environ['ON_COLAB'] == '1': os.environ["RWKV_JIT_ON"] = '0' os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster) model = RWKV(model=model_path, strategy='cuda bf16') else: model = RWKV(model=model_path, strategy='cpu bf16') from rwkv.utils import PIPELINE, PIPELINE_ARGS pipeline = PIPELINE(model, "20B_tokenizer.json") def infer( ctx, token_count=10, temperature=0.7, top_p=1.0, presencePenalty = 0.05, countPenalty = 0.05, ): args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p), alpha_frequency = countPenalty, alpha_presence = presencePenalty, token_ban = [0], # ban the generation of some tokens token_stop = []) # stop generation whenever you see any token here # ctx = ctx.strip().split('\n') # for c in range(len(ctx)): # ctx[c] = ctx[c].strip().strip('\u3000').strip('\r') # ctx = list(filter(lambda c: c != '', ctx)) # ctx = '\n' + ('\n'.join(ctx)).strip() # if ctx == '': # ctx = '\n' # gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) # print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}',flush=True) all_tokens = [] out_last = 0 out_str = '' occurrence = {} state = None for i in range(int(token_count)): out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state) for n in args.token_ban: out[n] = -float('inf') for n in occurrence: out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p) if token in args.token_stop: break all_tokens += [token] if token not in occurrence: occurrence[token] = 1 else: occurrence[token] += 1 tmp = pipeline.decode(all_tokens[out_last:]) if '\ufffd' not in tmp: out_str += tmp yield out_str out_last = i + 1 gc.collect() torch.cuda.empty_cache() yield out_str examples = [ ["""女招待: 欢迎光临。您远道而来,想必一定很累了吧? 深见: 不会……空气也清爽,也让我焕然一新呢 女招待: 是吗。那真是太好了 我因为撰稿的需要,而造访了这间位于信州山间的温泉宿驿。""", 200, 2.0, 0.4, 0.1, 0.1], ["翡翠: 欢迎回来,志贵少爷。", 200, 2.0, 0.4, 0.1, 0.1], ["""莲华: 你的目的,就是这个万华镜吧? 莲华拿出了万华镜。 深见: 啊…… 好像被万华镜拽过去了一般,我的腿不由自主地向它迈去 深见: 是这个……就是这个啊…… 烨烨生辉的魔法玩具。 连接现实与梦之世界的、诱惑的桥梁。 深见: 请让我好好看看…… 我刚想把手伸过去,莲华就一下子把它收了回去。""", 200, 2.0, 0.4, 0.1, 0.1], ["""嘉祥: 偶尔来一次也不错。 我坐到客厅的沙发上,拍了拍自己的大腿。 巧克力&香草: 喵喵? 巧克力: 咕噜咕噜咕噜~♪人家最喜欢让主人掏耳朵了~♪ 巧克力: 主人好久都没有帮我们掏耳朵了,现在人家超兴奋的喵~♪ 香草: 身为猫娘饲主,这点服务也是应该的对吧? 香草: 老实说我也有点兴奋呢咕噜咕噜咕噜~♪ 我摸摸各自占据住我左右两腿的两颗猫头。 嘉祥: 开心归开心,拜托你们俩别一直乱动啊,很危险的。""", 200, 2.0, 0.4, 0.1, 0.1], ] iface = gr.Interface( fn=infer, description=f'''这是GalGame剧本续写模型(实验性质,不保证效果)。请点击例子(在页面底部),可编辑内容。这里只看输入的最后约1200字,请写好,标点规范,无错别字,否则电脑会模仿你的错误。为避免占用资源,每次生成限制长度。可将输出内容复制到输入,然后继续生成。推荐提高temp改善文采,降低topp改善逻辑,提高两个penalty避免重复,具体幅度请自己实验。
{desc}''', allow_flagging="never", inputs=[ gr.Textbox(lines=10, label="Prompt 输入的前文", value="""嘉祥: 偶尔来一次也不错。 我坐到客厅的沙发上,拍了拍自己的大腿。 巧克力&香草: 喵喵? 巧克力: 咕噜咕噜咕噜~♪人家最喜欢让主人掏耳朵了~♪ 巧克力: 主人好久都没有帮我们掏耳朵了,现在人家超兴奋的喵~♪ 香草: 身为猫娘饲主,这点服务也是应该的对吧? 香草: 老实说我也有点兴奋呢咕噜咕噜咕噜~♪ 我摸摸各自占据住我左右两腿的两颗猫头。 嘉祥: 开心归开心,拜托你们俩别一直乱动啊,很危险的。"""), # prompt gr.Slider(10, 200, step=10, value=200, label="token_count 每次生成的长度"), # token_count gr.Slider(0.2, 2.0, step=0.1, value=2, label="temperature 默认0.7,高则变化丰富,低则保守求稳"), # temperature gr.Slider(0.0, 1.0, step=0.05, value=0.4, label="top_p 默认1.0,高则标新立异,低则循规蹈矩"), # top_p gr.Slider(0.0, 1.0, step=0.1, value=0.1, label="presencePenalty 默认0.0,避免写过的类似字"), # presencePenalty gr.Slider(0.0, 1.0, step=0.1, value=0.1, label="countPenalty 默认0.0,额外避免写过多次的类似字"), # countPenalty ], outputs=gr.Textbox(label="Output 输出的续写", lines=28), examples=examples, cache_examples=False, ).queue() demo = gr.TabbedInterface( [iface], ["Generative"] ) demo.queue(max_size=5) if args.share: demo.launch(share=True) else: demo.launch(share=False)