geethareddy commited on
Commit
6c9e43a
·
verified ·
1 Parent(s): 88724f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -34
app.py CHANGED
@@ -1,10 +1,12 @@
1
- from flask import Flask, request, jsonify
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import json
4
  import logging
5
  import os
6
  import threading
7
  import time
 
 
8
 
9
  # Set up logging to stdout only
10
  logging.basicConfig(
@@ -54,6 +56,46 @@ def start_background_tasks():
54
  thread.daemon = True
55
  thread.start()
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  @app.route("/")
58
  def root():
59
  logger.debug("Root endpoint accessed")
@@ -70,7 +112,7 @@ def health_check():
70
  @app.route("/generate_coaching", methods=["POST"])
71
  def generate_coaching():
72
  logger.debug("Generate coaching endpoint accessed")
73
- # Manual validation of request data (replacing Pydantic)
74
  data = request.get_json()
75
  if not data:
76
  logger.error("Invalid request: No JSON data provided")
@@ -82,15 +124,14 @@ def generate_coaching():
82
  logger.error(f"Missing required fields: {missing_fields}")
83
  return jsonify({"error": f"Missing required fields: {missing_fields}"}), 400
84
 
85
- if model is None or tokenizer is None:
86
- logger.warning("Model or tokenizer not loaded")
87
- # Return a static response if the model isn't loaded yet
88
- response_json = {
89
  "checklist": ["Inspect safety equipment", "Review milestone progress"],
90
  "tips": ["Prioritize team communication", "Check weather updates"],
91
  "quote": "Every step forward counts!"
92
- }
93
- return jsonify(response_json)
94
 
95
  try:
96
  # Prepare input text
@@ -98,10 +139,10 @@ def generate_coaching():
98
  f"Role: {data['role']}, Project: {data['project_id']}, "
99
  f"Milestones: {data['milestones']}, Reflection: {data['reflection_log']}"
100
  )
101
-
102
  # Tokenize input
103
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
104
-
105
  # Generate output
106
  outputs = model.generate(
107
  inputs["input_ids"],
@@ -111,30 +152,23 @@ def generate_coaching():
111
  do_sample=True,
112
  temperature=0.7
113
  )
114
-
115
- # Decode and parse response
116
  response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
117
-
118
- # Since distilgpt2 may not output JSON, parse the response manually or use fallback
119
- if not response_text.startswith("{"):
120
- checklist = ["Inspect safety equipment", "Review milestone progress"]
121
- tips = ["Prioritize team communication", "Check weather updates"]
122
- quote = "Every step forward counts!"
123
- response_json = {"checklist": checklist, "tips": tips, "quote": quote}
124
- logger.warning("Model output is not JSON, using default response")
125
- else:
126
- try:
127
- response_json = json.loads(response_text)
128
- except json.JSONDecodeError:
129
- response_json = {
130
- "checklist": ["Inspect safety equipment", "Review milestone progress"],
131
- "tips": ["Prioritize team communication", "Check weather updates"],
132
- "quote": "Every step forward counts!"
133
- }
134
- logger.warning("Failed to parse model output as JSON, using default response")
135
-
136
  return jsonify(response_json)
137
-
138
  except Exception as e:
139
  logger.error(f"Error generating coaching response: {str(e)}")
140
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
@@ -145,4 +179,5 @@ if __name__ == "__main__":
145
  # Run Flask app with waitress for production-ready WSGI server
146
  from waitress import serve
147
  logger.debug("Starting Flask app with Waitress")
148
- serve(app, host="0.0.0.0", port=7860)
 
 
1
+
2
+
3
  import json
4
  import logging
5
  import os
6
  import threading
7
  import time
8
+ from flask import Flask, request, jsonify
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
 
11
  # Set up logging to stdout only
12
  logging.basicConfig(
 
56
  thread.daemon = True
57
  thread.start()
58
 
59
+ # Utility function to wait for model loading with a timeout
60
+ def wait_for_model(timeout=60):
61
+ start_time = time.time()
62
+ while time.time() - start_time < timeout:
63
+ if model_load_status in ["local_model_loaded", "fallback_model_loaded"]:
64
+ return True
65
+ elif "failed" in model_load_status:
66
+ return False
67
+ time.sleep(1)
68
+ return False
69
+
70
+ # Utility function to parse raw text into structured JSON response
71
+ def parse_raw_text_to_json(raw_text):
72
+ lines = raw_text.strip().split("\n")
73
+ checklist = []
74
+ tips = []
75
+ quote = "Every step forward counts!"
76
+
77
+ # Simple parsing logic: look for keywords or assume structure
78
+ for line in lines:
79
+ line = line.strip()
80
+ if line.startswith("- "): # Assume checklist items start with "- "
81
+ checklist.append(line[2:].strip())
82
+ elif line.startswith("* "): # Assume tips start with "* "
83
+ tips.append(line[2:].strip())
84
+ elif line and not checklist and not tips: # If no checklist or tips, treat as a quote
85
+ quote = line
86
+
87
+ # Fallback if parsing fails
88
+ if not checklist:
89
+ checklist = ["Inspect safety equipment", "Review milestone progress"]
90
+ if not tips:
91
+ tips = ["Prioritize team communication", "Check weather updates"]
92
+
93
+ return {
94
+ "checklist": checklist,
95
+ "tips": tips,
96
+ "quote": quote
97
+ }
98
+
99
  @app.route("/")
100
  def root():
101
  logger.debug("Root endpoint accessed")
 
112
  @app.route("/generate_coaching", methods=["POST"])
113
  def generate_coaching():
114
  logger.debug("Generate coaching endpoint accessed")
115
+ # Manual validation of request data
116
  data = request.get_json()
117
  if not data:
118
  logger.error("Invalid request: No JSON data provided")
 
124
  logger.error(f"Missing required fields: {missing_fields}")
125
  return jsonify({"error": f"Missing required fields: {missing_fields}"}), 400
126
 
127
+ # Wait for the model to load (up to 60 seconds)
128
+ if not wait_for_model(timeout=60):
129
+ logger.warning("Model failed to load within timeout")
130
+ return jsonify({
131
  "checklist": ["Inspect safety equipment", "Review milestone progress"],
132
  "tips": ["Prioritize team communication", "Check weather updates"],
133
  "quote": "Every step forward counts!"
134
+ })
 
135
 
136
  try:
137
  # Prepare input text
 
139
  f"Role: {data['role']}, Project: {data['project_id']}, "
140
  f"Milestones: {data['milestones']}, Reflection: {data['reflection_log']}"
141
  )
142
+
143
  # Tokenize input
144
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
145
+
146
  # Generate output
147
  outputs = model.generate(
148
  inputs["input_ids"],
 
152
  do_sample=True,
153
  temperature=0.7
154
  )
155
+
156
+ # Decode response
157
  response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
158
+
159
+ # Try parsing as JSON first
160
+ try:
161
+ response_json = json.loads(response_text)
162
+ # Validate required fields in response
163
+ if not all(key in response_json for key in ["checklist", "tips", "quote"]):
164
+ raise ValueError("Missing required fields in model output")
165
+ except (json.JSONDecodeError, ValueError):
166
+ # If not JSON or invalid JSON, parse raw text
167
+ logger.warning("Model output is not valid JSON, parsing raw text")
168
+ response_json = parse_raw_text_to_json(response_text)
169
+
 
 
 
 
 
 
 
170
  return jsonify(response_json)
171
+
172
  except Exception as e:
173
  logger.error(f"Error generating coaching response: {str(e)}")
174
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
 
179
  # Run Flask app with waitress for production-ready WSGI server
180
  from waitress import serve
181
  logger.debug("Starting Flask app with Waitress")
182
+ serve(app, host="0.0.0.0", port=7860)
183
+