m96tkmok commited on
Commit
17bf3db
·
verified ·
1 Parent(s): 40c1dbc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+
3
+ import gradio as gr
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
+
6
+
7
+ TITLE = "<h1><center>Chat with lianghsun/Llama-3.2-Taiwan-3B</center></h1>"
8
+
9
+ DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/lianghsun/Llama-3.2-Taiwan-3B' target='_blank'> the model page</a> for details.</center></h3>"
10
+
11
+ DEFAULT_SYSTEM = "你是一個產自台灣的聊天機械人, 你以台灣本地人的身份, 使用正體中文回答問題."
12
+
13
+ CSS = """
14
+ .duplicate-button {
15
+ margin: auto !important;
16
+ color: white !important;
17
+ background: green !important;
18
+ border-radius: 100vh !important;
19
+ }
20
+ """
21
+
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained("lianghsun/Llama-3.2-Taiwan-3B")
24
+ model = AutoModelForCausalLM.from_pretrained("lianghsun/Llama-3.2-Taiwan-3B", torch_dtype="auto", device_map="auto")
25
+
26
+ def stream_chat(message: str, history: list, system: str, temperature: float, max_new_tokens: int):
27
+ conversation = [{"role": "system", "content": system or DEFAULT_SYSTEM}]
28
+ for prompt, answer in history:
29
+ conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
30
+
31
+ conversation.append({"role": "user", "content": message})
32
+
33
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(
34
+ model.device
35
+ )
36
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
37
+
38
+ generate_kwargs = dict(
39
+ input_ids=input_ids,
40
+ streamer=streamer,
41
+ max_new_tokens=max_new_tokens,
42
+ temperature=temperature,
43
+ do_sample=True,
44
+ )
45
+ if temperature == 0:
46
+ generate_kwargs["do_sample"] = False
47
+
48
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
49
+ t.start()
50
+
51
+ output = ""
52
+ for new_token in streamer:
53
+ output += new_token
54
+ yield output
55
+
56
+
57
+ chatbot = gr.Chatbot(height=450)
58
+
59
+ with gr.Blocks(css=CSS) as demo:
60
+ gr.HTML(TITLE)
61
+ gr.HTML(DESCRIPTION)
62
+ gr.ChatInterface(
63
+ fn=stream_chat,
64
+ chatbot=chatbot,
65
+ fill_height=True,
66
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
67
+ additional_inputs=[
68
+ gr.Text(
69
+ value="",
70
+ label="System",
71
+ render=False,
72
+ ),
73
+ gr.Slider(
74
+ minimum=0,
75
+ maximum=1,
76
+ step=0.1,
77
+ value=0.8,
78
+ label="Temperature",
79
+ render=False,
80
+ ),
81
+ gr.Slider(
82
+ minimum=128,
83
+ maximum=4096,
84
+ step=1,
85
+ value=1024,
86
+ label="Max new tokens",
87
+ render=False,
88
+ ),
89
+ ],
90
+ )
91
+
92
+
93
+ if __name__ == "__main__":
94
+ demo.launch()