File size: 5,799 Bytes
875d1bc
b181476
 
 
3db7383
875d1bc
 
b181476
fdb1a6b
16dbf0f
0d34466
16dbf0f
 
fdb1a6b
16dbf0f
 
b181476
 
875d1bc
 
 
e7c1a90
 
 
 
 
 
3db7383
16dbf0f
3db7383
875d1bc
 
e7c1a90
 
 
 
 
 
 
 
16dbf0f
e7c1a90
 
 
 
 
 
 
 
875d1bc
 
 
 
 
 
e7c1a90
875d1bc
 
0d34466
875d1bc
0d34466
875d1bc
 
0d34466
875d1bc
16dbf0f
 
875d1bc
b181476
875d1bc
 
0d34466
875d1bc
 
 
 
 
 
 
 
 
 
 
 
e7c1a90
0d34466
 
 
 
 
 
 
875d1bc
e7c1a90
b181476
 
 
875d1bc
 
b181476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7c1a90
92b443e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b181476
875d1bc
b181476
 
 
875d1bc
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from flask import Flask, request, jsonify
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import logging
import os
import threading
import time

# Set up logging to stdout only
logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[
        logging.StreamHandler()  # Log to stdout
    ]
)
logger = logging.getLogger(__name__)

# Initialize Flask app
app = Flask(__name__)

# Global variables for model and tokenizer
model = None
tokenizer = None
model_load_status = "not_loaded"

# Define model path and fallback
model_path = "/app/fine-tuned-construction-llm"
fallback_model = "distilgpt2"

# Function to load model in the background
def load_model_background():
    global model, tokenizer, model_load_status
    try:
        if os.path.isdir(model_path):
            logger.info(f"Loading local model from {model_path}")
            model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)
            tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
            model_load_status = "local_model_loaded"
        else:
            logger.info(f"Model directory not found: {model_path}. Using pre-trained model: {fallback_model}")
            model = AutoModelForCausalLM.from_pretrained(fallback_model)
            tokenizer = AutoTokenizer.from_pretrained(fallback_model)
            model_load_status = "fallback_model_loaded"
        logger.info("Model and tokenizer loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load model or tokenizer: {str(e)}")
        model_load_status = f"failed: {str(e)}"

# Start model loading in a background thread
def start_background_tasks():
    logger.debug("Starting background tasks")
    thread = threading.Thread(target=load_model_background)
    thread.daemon = True
    thread.start()

@app.route("/")
def root():
    logger.debug("Root endpoint accessed")
    return jsonify({"message": "Supervisor AI Coach is running"})

@app.route("/health")
def health_check():
    logger.debug("Health endpoint accessed")
    return jsonify({
        "status": "healthy" if model_load_status in ["local_model_loaded", "fallback_model_loaded"] else "starting",
        "model_load_status": model_load_status
    })

@app.route("/generate_coaching", methods=["POST"])
def generate_coaching():
    logger.debug("Generate coaching endpoint accessed")
    # Manual validation of request data (replacing Pydantic)
    data = request.get_json()
    if not data:
        logger.error("Invalid request: No JSON data provided")
        return jsonify({"error": "Invalid request: JSON data required"}), 400

    required_fields = ["role", "project_id", "milestones", "reflection_log"]
    missing_fields = [field for field in required_fields if field not in data]
    if missing_fields:
        logger.error(f"Missing required fields: {missing_fields}")
        return jsonify({"error": f"Missing required fields: {missing_fields}"}), 400

    if model is None or tokenizer is None:
        logger.warning("Model or tokenizer not loaded")
        # Return a static response if the model isn't loaded yet
        response_json = {
            "checklist": ["Inspect safety equipment", "Review milestone progress"],
            "tips": ["Prioritize team communication", "Check weather updates"],
            "quote": "Every step forward counts!"
        }
        return jsonify(response_json)

    try:
        # Prepare input text
        input_text = (
            f"Role: {data['role']}, Project: {data['project_id']}, "
            f"Milestones: {data['milestones']}, Reflection: {data['reflection_log']}"
        )
        
        # Tokenize input
        inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
        
        # Generate output
        outputs = model.generate(
            inputs["input_ids"],
            max_length=200,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            do_sample=True,
            temperature=0.7
        )
        
        # Decode and parse response
        response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Since distilgpt2 may not output JSON, parse the response manually or use fallback
        if not response_text.startswith("{"):
            checklist = ["Inspect safety equipment", "Review milestone progress"]
            tips = ["Prioritize team communication", "Check weather updates"]
            quote = "Every step forward counts!"
            response_json = {"checklist": checklist, "tips": tips, "quote": quote}
            logger.warning("Model output is not JSON, using default response")
        else:
            try:
                response_json = json.loads(response_text)
            except json.JSONDecodeError:
                response_json = {
                    "checklist": ["Inspect safety equipment", "Review milestone progress"],
                    "tips": ["Prioritize team communication", "Check weather updates"],
                    "quote": "Every step forward counts!"
                }
                logger.warning("Failed to parse model output as JSON, using default response")
        
        return jsonify(response_json)
    
    except Exception as e:
        logger.error(f"Error generating coaching response: {str(e)}")
        return jsonify({"error": f"Internal server error: {str(e)}"}), 500

if __name__ == "__main__":
    # Start background tasks before the app runs
    start_background_tasks()
    # Run Flask app with waitress for production-ready WSGI server
    from waitress import serve
    logger.debug("Starting Flask app with Waitress")
    serve(app, host="0.0.0.0", port=7860)