fexeak commited on
Commit
8cfcc01
·
1 Parent(s): fc145ea

feat: 添加基于SmolLM2-135M的Gradio聊天界面

Browse files

实现一个完整的聊天助手界面,包含以下功能:
- 后台线程加载模型
- 可调节生成参数(temperature, max_length, top_p)
- 聊天历史记录功能
- 错误处理和状态提示

Files changed (1) hide show
  1. app.py +135 -8
app.py CHANGED
@@ -1,10 +1,137 @@
1
- # pip install transformers
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
3
  checkpoint = "HuggingFaceTB/SmolLM2-135M"
4
- device = "cuda" # for GPU usage or "cpu" for CPU usage
5
- tokenizer = AutoTokenizer.from_pretrained(checkpoint)
6
- # for multiple GPUs install accelerate and do `model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto")`
7
- model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
8
- inputs = tokenizer.encode("Gravity is", return_tensors="pt").to(device)
9
- outputs = model.generate(inputs)
10
- print(tokenizer.decode(outputs[0]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import threading
5
+ import time
6
+
7
+ # Global variables for model and tokenizer
8
+ model = None
9
+ tokenizer = None
10
+ model_loaded = False
11
  checkpoint = "HuggingFaceTB/SmolLM2-135M"
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ def load_model():
15
+ """Load the model and tokenizer"""
16
+ global model, tokenizer, model_loaded
17
+ try:
18
+ print("Loading model...")
19
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
20
+ model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
21
+ model_loaded = True
22
+ print("Model loaded successfully!")
23
+ except Exception as e:
24
+ print(f"Error loading model: {e}")
25
+ model_loaded = False
26
+
27
+ def generate_response(message, history, temperature, max_length, top_p):
28
+ """Generate response from the model"""
29
+ global model, tokenizer, model_loaded
30
+
31
+ if not model_loaded:
32
+ return "模型尚未加载完成,请稍等..."
33
+
34
+ try:
35
+ # Tokenize input
36
+ inputs = tokenizer.encode(message, return_tensors="pt").to(device)
37
+
38
+ # Generate
39
+ with torch.no_grad():
40
+ outputs = model.generate(
41
+ inputs,
42
+ max_length=max_length,
43
+ temperature=temperature,
44
+ top_p=top_p,
45
+ do_sample=True,
46
+ pad_token_id=tokenizer.eos_token_id
47
+ )
48
+
49
+ # Decode response
50
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
+ return response.strip()
52
+
53
+ except Exception as e:
54
+ return f"生成回复时出错: {str(e)}"
55
+
56
+ def chat_interface(message, history, temperature, max_length, top_p):
57
+ """Chat interface for Gradio"""
58
+ response = generate_response(message, history, temperature, max_length, top_p)
59
+ history.append([message, response])
60
+ return "", history
61
+
62
+ # Load model in background
63
+ loading_thread = threading.Thread(target=load_model)
64
+ loading_thread.start()
65
+
66
+ # Create Gradio interface
67
+ with gr.Blocks(title="AI Chat Assistant") as demo:
68
+ gr.Markdown("# 🤖 AI Chat Assistant")
69
+ gr.Markdown("基于 SmolLM2-135M 模型的聊天助手")
70
+
71
+ with gr.Row():
72
+ with gr.Column(scale=3):
73
+ chatbot = gr.Chatbot(
74
+ value=[],
75
+ height=500,
76
+ show_label=False
77
+ )
78
+
79
+ with gr.Row():
80
+ msg = gr.Textbox(
81
+ placeholder="输入您的消息...",
82
+ show_label=False,
83
+ scale=4
84
+ )
85
+ send_btn = gr.Button("发送", scale=1)
86
+
87
+ clear_btn = gr.Button("清空对话")
88
+
89
+ with gr.Column(scale=1):
90
+ gr.Markdown("### 参数设置")
91
+ temperature = gr.Slider(
92
+ minimum=0.1,
93
+ maximum=2.0,
94
+ value=0.7,
95
+ step=0.1,
96
+ label="Temperature"
97
+ )
98
+ max_length = gr.Slider(
99
+ minimum=100,
100
+ maximum=2000,
101
+ value=1000,
102
+ step=100,
103
+ label="最大长度"
104
+ )
105
+ top_p = gr.Slider(
106
+ minimum=0.1,
107
+ maximum=1.0,
108
+ value=0.95,
109
+ step=0.05,
110
+ label="Top-p"
111
+ )
112
+
113
+ # Event handlers
114
+ send_btn.click(
115
+ chat_interface,
116
+ inputs=[msg, chatbot, temperature, max_length, top_p],
117
+ outputs=[msg, chatbot]
118
+ )
119
+
120
+ msg.submit(
121
+ chat_interface,
122
+ inputs=[msg, chatbot, temperature, max_length, top_p],
123
+ outputs=[msg, chatbot]
124
+ )
125
+
126
+ clear_btn.click(
127
+ lambda: ([], ""),
128
+ outputs=[chatbot, msg]
129
+ )
130
+
131
+ if __name__ == "__main__":
132
+ demo.launch(
133
+ server_name="0.0.0.0",
134
+ server_port=7860,
135
+ share=True,
136
+ show_error=True
137
+ )