fexeak commited on
Commit
7d026c2
·
1 Parent(s): cb41d64

refactor: 替换模型为SmolLM2并简化代码结构

Browse files

移除原有NSFW-Flash模型相关代码,改用更轻量的SmolLM2-135M模型
简化代码结构,仅保留基础模型加载和推理功能

Files changed (2) hide show
  1. app.py +9 -164
  2. app.py.bak +165 -0
app.py CHANGED
@@ -1,165 +1,10 @@
1
- import torch
2
- import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
- import threading
5
- import time
6
-
7
- # Global variables for model and tokenizer
8
- model = None
9
- tokenizer = None
10
- model_loaded = False
11
-
12
- def load_model():
13
- """Load the model and tokenizer"""
14
- global model, tokenizer, model_loaded
15
- try:
16
- print("Loading model...")
17
- model = AutoModelForCausalLM.from_pretrained(
18
- "UnfilteredAI/NSFW-Flash",
19
- trust_remote_code=True,
20
- torch_dtype=torch.bfloat16
21
- ).to("cuda" if torch.cuda.is_available() else "cpu")
22
-
23
- tokenizer = AutoTokenizer.from_pretrained(
24
- "UnfilteredAI/NSFW-Flash",
25
- trust_remote_code=True
26
- )
27
-
28
- model_loaded = True
29
- print("Model loaded successfully!")
30
- except Exception as e:
31
- print(f"Error loading model: {e}")
32
- model_loaded = False
33
-
34
- def generate_response(message, history, temperature, max_length, top_p):
35
- """Generate response from the model"""
36
- global model, tokenizer, model_loaded
37
-
38
- if not model_loaded:
39
- return "模型尚未加载完成,请稍等..."
40
-
41
- try:
42
- # Build conversation history
43
- chat = [
44
- {"role": "system", "content": "You are NSFW-Flash, an AI assistant. Respond helpfully and appropriately."}
45
- ]
46
-
47
- # Add conversation history
48
- for user_msg, bot_msg in history:
49
- chat.append({"role": "user", "content": user_msg})
50
- if bot_msg:
51
- chat.append({"role": "assistant", "content": bot_msg})
52
-
53
- # Add current message
54
- chat.append({"role": "user", "content": message})
55
-
56
- # Apply chat template
57
- chat_text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
58
-
59
- # Tokenize
60
- inputs = tokenizer(chat_text, return_tensors="pt", return_attention_mask=False)
61
- if torch.cuda.is_available():
62
- inputs = inputs.to("cuda")
63
-
64
- # Generate
65
- with torch.no_grad():
66
- generated = model.generate(
67
- **inputs,
68
- max_length=max_length,
69
- temperature=temperature,
70
- top_p=top_p,
71
- do_sample=True,
72
- use_cache=False,
73
- eos_token_id=tokenizer.eos_token_id,
74
- pad_token_id=tokenizer.eos_token_id
75
- )
76
-
77
- # Decode response
78
- response = tokenizer.decode(generated[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
79
- return response.strip()
80
-
81
- except Exception as e:
82
- return f"生成回复时出错: {str(e)}"
83
-
84
- def chat_interface(message, history, temperature, max_length, top_p):
85
- """Chat interface for Gradio"""
86
- response = generate_response(message, history, temperature, max_length, top_p)
87
- history.append([message, response])
88
- return "", history
89
-
90
- # Load model in background
91
- loading_thread = threading.Thread(target=load_model)
92
- loading_thread.start()
93
-
94
- # Create Gradio interface
95
- with gr.Blocks(title="AI Chat Assistant") as demo:
96
- gr.Markdown("# 🤖 AI Chat Assistant")
97
- gr.Markdown("基于 NSFW-Flash 模型的聊天助手")
98
-
99
- with gr.Row():
100
- with gr.Column(scale=3):
101
- chatbot = gr.Chatbot(
102
- value=[],
103
- height=500,
104
- show_label=False
105
- )
106
-
107
- with gr.Row():
108
- msg = gr.Textbox(
109
- placeholder="输入您的消息...",
110
- show_label=False,
111
- scale=4
112
- )
113
- send_btn = gr.Button("发送", scale=1)
114
-
115
- clear_btn = gr.Button("清空对话")
116
-
117
- with gr.Column(scale=1):
118
- gr.Markdown("### 参数设置")
119
- temperature = gr.Slider(
120
- minimum=0.1,
121
- maximum=2.0,
122
- value=0.7,
123
- step=0.1,
124
- label="Temperature"
125
- )
126
- max_length = gr.Slider(
127
- minimum=100,
128
- maximum=2000,
129
- value=1000,
130
- step=100,
131
- label="最大长度"
132
- )
133
- top_p = gr.Slider(
134
- minimum=0.1,
135
- maximum=1.0,
136
- value=0.95,
137
- step=0.05,
138
- label="Top-p"
139
- )
140
-
141
- # Event handlers
142
- send_btn.click(
143
- chat_interface,
144
- inputs=[msg, chatbot, temperature, max_length, top_p],
145
- outputs=[msg, chatbot]
146
- )
147
-
148
- msg.submit(
149
- chat_interface,
150
- inputs=[msg, chatbot, temperature, max_length, top_p],
151
- outputs=[msg, chatbot]
152
- )
153
-
154
- clear_btn.click(
155
- lambda: ([], ""),
156
- outputs=[chatbot, msg]
157
- )
158
-
159
- if __name__ == "__main__":
160
- demo.launch(
161
- server_name="0.0.0.0",
162
- server_port=7860,
163
- share=True,
164
- show_error=True
165
- )
 
1
+ # pip install transformers
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ checkpoint = "HuggingFaceTB/SmolLM2-135M"
4
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
5
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
6
+ # for multiple GPUs install accelerate and do `model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto")`
7
+ model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
8
+ inputs = tokenizer.encode("Gravity is", return_tensors="pt").to(device)
9
+ outputs = model.generate(inputs)
10
+ print(tokenizer.decode(outputs[0]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py.bak ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import threading
5
+ import time
6
+
7
+ # Global variables for model and tokenizer
8
+ model = None
9
+ tokenizer = None
10
+ model_loaded = False
11
+
12
+ def load_model():
13
+ """Load the model and tokenizer"""
14
+ global model, tokenizer, model_loaded
15
+ try:
16
+ print("Loading model...")
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ "UnfilteredAI/NSFW-Flash",
19
+ trust_remote_code=True,
20
+ torch_dtype=torch.bfloat16
21
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained(
24
+ "UnfilteredAI/NSFW-Flash",
25
+ trust_remote_code=True
26
+ )
27
+
28
+ model_loaded = True
29
+ print("Model loaded successfully!")
30
+ except Exception as e:
31
+ print(f"Error loading model: {e}")
32
+ model_loaded = False
33
+
34
+ def generate_response(message, history, temperature, max_length, top_p):
35
+ """Generate response from the model"""
36
+ global model, tokenizer, model_loaded
37
+
38
+ if not model_loaded:
39
+ return "模型尚未加载完成,请稍等..."
40
+
41
+ try:
42
+ # Build conversation history
43
+ chat = [
44
+ {"role": "system", "content": "You are NSFW-Flash, an AI assistant. Respond helpfully and appropriately."}
45
+ ]
46
+
47
+ # Add conversation history
48
+ for user_msg, bot_msg in history:
49
+ chat.append({"role": "user", "content": user_msg})
50
+ if bot_msg:
51
+ chat.append({"role": "assistant", "content": bot_msg})
52
+
53
+ # Add current message
54
+ chat.append({"role": "user", "content": message})
55
+
56
+ # Apply chat template
57
+ chat_text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
58
+
59
+ # Tokenize
60
+ inputs = tokenizer(chat_text, return_tensors="pt", return_attention_mask=False)
61
+ if torch.cuda.is_available():
62
+ inputs = inputs.to("cuda")
63
+
64
+ # Generate
65
+ with torch.no_grad():
66
+ generated = model.generate(
67
+ **inputs,
68
+ max_length=max_length,
69
+ temperature=temperature,
70
+ top_p=top_p,
71
+ do_sample=True,
72
+ use_cache=False,
73
+ eos_token_id=tokenizer.eos_token_id,
74
+ pad_token_id=tokenizer.eos_token_id
75
+ )
76
+
77
+ # Decode response
78
+ response = tokenizer.decode(generated[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
79
+ return response.strip()
80
+
81
+ except Exception as e:
82
+ return f"生成回复时出错: {str(e)}"
83
+
84
+ def chat_interface(message, history, temperature, max_length, top_p):
85
+ """Chat interface for Gradio"""
86
+ response = generate_response(message, history, temperature, max_length, top_p)
87
+ history.append([message, response])
88
+ return "", history
89
+
90
+ # Load model in background
91
+ loading_thread = threading.Thread(target=load_model)
92
+ loading_thread.start()
93
+
94
+ # Create Gradio interface
95
+ with gr.Blocks(title="AI Chat Assistant") as demo:
96
+ gr.Markdown("# 🤖 AI Chat Assistant")
97
+ gr.Markdown("基于 NSFW-Flash 模型的聊天助手")
98
+
99
+ with gr.Row():
100
+ with gr.Column(scale=3):
101
+ chatbot = gr.Chatbot(
102
+ value=[],
103
+ height=500,
104
+ show_label=False
105
+ )
106
+
107
+ with gr.Row():
108
+ msg = gr.Textbox(
109
+ placeholder="输入您的消息...",
110
+ show_label=False,
111
+ scale=4
112
+ )
113
+ send_btn = gr.Button("发送", scale=1)
114
+
115
+ clear_btn = gr.Button("清空对话")
116
+
117
+ with gr.Column(scale=1):
118
+ gr.Markdown("### 参数设置")
119
+ temperature = gr.Slider(
120
+ minimum=0.1,
121
+ maximum=2.0,
122
+ value=0.7,
123
+ step=0.1,
124
+ label="Temperature"
125
+ )
126
+ max_length = gr.Slider(
127
+ minimum=100,
128
+ maximum=2000,
129
+ value=1000,
130
+ step=100,
131
+ label="最大长度"
132
+ )
133
+ top_p = gr.Slider(
134
+ minimum=0.1,
135
+ maximum=1.0,
136
+ value=0.95,
137
+ step=0.05,
138
+ label="Top-p"
139
+ )
140
+
141
+ # Event handlers
142
+ send_btn.click(
143
+ chat_interface,
144
+ inputs=[msg, chatbot, temperature, max_length, top_p],
145
+ outputs=[msg, chatbot]
146
+ )
147
+
148
+ msg.submit(
149
+ chat_interface,
150
+ inputs=[msg, chatbot, temperature, max_length, top_p],
151
+ outputs=[msg, chatbot]
152
+ )
153
+
154
+ clear_btn.click(
155
+ lambda: ([], ""),
156
+ outputs=[chatbot, msg]
157
+ )
158
+
159
+ if __name__ == "__main__":
160
+ demo.launch(
161
+ server_name="0.0.0.0",
162
+ server_port=7860,
163
+ share=True,
164
+ show_error=True
165
+ )