AISupervisor / app.py
geethareddy's picture
Update app.py
95f7d07 verified
import json
import logging
import os
import threading
import time
from flask import Flask, request, jsonify
from transformers import AutoModelForCausalLM, AutoTokenizer
# Set up logging to stdout only
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler() # Log to stdout
]
)
logger = logging.getLogger(__name__)
# Initialize Flask app
app = Flask(__name__)
# Global variables for model and tokenizer
model = None
tokenizer = None
model_load_status = "not_loaded"
# Define model path and fallback
model_path = "/app/fine-tuned-construction-llm"
fallback_model = "distilgpt2"
# Function to load model in the background
def load_model_background():
global model, tokenizer, model_load_status
try:
if os.path.isdir(model_path):
logger.info(f"Loading local model from {model_path}")
model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
model_load_status = "local_model_loaded"
else:
logger.info(f"Model directory not found: {model_path}. Using pre-trained model: {fallback_model}")
model = AutoModelForCausalLM.from_pretrained(fallback_model)
tokenizer = AutoTokenizer.from_pretrained(fallback_model)
model_load_status = "fallback_model_loaded"
logger.info("Model and tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {str(e)}")
model_load_status = f"failed: {str(e)}"
# Start model loading in a background thread
def start_background_tasks():
logger.debug("Starting background tasks")
thread = threading.Thread(target=load_model_background)
thread.daemon = True
thread.start()
# Utility function to wait for model loading with a timeout
def wait_for_model(timeout=60):
start_time = time.time()
while time.time() - start_time < timeout:
if model_load_status in ["local_model_loaded", "fallback_model_loaded"]:
return True
elif "failed" in model_load_status:
return False
time.sleep(1)
return False
# Utility function to parse raw text into structured JSON response
def parse_raw_text_to_json(raw_text):
lines = raw_text.strip().split("\n")
checklist = []
tips = []
quote = "Every step forward counts!"
checklist_section = False
tips_section = False
for line in lines:
line = line.strip()
if not line:
continue
if line.lower().startswith("checklist:"):
checklist_section = True
tips_section = False
continue
elif line.lower().startswith("tips:"):
checklist_section = False
tips_section = True
continue
elif line.lower().startswith("quote:"):
checklist_section = False
tips_section = False
quote = line[6:].strip() or quote
continue
if checklist_section and line.startswith("- "):
checklist.append(line[2:].strip())
elif tips_section and line.startswith("* "):
tips.append(line[2:].strip())
elif not checklist and not tips and line:
# If no sections are defined, try to infer structure
if line.startswith("- "):
checklist.append(line[2:].strip())
elif line.startswith("* "):
tips.append(line[2:].strip())
else:
quote = line
# Fallback if parsing fails
if not checklist:
checklist = ["Inspect safety equipment", "Review milestone progress"]
if not tips:
tips = ["Prioritize team communication", "Check weather updates"]
return {
"checklist": checklist,
"tips": tips,
"quote": quote
}
@app.route("/")
def root():
logger.debug("Root endpoint accessed")
return jsonify({"message": "Supervisor AI Coach is running"})
@app.route("/health")
def health_check():
logger.debug("Health endpoint accessed")
return jsonify({
"status": "healthy" if model_load_status in ["local_model_loaded", "fallback_model_loaded"] else "starting",
"model_load_status": model_load_status
})
@app.route("/debug", methods=["POST"])
def debug():
logger.debug("Debug endpoint accessed")
data = request.get_json()
if not data:
return jsonify({"error": "Invalid request: JSON data required"}), 400
required_fields = ["role", "project_id", "milestones", "reflection_log"]
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
return jsonify({"error": f"Missing required fields: {missing_fields}"}), 400
input_text = (
f"Role: {data['role']}, Project: {data['project_id']}, "
f"Milestones: {data['milestones']}, Reflection: {data['reflection_log']}"
)
return jsonify({
"model_load_status": model_load_status,
"input_text": input_text,
"model_ready": model is not None and tokenizer is not None
})
@app.route("/generate_coaching", methods=["POST"])
def generate_coaching():
logger.debug("Generate coaching endpoint accessed")
# Manual validation of request data
data = request.get_json()
if not data:
logger.error("Invalid request: No JSON data provided")
return jsonify({"error": "Invalid request: JSON data required"}), 400
required_fields = ["role", "project_id", "milestones", "reflection_log"]
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
logger.error(f"Missing required fields: {missing_fields}")
return jsonify({"error": f"Missing required fields: {missing_fields}"}), 400
# Wait for the model to load (up to 60 seconds)
if not wait_for_model(timeout=60):
logger.warning("Model failed to load within timeout")
return jsonify({
"checklist": ["Inspect safety equipment", "Review milestone progress"],
"tips": ["Prioritize team communication", "Check weather updates"],
"quote": "Every step forward counts!"
})
try:
# Prepare input text with a structured prompt to encourage formatted output
input_text = (
f"Role: {data['role']}, Project: {data['project_id']}, "
f"Milestones: {data['milestones']}, Reflection: {data['reflection_log']}\n\n"
"Generate a coaching response in the following format:\n"
"Checklist:\n- <item1>\n- <item2>\n"
"Tips:\n* <tip1>\n* <tip2>\n"
"Quote: <motivational quote>"
)
# Tokenize input
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate output
outputs = model.generate(
inputs["input_ids"],
max_length=200,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
temperature=0.7
)
# Decode response
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.debug(f"Raw model output: {response_text}")
# Try parsing as JSON first
try:
response_json = json.loads(response_text)
# Validate required fields in response
if not all(key in response_json for key in ["checklist", "tips", "quote"]):
raise ValueError("Missing required fields in model output")
except (json.JSONDecodeError, ValueError):
# If not JSON or invalid JSON, parse raw text
logger.warning("Model output is not valid JSON, parsing raw text")
response_json = parse_raw_text_to_json(response_text)
return jsonify(response_json)
except Exception as e:
logger.error(f"Error generating coaching response: {str(e)}")
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
if __name__ == "__main__":
# Start background tasks before the app runs
start_background_tasks()
# Run Flask app with waitress for production-ready WSGI server
from waitress import serve
logger.debug("Starting Flask app with Waitress")
serve(app, host="0.0.0.0", port=7860)