File size: 8,494 Bytes
b181476
 
3db7383
875d1bc
 
6c9e43a
 
b181476
fdb1a6b
16dbf0f
0d34466
16dbf0f
 
fdb1a6b
16dbf0f
 
b181476
 
875d1bc
 
 
e7c1a90
 
 
 
 
 
3db7383
16dbf0f
3db7383
875d1bc
 
e7c1a90
 
 
 
 
 
 
 
16dbf0f
e7c1a90
 
 
 
 
 
 
 
875d1bc
 
 
 
 
 
e7c1a90
6c9e43a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95f7d07
 
 
6c9e43a
 
95f7d07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c9e43a
95f7d07
6c9e43a
95f7d07
 
 
 
 
 
 
 
6c9e43a
 
 
 
 
 
 
 
 
 
 
 
 
875d1bc
 
0d34466
875d1bc
0d34466
875d1bc
 
0d34466
875d1bc
16dbf0f
 
875d1bc
b181476
95f7d07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
875d1bc
 
0d34466
6c9e43a
875d1bc
 
 
 
 
 
 
 
 
 
 
6c9e43a
 
 
 
0d34466
 
 
6c9e43a
e7c1a90
b181476
95f7d07
b181476
875d1bc
95f7d07
 
 
 
 
b181476
6c9e43a
b181476
 
6c9e43a
b181476
 
 
 
 
 
 
 
 
6c9e43a
 
b181476
95f7d07
6c9e43a
 
 
 
 
 
 
 
 
 
 
 
875d1bc
6c9e43a
b181476
 
875d1bc
 
 
 
 
 
 
 
95f7d07
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import json
import logging
import os
import threading
import time
from flask import Flask, request, jsonify
from transformers import AutoModelForCausalLM, AutoTokenizer

# Set up logging to stdout only
logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[
        logging.StreamHandler()  # Log to stdout
    ]
)
logger = logging.getLogger(__name__)

# Initialize Flask app
app = Flask(__name__)

# Global variables for model and tokenizer
model = None
tokenizer = None
model_load_status = "not_loaded"

# Define model path and fallback
model_path = "/app/fine-tuned-construction-llm"
fallback_model = "distilgpt2"

# Function to load model in the background
def load_model_background():
    global model, tokenizer, model_load_status
    try:
        if os.path.isdir(model_path):
            logger.info(f"Loading local model from {model_path}")
            model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)
            tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
            model_load_status = "local_model_loaded"
        else:
            logger.info(f"Model directory not found: {model_path}. Using pre-trained model: {fallback_model}")
            model = AutoModelForCausalLM.from_pretrained(fallback_model)
            tokenizer = AutoTokenizer.from_pretrained(fallback_model)
            model_load_status = "fallback_model_loaded"
        logger.info("Model and tokenizer loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load model or tokenizer: {str(e)}")
        model_load_status = f"failed: {str(e)}"

# Start model loading in a background thread
def start_background_tasks():
    logger.debug("Starting background tasks")
    thread = threading.Thread(target=load_model_background)
    thread.daemon = True
    thread.start()

# Utility function to wait for model loading with a timeout
def wait_for_model(timeout=60):
    start_time = time.time()
    while time.time() - start_time < timeout:
        if model_load_status in ["local_model_loaded", "fallback_model_loaded"]:
            return True
        elif "failed" in model_load_status:
            return False
        time.sleep(1)
    return False

# Utility function to parse raw text into structured JSON response
def parse_raw_text_to_json(raw_text):
    lines = raw_text.strip().split("\n")
    checklist = []
    tips = []
    quote = "Every step forward counts!"

    checklist_section = False
    tips_section = False

    for line in lines:
        line = line.strip()
        if not line:
            continue
        if line.lower().startswith("checklist:"):
            checklist_section = True
            tips_section = False
            continue
        elif line.lower().startswith("tips:"):
            checklist_section = False
            tips_section = True
            continue
        elif line.lower().startswith("quote:"):
            checklist_section = False
            tips_section = False
            quote = line[6:].strip() or quote
            continue

        if checklist_section and line.startswith("- "):
            checklist.append(line[2:].strip())
        elif tips_section and line.startswith("* "):
            tips.append(line[2:].strip())
        elif not checklist and not tips and line:
            # If no sections are defined, try to infer structure
            if line.startswith("- "):
                checklist.append(line[2:].strip())
            elif line.startswith("* "):
                tips.append(line[2:].strip())
            else:
                quote = line

    # Fallback if parsing fails
    if not checklist:
        checklist = ["Inspect safety equipment", "Review milestone progress"]
    if not tips:
        tips = ["Prioritize team communication", "Check weather updates"]

    return {
        "checklist": checklist,
        "tips": tips,
        "quote": quote
    }

@app.route("/")
def root():
    logger.debug("Root endpoint accessed")
    return jsonify({"message": "Supervisor AI Coach is running"})

@app.route("/health")
def health_check():
    logger.debug("Health endpoint accessed")
    return jsonify({
        "status": "healthy" if model_load_status in ["local_model_loaded", "fallback_model_loaded"] else "starting",
        "model_load_status": model_load_status
    })

@app.route("/debug", methods=["POST"])
def debug():
    logger.debug("Debug endpoint accessed")
    data = request.get_json()
    if not data:
        return jsonify({"error": "Invalid request: JSON data required"}), 400

    required_fields = ["role", "project_id", "milestones", "reflection_log"]
    missing_fields = [field for field in required_fields if field not in data]
    if missing_fields:
        return jsonify({"error": f"Missing required fields: {missing_fields}"}), 400

    input_text = (
        f"Role: {data['role']}, Project: {data['project_id']}, "
        f"Milestones: {data['milestones']}, Reflection: {data['reflection_log']}"
    )

    return jsonify({
        "model_load_status": model_load_status,
        "input_text": input_text,
        "model_ready": model is not None and tokenizer is not None
    })

@app.route("/generate_coaching", methods=["POST"])
def generate_coaching():
    logger.debug("Generate coaching endpoint accessed")
    # Manual validation of request data
    data = request.get_json()
    if not data:
        logger.error("Invalid request: No JSON data provided")
        return jsonify({"error": "Invalid request: JSON data required"}), 400

    required_fields = ["role", "project_id", "milestones", "reflection_log"]
    missing_fields = [field for field in required_fields if field not in data]
    if missing_fields:
        logger.error(f"Missing required fields: {missing_fields}")
        return jsonify({"error": f"Missing required fields: {missing_fields}"}), 400

    # Wait for the model to load (up to 60 seconds)
    if not wait_for_model(timeout=60):
        logger.warning("Model failed to load within timeout")
        return jsonify({
            "checklist": ["Inspect safety equipment", "Review milestone progress"],
            "tips": ["Prioritize team communication", "Check weather updates"],
            "quote": "Every step forward counts!"
        })

    try:
        # Prepare input text with a structured prompt to encourage formatted output
        input_text = (
            f"Role: {data['role']}, Project: {data['project_id']}, "
            f"Milestones: {data['milestones']}, Reflection: {data['reflection_log']}\n\n"
            "Generate a coaching response in the following format:\n"
            "Checklist:\n- <item1>\n- <item2>\n"
            "Tips:\n* <tip1>\n* <tip2>\n"
            "Quote: <motivational quote>"
        )
       
        # Tokenize input
        inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
       
        # Generate output
        outputs = model.generate(
            inputs["input_ids"],
            max_length=200,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            do_sample=True,
            temperature=0.7
        )
       
        # Decode response
        response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        logger.debug(f"Raw model output: {response_text}")
       
        # Try parsing as JSON first
        try:
            response_json = json.loads(response_text)
            # Validate required fields in response
            if not all(key in response_json for key in ["checklist", "tips", "quote"]):
                raise ValueError("Missing required fields in model output")
        except (json.JSONDecodeError, ValueError):
            # If not JSON or invalid JSON, parse raw text
            logger.warning("Model output is not valid JSON, parsing raw text")
            response_json = parse_raw_text_to_json(response_text)
       
        return jsonify(response_json)
   
    except Exception as e:
        logger.error(f"Error generating coaching response: {str(e)}")
        return jsonify({"error": f"Internal server error: {str(e)}"}), 500

if __name__ == "__main__":
    # Start background tasks before the app runs
    start_background_tasks()
    # Run Flask app with waitress for production-ready WSGI server
    from waitress import serve
    logger.debug("Starting Flask app with Waitress")
    serve(app, host="0.0.0.0", port=7860)