from threading import Thread import torch import gradio as gr import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import subprocess subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) BANNER_HTML = """

Check our Chinese-LLaMA-Alpaca-3 GitHub Project for more information.

The demo is mainly for academic purposes. Illegal usages are prohibited. Default model: hfl/llama-3-chinese-8b-instruct-v3

""" DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant. 你是一个乐于助人的助手。" # Load different instruct models based on the selected version def load_model(version): global tokenizer, model if version == "v1": model_name = "hfl/llama-3-chinese-8b-instruct" elif version == "v2": model_name = "hfl/llama-3-chinese-8b-instruct-v2" elif version == "v3": model_name = "hfl/llama-3-chinese-8b-instruct-v3" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2") return f"Model {model_name} loaded." @spaces.GPU(duration=50) def stream_chat(message: str, history: list, system_prompt: str, model_version: str, temperature: float, max_new_tokens: int): conversation = [{"role": "system", "content": system_prompt or DEFAULT_SYSTEM_PROMPT}] for prompt, answer in history: conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}]) conversation.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")] generate_kwargs = { "input_ids": input_ids, "streamer": streamer, "eos_token_id": terminators, "pad_token_id": tokenizer.eos_token_id, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": 40, "top_p": 0.9, "num_beams": 1, "repetition_penalty": 1.1, "do_sample": temperature != 0, } generation_thread = Thread(target=model.generate, kwargs=generate_kwargs) generation_thread.start() output = "" for new_token in streamer: output += new_token yield output chatbot = gr.Chatbot(height=500) with gr.Blocks() as demo: gr.HTML(BANNER_HTML) gr.ChatInterface( fn=stream_chat, chatbot=chatbot, fill_height=True, additional_inputs_accordion=gr.Accordion(label="Parameters / 参数设置", open=False, render=False), additional_inputs=[ gr.Text(value=DEFAULT_SYSTEM_PROMPT, label="System Prompt / 系统提示词", render=False), gr.Radio(choices=["v1", "v2", "v3"], label="Model Version / 模型版本", value="v3", interactive=False, render=False), gr.Slider(minimum=0, maximum=1.5, step=0.1, value=0.6, label="Temperature / 温度系数", render=False), gr.Slider(minimum=128, maximum=2048, step=1, value=512, label="Max new tokens / 最大生成长度", render=False), ], cache_examples=False, ) if __name__ == "__main__": load_model("v3") # Load the default model demo.launch()