Spaces:
Sleeping
Sleeping
File size: 6,656 Bytes
b181476 3db7383 875d1bc 6c9e43a b181476 fdb1a6b 16dbf0f 0d34466 16dbf0f fdb1a6b 16dbf0f b181476 875d1bc e7c1a90 3db7383 16dbf0f 3db7383 875d1bc e7c1a90 16dbf0f e7c1a90 875d1bc e7c1a90 6c9e43a 875d1bc 0d34466 875d1bc 0d34466 875d1bc 0d34466 875d1bc 16dbf0f 875d1bc b181476 875d1bc 0d34466 6c9e43a 875d1bc 6c9e43a 0d34466 6c9e43a e7c1a90 b181476 875d1bc b181476 6c9e43a b181476 6c9e43a b181476 6c9e43a b181476 6c9e43a 875d1bc 6c9e43a b181476 875d1bc 6c9e43a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import json
import logging
import os
import threading
import time
from flask import Flask, request, jsonify
from transformers import AutoModelForCausalLM, AutoTokenizer
# Set up logging to stdout only
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler() # Log to stdout
]
)
logger = logging.getLogger(__name__)
# Initialize Flask app
app = Flask(__name__)
# Global variables for model and tokenizer
model = None
tokenizer = None
model_load_status = "not_loaded"
# Define model path and fallback
model_path = "/app/fine-tuned-construction-llm"
fallback_model = "distilgpt2"
# Function to load model in the background
def load_model_background():
global model, tokenizer, model_load_status
try:
if os.path.isdir(model_path):
logger.info(f"Loading local model from {model_path}")
model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
model_load_status = "local_model_loaded"
else:
logger.info(f"Model directory not found: {model_path}. Using pre-trained model: {fallback_model}")
model = AutoModelForCausalLM.from_pretrained(fallback_model)
tokenizer = AutoTokenizer.from_pretrained(fallback_model)
model_load_status = "fallback_model_loaded"
logger.info("Model and tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {str(e)}")
model_load_status = f"failed: {str(e)}"
# Start model loading in a background thread
def start_background_tasks():
logger.debug("Starting background tasks")
thread = threading.Thread(target=load_model_background)
thread.daemon = True
thread.start()
# Utility function to wait for model loading with a timeout
def wait_for_model(timeout=60):
start_time = time.time()
while time.time() - start_time < timeout:
if model_load_status in ["local_model_loaded", "fallback_model_loaded"]:
return True
elif "failed" in model_load_status:
return False
time.sleep(1)
return False
# Utility function to parse raw text into structured JSON response
def parse_raw_text_to_json(raw_text):
lines = raw_text.strip().split("\n")
checklist = []
tips = []
quote = "Every step forward counts!"
# Simple parsing logic: look for keywords or assume structure
for line in lines:
line = line.strip()
if line.startswith("- "): # Assume checklist items start with "- "
checklist.append(line[2:].strip())
elif line.startswith("* "): # Assume tips start with "* "
tips.append(line[2:].strip())
elif line and not checklist and not tips: # If no checklist or tips, treat as a quote
quote = line
# Fallback if parsing fails
if not checklist:
checklist = ["Inspect safety equipment", "Review milestone progress"]
if not tips:
tips = ["Prioritize team communication", "Check weather updates"]
return {
"checklist": checklist,
"tips": tips,
"quote": quote
}
@app.route("/")
def root():
logger.debug("Root endpoint accessed")
return jsonify({"message": "Supervisor AI Coach is running"})
@app.route("/health")
def health_check():
logger.debug("Health endpoint accessed")
return jsonify({
"status": "healthy" if model_load_status in ["local_model_loaded", "fallback_model_loaded"] else "starting",
"model_load_status": model_load_status
})
@app.route("/generate_coaching", methods=["POST"])
def generate_coaching():
logger.debug("Generate coaching endpoint accessed")
# Manual validation of request data
data = request.get_json()
if not data:
logger.error("Invalid request: No JSON data provided")
return jsonify({"error": "Invalid request: JSON data required"}), 400
required_fields = ["role", "project_id", "milestones", "reflection_log"]
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
logger.error(f"Missing required fields: {missing_fields}")
return jsonify({"error": f"Missing required fields: {missing_fields}"}), 400
# Wait for the model to load (up to 60 seconds)
if not wait_for_model(timeout=60):
logger.warning("Model failed to load within timeout")
return jsonify({
"checklist": ["Inspect safety equipment", "Review milestone progress"],
"tips": ["Prioritize team communication", "Check weather updates"],
"quote": "Every step forward counts!"
})
try:
# Prepare input text
input_text = (
f"Role: {data['role']}, Project: {data['project_id']}, "
f"Milestones: {data['milestones']}, Reflection: {data['reflection_log']}"
)
# Tokenize input
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate output
outputs = model.generate(
inputs["input_ids"],
max_length=200,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
temperature=0.7
)
# Decode response
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Try parsing as JSON first
try:
response_json = json.loads(response_text)
# Validate required fields in response
if not all(key in response_json for key in ["checklist", "tips", "quote"]):
raise ValueError("Missing required fields in model output")
except (json.JSONDecodeError, ValueError):
# If not JSON or invalid JSON, parse raw text
logger.warning("Model output is not valid JSON, parsing raw text")
response_json = parse_raw_text_to_json(response_text)
return jsonify(response_json)
except Exception as e:
logger.error(f"Error generating coaching response: {str(e)}")
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
if __name__ == "__main__":
# Start background tasks before the app runs
start_background_tasks()
# Run Flask app with waitress for production-ready WSGI server
from waitress import serve
logger.debug("Starting Flask app with Waitress")
serve(app, host="0.0.0.0", port=7860)
|