import json import logging import os import threading import time from flask import Flask, request, jsonify from transformers import AutoModelForCausalLM, AutoTokenizer # Set up logging to stdout only logging.basicConfig( level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[ logging.StreamHandler() # Log to stdout ] ) logger = logging.getLogger(__name__) # Initialize Flask app app = Flask(__name__) # Global variables for model and tokenizer model = None tokenizer = None model_load_status = "not_loaded" # Define model path and fallback model_path = "/app/fine-tuned-construction-llm" fallback_model = "distilgpt2" # Function to load model in the background def load_model_background(): global model, tokenizer, model_load_status try: if os.path.isdir(model_path): logger.info(f"Loading local model from {model_path}") model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True) tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) model_load_status = "local_model_loaded" else: logger.info(f"Model directory not found: {model_path}. Using pre-trained model: {fallback_model}") model = AutoModelForCausalLM.from_pretrained(fallback_model) tokenizer = AutoTokenizer.from_pretrained(fallback_model) model_load_status = "fallback_model_loaded" logger.info("Model and tokenizer loaded successfully") except Exception as e: logger.error(f"Failed to load model or tokenizer: {str(e)}") model_load_status = f"failed: {str(e)}" # Start model loading in a background thread def start_background_tasks(): logger.debug("Starting background tasks") thread = threading.Thread(target=load_model_background) thread.daemon = True thread.start() # Utility function to wait for model loading with a timeout def wait_for_model(timeout=60): start_time = time.time() while time.time() - start_time < timeout: if model_load_status in ["local_model_loaded", "fallback_model_loaded"]: return True elif "failed" in model_load_status: return False time.sleep(1) return False # Utility function to parse raw text into structured JSON response def parse_raw_text_to_json(raw_text): lines = raw_text.strip().split("\n") checklist = [] tips = [] quote = "Every step forward counts!" checklist_section = False tips_section = False for line in lines: line = line.strip() if not line: continue if line.lower().startswith("checklist:"): checklist_section = True tips_section = False continue elif line.lower().startswith("tips:"): checklist_section = False tips_section = True continue elif line.lower().startswith("quote:"): checklist_section = False tips_section = False quote = line[6:].strip() or quote continue if checklist_section and line.startswith("- "): checklist.append(line[2:].strip()) elif tips_section and line.startswith("* "): tips.append(line[2:].strip()) elif not checklist and not tips and line: # If no sections are defined, try to infer structure if line.startswith("- "): checklist.append(line[2:].strip()) elif line.startswith("* "): tips.append(line[2:].strip()) else: quote = line # Fallback if parsing fails if not checklist: checklist = ["Inspect safety equipment", "Review milestone progress"] if not tips: tips = ["Prioritize team communication", "Check weather updates"] return { "checklist": checklist, "tips": tips, "quote": quote } @app.route("/") def root(): logger.debug("Root endpoint accessed") return jsonify({"message": "Supervisor AI Coach is running"}) @app.route("/health") def health_check(): logger.debug("Health endpoint accessed") return jsonify({ "status": "healthy" if model_load_status in ["local_model_loaded", "fallback_model_loaded"] else "starting", "model_load_status": model_load_status }) @app.route("/debug", methods=["POST"]) def debug(): logger.debug("Debug endpoint accessed") data = request.get_json() if not data: return jsonify({"error": "Invalid request: JSON data required"}), 400 required_fields = ["role", "project_id", "milestones", "reflection_log"] missing_fields = [field for field in required_fields if field not in data] if missing_fields: return jsonify({"error": f"Missing required fields: {missing_fields}"}), 400 input_text = ( f"Role: {data['role']}, Project: {data['project_id']}, " f"Milestones: {data['milestones']}, Reflection: {data['reflection_log']}" ) return jsonify({ "model_load_status": model_load_status, "input_text": input_text, "model_ready": model is not None and tokenizer is not None }) @app.route("/generate_coaching", methods=["POST"]) def generate_coaching(): logger.debug("Generate coaching endpoint accessed") # Manual validation of request data data = request.get_json() if not data: logger.error("Invalid request: No JSON data provided") return jsonify({"error": "Invalid request: JSON data required"}), 400 required_fields = ["role", "project_id", "milestones", "reflection_log"] missing_fields = [field for field in required_fields if field not in data] if missing_fields: logger.error(f"Missing required fields: {missing_fields}") return jsonify({"error": f"Missing required fields: {missing_fields}"}), 400 # Wait for the model to load (up to 60 seconds) if not wait_for_model(timeout=60): logger.warning("Model failed to load within timeout") return jsonify({ "checklist": ["Inspect safety equipment", "Review milestone progress"], "tips": ["Prioritize team communication", "Check weather updates"], "quote": "Every step forward counts!" }) try: # Prepare input text with a structured prompt to encourage formatted output input_text = ( f"Role: {data['role']}, Project: {data['project_id']}, " f"Milestones: {data['milestones']}, Reflection: {data['reflection_log']}\n\n" "Generate a coaching response in the following format:\n" "Checklist:\n- \n- \n" "Tips:\n* \n* \n" "Quote: " ) # Tokenize input inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) # Generate output outputs = model.generate( inputs["input_ids"], max_length=200, num_return_sequences=1, no_repeat_ngram_size=2, do_sample=True, temperature=0.7 ) # Decode response response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) logger.debug(f"Raw model output: {response_text}") # Try parsing as JSON first try: response_json = json.loads(response_text) # Validate required fields in response if not all(key in response_json for key in ["checklist", "tips", "quote"]): raise ValueError("Missing required fields in model output") except (json.JSONDecodeError, ValueError): # If not JSON or invalid JSON, parse raw text logger.warning("Model output is not valid JSON, parsing raw text") response_json = parse_raw_text_to_json(response_text) return jsonify(response_json) except Exception as e: logger.error(f"Error generating coaching response: {str(e)}") return jsonify({"error": f"Internal server error: {str(e)}"}), 500 if __name__ == "__main__": # Start background tasks before the app runs start_background_tasks() # Run Flask app with waitress for production-ready WSGI server from waitress import serve logger.debug("Starting Flask app with Waitress") serve(app, host="0.0.0.0", port=7860)