Spaces:
Runtime error
Runtime error
from peft import PeftModel, PeftConfig | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig | |
from threading import Thread | |
import gradio as gr | |
import torch | |
config = PeftConfig.from_pretrained("Junity/Genshin-World-Model", trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-13B-Base", torch_dtype=torch.float32, trust_remote_code=True) | |
model = PeftModel.from_pretrained(model, r"Junity/Genshin-World-Model", torch_dtype=torch.float32, trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan-13B-Base", trust_remote_code=True) | |
history = [] | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if device == "cuda": | |
model.cuda() | |
model = model.half() | |
def respond(role_name, msg, textbox): | |
if textbox != '': | |
textbox = textbox + "\n" + role_name + ":" + msg + ('。' if msg[-1] not in ['。', '!', '?'] else '') + '\n' | |
yield ["", textbox] | |
else: | |
textbox = textbox + role_name + ":" + msg + ('。' if msg[-1] not in ['。', '!', '?'] else '') + '\n' | |
yield ["", textbox] | |
input_ids = tokenizer.encode(textbox)[-4096:] | |
input_ids = torch.LongTensor([input_ids]).to(device) | |
generation_config = model.generation_config | |
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) | |
gen_kwargs = {} | |
gen_kwargs.update(dict( | |
input_ids=input_ids, | |
temperature=1.5, | |
top_p=0.7, | |
top_k=50, | |
repetition_penalty=1.0, | |
max_new_tokens=256, | |
do_sample=True, | |
)) | |
outputs = [] | |
print(input_ids) | |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) | |
gen_kwargs["streamer"] = streamer | |
thread = Thread(target=model.generate, kwargs=gen_kwargs) | |
thread.start() | |
for new_text in streamer: | |
textbox += new_text | |
yield ["", textbox] | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
## Genshin-World-Model | |
- 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model) | |
- 此模型不支持要求对方回答什么,只支持续写。 | |
""" | |
) | |
with gr.Tab("聊天") as chat: | |
role_name = gr.Textbox(label="你将扮演的角色") | |
msg = gr.Textbox(label="输入") | |
with gr.Row(): | |
clear = gr.Button("Clear") | |
sub = gr.Button("Submit") | |
textbox = gr.Textbox(interactive=False) | |
sub.click(fn=respond, inputs=[role_name, msg, textbox], outputs=[msg, textbox]) | |
clear.click(lambda: None, None, textbox, queue=False) | |
demo.queue().launch(server_port=6006) | |