ChatGal / app.py
wanicca's picture
fix demo (add gradio requirement and fix download)
0692d85
raw
history blame
6.25 kB
import gradio as gr
import argparse
import os, gc, torch
from datetime import datetime
from huggingface_hub import hf_hub_download
# from pynvml import *
# nvmlInit()
# gpu_h = nvmlDeviceGetHandleByIndex(0)
ctx_limit = 4096
desc = f'''链接:<a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 0.5em">ChatRWKV</a><a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 0.5em">RWKV-LM</a><a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 0.5em">RWKV pip package</a><a href="https://zhuanlan.zhihu.com/p/618011122" target="_blank" style="margin:0 0.5em">知乎教程</a>
'''
parser = argparse.ArgumentParser(prog = 'ChatGal RWKV')
parser.add_argument('--share',action='store_true')
args = parser.parse_args()
os.environ["RWKV_JIT_ON"] = '1'
from rwkv.model import RWKV
model_path = hf_hub_download(repo_id="Synthia/ChatGalRWKV", filename="rwkv-chatgal-v1-3B-ctx4096-epoch2.pth")
if os.environ['ON_COLAB'] == '1':
os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
model = RWKV(model=model_path, strategy='cuda bf16')
else:
model = RWKV(model=model_path, strategy='cpu bf16')
from rwkv.utils import PIPELINE, PIPELINE_ARGS
pipeline = PIPELINE(model, "20B_tokenizer.json")
def infer(
ctx,
token_count=10,
temperature=0.7,
top_p=1.0,
presencePenalty = 0.05,
countPenalty = 0.05,
):
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
alpha_frequency = countPenalty,
alpha_presence = presencePenalty,
token_ban = [0], # ban the generation of some tokens
token_stop = []) # stop generation whenever you see any token here
ctx = ctx.strip().split('\n')
for c in range(len(ctx)):
ctx[c] = ctx[c].strip().strip('\u3000').strip('\r')
ctx = list(filter(lambda c: c != '', ctx))
ctx = '\n' + ('\n'.join(ctx)).strip()
if ctx == '':
ctx = '\n'
# gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
# print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}',flush=True)
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
state = None
for i in range(int(token_count)):
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
for n in args.token_ban:
out[n] = -float('inf')
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
if token in args.token_stop:
break
all_tokens += [token]
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
tmp = pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
out_str += tmp
yield out_str
out_last = i + 1
gc.collect()
torch.cuda.empty_cache()
yield out_str
examples = [
["""女招待: 欢迎光临。您远道而来,想必一定很累了吧?
深见: 不会……空气也清爽,也让我焕然一新呢
女招待: 是吗。那真是太好了
我因为撰稿的需要,而造访了这间位于信州山间的温泉宿驿。""", 200, 0.7, 1.0, 0.05, 0.05],
["翡翠: 欢迎回来,志贵少爷。", 200, 0.7, 1.0, 0.05, 0.05],
["""莲华: 你的目的,就是这个万华镜吧?
莲华拿出了万华镜。
深见: 啊……
好像被万华镜拽过去了一般,我的腿不由自主地向它迈去
深见: 是这个……就是这个啊……
烨烨生辉的魔法玩具。
连接现实与梦之世界的、诱惑的桥梁。
深见: 请让我好好看看……
我刚想把手伸过去,莲华就一下子把它收了回去。""", 200, 0.7, 1.0, 0.05, 0.05],
["""嘉祥: 偶尔来一次也不错。
我坐到客厅的沙发上,拍了拍自己的大腿。
巧克力&香草: 喵喵?
巧克力: 咕噜咕噜咕噜~♪人家最喜欢让主人掏耳朵了~♪
巧克力: 主人好久都没有帮我们掏耳朵了,现在人家超兴奋的喵~♪
香草: 身为猫娘饲主,这点服务也是应该的对吧?
香草: 老实说我也有点兴奋呢咕噜咕噜咕噜~♪
我摸摸各自占据住我左右两腿的两颗猫头。
嘉祥: 开心归开心,拜托你们俩别一直乱动啊,很危险的。""", 200, 0.7, 1.0, 0.05, 0.05],
]
iface = gr.Interface(
fn=infer,
description=f'''这是纯网文模型,去除了英文和代码能力,但写小白文更强。<b>请点击例子(在页面底部)</b>,可编辑内容。这里只看输入的最后约1200字,请写好,标点规范,无错别字,否则电脑会模仿你的错误。<b>为避免占用资源,每次生成限制长度。可将输出内容复制到输入,然后继续生成</b>。推荐提高temp改善文采,降低topp改善逻辑,提高两个penalty避免重复,具体幅度请自己实验。{desc}''',
allow_flagging="never",
inputs=[
gr.Textbox(lines=10, label="Prompt 输入的前文", value="通过基因改造,修真"), # prompt
gr.Slider(10, 200, step=10, value=200, label="token_count 每次生成的长度"), # token_count
gr.Slider(0.2, 2.0, step=0.1, value=0.7, label="temperature 默认0.7,高则变化丰富,低则保守求稳"), # temperature
gr.Slider(0.0, 1.0, step=0.05, value=1.0, label="top_p 默认1.0,高则标新立异,低则循规蹈矩"), # top_p
gr.Slider(0.0, 1.0, step=0.1, value=0.05, label="presencePenalty 默认0.05,避免写过的类似字"), # presencePenalty
gr.Slider(0.0, 1.0, step=0.1, value=0.05, label="countPenalty 默认0.05,额外避免写过多次的类似字"), # countPenalty
],
outputs=gr.Textbox(label="Output 输出的续写", lines=28),
examples=examples,
cache_examples=False,
).queue()
demo = gr.TabbedInterface(
[iface], ["Generative"]
)
demo.queue(max_size=5)
if args.share:
demo.launch(share=True)
else:
demo.launch(share=False)