File size: 4,159 Bytes
c49a56c
e0dec7a
c49a56c
 
 
4e6c5df
802ccb7
c49a56c
 
 
 
 
 
 
 
b1e3e8e
c49a56c
 
 
d4a0c70
c49a56c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0dec7a
c49a56c
 
802ccb7
c49a56c
 
 
 
 
 
 
 
 
 
960904c
 
c49a56c
 
 
960904c
 
c49a56c
 
960904c
 
 
 
c49a56c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4a0c70
 
c49a56c
 
ff72155
 
 
 
 
c49a56c
 
 
 
ff72155
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from threading import Thread
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import subprocess
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

BANNER_HTML = """
<p align="center">
    <a href="https://github.com/ymcui/Chinese-LLaMA-Alpaca-3">
        <img src="https://ymcui.com/images/chinese-llama-alpaca-3-banner.png" width="600"/>
    </a>
</p>
<h3>
    <center>Check our <a href='https://github.com/ymcui/Chinese-LLaMA-Alpaca-3' target='_blank'>Chinese-LLaMA-Alpaca-3 GitHub Project</a> for more information.
    </center>
</h3>
<p>
    <center><em>The demo is mainly for academic purposes. Illegal usages are prohibited. Default model: <a href="https://huggingface.co/hfl/llama-3-chinese-8b-instruct-v3">hfl/llama-3-chinese-8b-instruct-v3</a></em></center>
</p>
"""

DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant. 你是一个乐于助人的助手。"

# Load different instruct models based on the selected version
def load_model(version):
    global tokenizer, model
    if version == "v1":
        model_name = "hfl/llama-3-chinese-8b-instruct"
    elif version == "v2":
        model_name = "hfl/llama-3-chinese-8b-instruct-v2"
    elif version == "v3":
        model_name = "hfl/llama-3-chinese-8b-instruct-v3"
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2")
    return f"Model {model_name} loaded."

@spaces.GPU(duration=60)
def stream_chat(message: str, history: list, system_prompt: str, model_version: str, temperature: float, max_new_tokens: int):
    conversation = [{"role": "system", "content": system_prompt or DEFAULT_SYSTEM_PROMPT}]
    for prompt, answer in history:
        conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])

    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]

    generate_kwargs = {
        "input_ids": input_ids,
        "streamer": streamer,
        "eos_token_id": terminators,
        "pad_token_id": tokenizer.eos_token_id,
        "max_new_tokens": max_new_tokens,
        "temperature": temperature,
        "top_k": 40,
        "top_p": 0.9,
        "num_beams": 1,
        "repetition_penalty": 1.1,
        "do_sample": temperature != 0,
    }

    generation_thread = Thread(target=model.generate, kwargs=generate_kwargs)
    generation_thread.start()

    output = ""
    for new_token in streamer:
        output += new_token
        yield output

chatbot = gr.Chatbot(height=500)

with gr.Blocks() as demo:
    gr.HTML(BANNER_HTML)
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="Parameters / 参数设置", open=False, render=False),
        additional_inputs=[
            gr.Text(value=DEFAULT_SYSTEM_PROMPT, label="System Prompt / 系统提示词", render=False),
            gr.Radio(choices=["v1", "v2", "v3"], label="Model Version / 模型版本", value="v3", interactive=False, render=False),
            gr.Slider(minimum=0, maximum=1.5, step=0.1, value=0.6, label="Temperature / 温度系数", render=False),
            gr.Slider(minimum=128, maximum=2048, step=1, value=512, label="Max new tokens / 最大生成长度", render=False),
        ],
        cache_examples=False,
        submit_btn="Send / 发送",
        stop_btn="Stop / 停止",
        retry_btn="🔄 Retry / 重试",
        undo_btn="↩️ Undo / 撤销",
        clear_btn="🗑️ Clear / 清空",
    )

if __name__ == "__main__":
    load_model("v3")  # Load the default model
    demo.launch()