File size: 6,656 Bytes
b181476
 
3db7383
875d1bc
 
6c9e43a
 
b181476
fdb1a6b
16dbf0f
0d34466
16dbf0f
 
fdb1a6b
16dbf0f
 
b181476
 
875d1bc
 
 
e7c1a90
 
 
 
 
 
3db7383
16dbf0f
3db7383
875d1bc
 
e7c1a90
 
 
 
 
 
 
 
16dbf0f
e7c1a90
 
 
 
 
 
 
 
875d1bc
 
 
 
 
 
e7c1a90
6c9e43a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
875d1bc
 
0d34466
875d1bc
0d34466
875d1bc
 
0d34466
875d1bc
16dbf0f
 
875d1bc
b181476
875d1bc
 
0d34466
6c9e43a
875d1bc
 
 
 
 
 
 
 
 
 
 
6c9e43a
 
 
 
0d34466
 
 
6c9e43a
e7c1a90
b181476
 
 
875d1bc
 
b181476
6c9e43a
b181476
 
6c9e43a
b181476
 
 
 
 
 
 
 
 
6c9e43a
 
b181476
6c9e43a
 
 
 
 
 
 
 
 
 
 
 
875d1bc
6c9e43a
b181476
 
875d1bc
 
 
 
 
 
 
 
6c9e43a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import json
import logging
import os
import threading
import time
from flask import Flask, request, jsonify
from transformers import AutoModelForCausalLM, AutoTokenizer

# Set up logging to stdout only
logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[
        logging.StreamHandler()  # Log to stdout
    ]
)
logger = logging.getLogger(__name__)

# Initialize Flask app
app = Flask(__name__)

# Global variables for model and tokenizer
model = None
tokenizer = None
model_load_status = "not_loaded"

# Define model path and fallback
model_path = "/app/fine-tuned-construction-llm"
fallback_model = "distilgpt2"

# Function to load model in the background
def load_model_background():
    global model, tokenizer, model_load_status
    try:
        if os.path.isdir(model_path):
            logger.info(f"Loading local model from {model_path}")
            model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)
            tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
            model_load_status = "local_model_loaded"
        else:
            logger.info(f"Model directory not found: {model_path}. Using pre-trained model: {fallback_model}")
            model = AutoModelForCausalLM.from_pretrained(fallback_model)
            tokenizer = AutoTokenizer.from_pretrained(fallback_model)
            model_load_status = "fallback_model_loaded"
        logger.info("Model and tokenizer loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load model or tokenizer: {str(e)}")
        model_load_status = f"failed: {str(e)}"

# Start model loading in a background thread
def start_background_tasks():
    logger.debug("Starting background tasks")
    thread = threading.Thread(target=load_model_background)
    thread.daemon = True
    thread.start()

# Utility function to wait for model loading with a timeout
def wait_for_model(timeout=60):
    start_time = time.time()
    while time.time() - start_time < timeout:
        if model_load_status in ["local_model_loaded", "fallback_model_loaded"]:
            return True
        elif "failed" in model_load_status:
            return False
        time.sleep(1)
    return False

# Utility function to parse raw text into structured JSON response
def parse_raw_text_to_json(raw_text):
    lines = raw_text.strip().split("\n")
    checklist = []
    tips = []
    quote = "Every step forward counts!"

    # Simple parsing logic: look for keywords or assume structure
    for line in lines:
        line = line.strip()
        if line.startswith("- "):  # Assume checklist items start with "- "
            checklist.append(line[2:].strip())
        elif line.startswith("* "):  # Assume tips start with "* "
            tips.append(line[2:].strip())
        elif line and not checklist and not tips:  # If no checklist or tips, treat as a quote
            quote = line

    # Fallback if parsing fails
    if not checklist:
        checklist = ["Inspect safety equipment", "Review milestone progress"]
    if not tips:
        tips = ["Prioritize team communication", "Check weather updates"]

    return {
        "checklist": checklist,
        "tips": tips,
        "quote": quote
    }

@app.route("/")
def root():
    logger.debug("Root endpoint accessed")
    return jsonify({"message": "Supervisor AI Coach is running"})

@app.route("/health")
def health_check():
    logger.debug("Health endpoint accessed")
    return jsonify({
        "status": "healthy" if model_load_status in ["local_model_loaded", "fallback_model_loaded"] else "starting",
        "model_load_status": model_load_status
    })

@app.route("/generate_coaching", methods=["POST"])
def generate_coaching():
    logger.debug("Generate coaching endpoint accessed")
    # Manual validation of request data
    data = request.get_json()
    if not data:
        logger.error("Invalid request: No JSON data provided")
        return jsonify({"error": "Invalid request: JSON data required"}), 400

    required_fields = ["role", "project_id", "milestones", "reflection_log"]
    missing_fields = [field for field in required_fields if field not in data]
    if missing_fields:
        logger.error(f"Missing required fields: {missing_fields}")
        return jsonify({"error": f"Missing required fields: {missing_fields}"}), 400

    # Wait for the model to load (up to 60 seconds)
    if not wait_for_model(timeout=60):
        logger.warning("Model failed to load within timeout")
        return jsonify({
            "checklist": ["Inspect safety equipment", "Review milestone progress"],
            "tips": ["Prioritize team communication", "Check weather updates"],
            "quote": "Every step forward counts!"
        })

    try:
        # Prepare input text
        input_text = (
            f"Role: {data['role']}, Project: {data['project_id']}, "
            f"Milestones: {data['milestones']}, Reflection: {data['reflection_log']}"
        )
       
        # Tokenize input
        inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
       
        # Generate output
        outputs = model.generate(
            inputs["input_ids"],
            max_length=200,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            do_sample=True,
            temperature=0.7
        )
       
        # Decode response
        response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
       
        # Try parsing as JSON first
        try:
            response_json = json.loads(response_text)
            # Validate required fields in response
            if not all(key in response_json for key in ["checklist", "tips", "quote"]):
                raise ValueError("Missing required fields in model output")
        except (json.JSONDecodeError, ValueError):
            # If not JSON or invalid JSON, parse raw text
            logger.warning("Model output is not valid JSON, parsing raw text")
            response_json = parse_raw_text_to_json(response_text)
       
        return jsonify(response_json)
   
    except Exception as e:
        logger.error(f"Error generating coaching response: {str(e)}")
        return jsonify({"error": f"Internal server error: {str(e)}"}), 500

if __name__ == "__main__":
    # Start background tasks before the app runs
    start_background_tasks()
    # Run Flask app with waitress for production-ready WSGI server
    from waitress import serve
    logger.debug("Starting Flask app with Waitress")
    serve(app, host="0.0.0.0", port=7860)