mike23415 commited on
Commit
59219bf
·
verified ·
1 Parent(s): b283c96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -10
app.py CHANGED
@@ -6,10 +6,6 @@ from flask_cors import CORS
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  import gradio as gr
8
 
9
- # Initialize Flask app
10
- app = Flask(__name__)
11
- CORS(app)
12
-
13
  # Global variables
14
  MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
15
  MAX_LENGTH = 2048
@@ -18,15 +14,24 @@ TEMPERATURE = 0.7
18
  TOP_P = 0.9
19
  THINKING_STEPS = 3 # Number of thinking steps
20
 
21
- # Load model and tokenizer
22
- @app.before_first_request
23
- def load_model():
 
 
 
24
  global model, tokenizer
25
 
 
 
 
26
  print(f"Loading model: {MODEL_ID}")
27
 
28
  # Load tokenizer
29
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
30
 
31
  # Load model with optimizations for limited resources
32
  model = AutoModelForCausalLM.from_pretrained(
@@ -38,6 +43,10 @@ def load_model():
38
 
39
  print("Model and tokenizer loaded successfully!")
40
 
 
 
 
 
41
  # Helper function for step-by-step thinking
42
  def generate_with_thinking(prompt, thinking_steps=THINKING_STEPS):
43
  # Initialize conversation with prompt
@@ -100,6 +109,10 @@ def generate_with_thinking(prompt, thinking_steps=THINKING_STEPS):
100
  @app.route('/api/chat', methods=['POST'])
101
  def chat():
102
  try:
 
 
 
 
103
  data = request.json
104
  prompt = data.get('prompt', '')
105
  include_thinking = data.get('include_thinking', False)
@@ -123,6 +136,9 @@ def chat():
123
  return jsonify(result)
124
 
125
  except Exception as e:
 
 
 
126
  return jsonify({'error': str(e)}), 500
127
 
128
  # Simple health check endpoint
@@ -157,6 +173,10 @@ def create_ui():
157
  if not question.strip():
158
  return "", "Please enter a question"
159
 
 
 
 
 
160
  response = generate_with_thinking(question)
161
 
162
  if show_thinking:
@@ -180,8 +200,8 @@ def create_ui():
180
 
181
  # Create Gradio UI and launch the app
182
  if __name__ == "__main__":
183
- # Load model at startup for Gradio
184
- load_model()
185
 
186
  # Create and launch Gradio interface
187
  demo = create_ui()
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  import gradio as gr
8
 
 
 
 
 
9
  # Global variables
10
  MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
11
  MAX_LENGTH = 2048
 
14
  TOP_P = 0.9
15
  THINKING_STEPS = 3 # Number of thinking steps
16
 
17
+ # Global variables for model and tokenizer
18
+ model = None
19
+ tokenizer = None
20
+
21
+ # Function to load model and tokenizer
22
+ def load_model_and_tokenizer():
23
  global model, tokenizer
24
 
25
+ if model is not None and tokenizer is not None:
26
+ return
27
+
28
  print(f"Loading model: {MODEL_ID}")
29
 
30
  # Load tokenizer
31
+ tokenizer = AutoTokenizer.from_pretrained(
32
+ MODEL_ID,
33
+ use_fast=True,
34
+ )
35
 
36
  # Load model with optimizations for limited resources
37
  model = AutoModelForCausalLM.from_pretrained(
 
43
 
44
  print("Model and tokenizer loaded successfully!")
45
 
46
+ # Initialize Flask app
47
+ app = Flask(__name__)
48
+ CORS(app)
49
+
50
  # Helper function for step-by-step thinking
51
  def generate_with_thinking(prompt, thinking_steps=THINKING_STEPS):
52
  # Initialize conversation with prompt
 
109
  @app.route('/api/chat', methods=['POST'])
110
  def chat():
111
  try:
112
+ # Ensure model is loaded
113
+ if model is None or tokenizer is None:
114
+ load_model_and_tokenizer()
115
+
116
  data = request.json
117
  prompt = data.get('prompt', '')
118
  include_thinking = data.get('include_thinking', False)
 
136
  return jsonify(result)
137
 
138
  except Exception as e:
139
+ import traceback
140
+ print(f"Error in chat endpoint: {str(e)}")
141
+ print(traceback.format_exc())
142
  return jsonify({'error': str(e)}), 500
143
 
144
  # Simple health check endpoint
 
173
  if not question.strip():
174
  return "", "Please enter a question"
175
 
176
+ # Ensure model is loaded
177
+ if model is None or tokenizer is None:
178
+ load_model_and_tokenizer()
179
+
180
  response = generate_with_thinking(question)
181
 
182
  if show_thinking:
 
200
 
201
  # Create Gradio UI and launch the app
202
  if __name__ == "__main__":
203
+ # Load model at startup
204
+ load_model_and_tokenizer()
205
 
206
  # Create and launch Gradio interface
207
  demo = create_ui()