File size: 4,297 Bytes
b181476
 
 
 
 
3db7383
b181476
 
 
 
 
e7c1a90
b181476
 
 
 
 
 
 
 
 
e7c1a90
 
 
 
 
 
3db7383
e7c1a90
3db7383
e7c1a90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b181476
 
 
e7c1a90
 
 
 
b181476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7c1a90
92b443e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b181476
 
 
 
 
e7c1a90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import logging
import os

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI()

# Define input model for validation
class CoachingInput(BaseModel):
    role: str
    project_id: str
    milestones: str
    reflection_log: str

# Global variables for model and tokenizer
model = None
tokenizer = None
model_load_status = "not_loaded"

# Define model path and fallback
model_path = "/app/fine-tuned-construction-llm"
fallback_model = "distilgpt2"  # Smaller model for faster loading

# Load model and tokenizer at startup
def load_model():
    global model, tokenizer, model_load_status
    try:
        if os.path.isdir(model_path):
            logger.info(f"Loading local model from {model_path}")
            model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)
            tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
            model_load_status = "local_model_loaded"
        else:
            logger.warning(f"Model directory not found: {model_path}. Falling back to pre-trained model: {fallback_model}")
            model = AutoModelForCausalLM.from_pretrained(fallback_model)
            tokenizer = AutoTokenizer.from_pretrained(fallback_model)
            model_load_status = "fallback_model_loaded"
        logger.info("Model and tokenizer loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load model or tokenizer: {str(e)}")
        model_load_status = f"failed: {str(e)}"
        # Do not raise an exception; allow the app to start

# Load model on startup
load_model()

@app.on_event("startup")
async def startup_event():
    logger.info("FastAPI application started")

@app.get("/health")
async def health_check():
    return {"status": "healthy", "model_load_status": model_load_status}

@app.post("/generate_coaching")
async def generate_coaching(data: CoachingInput):
    if model is None or tokenizer is None:
        logger.error("Model or tokenizer not loaded")
        raise HTTPException(status_code=503, detail="Model not loaded. Please check server logs.")

    try:
        # Prepare input text
        input_text = (
            f"Role: {data.role}, Project: {data.project_id}, "
            f"Milestones: {data.milestones}, Reflection: {data.reflection_log}"
        )
        
        # Tokenize input
        inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
        
        # Generate output
        outputs = model.generate(
            inputs["input_ids"],
            max_length=200,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            do_sample=True,
            temperature=0.7
        )
        
        # Decode and parse response
        response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Since distilgpt2 may not output JSON, parse the response manually or use fallback
        if not response_text.startswith("{"):
            checklist = ["Inspect safety equipment", "Review milestone progress"]
            tips = ["Prioritize team communication", "Check weather updates"]
            quote = "Every step forward counts!"
            response_json = {"checklist": checklist, "tips": tips, "quote": quote}
            logger.warning("Model output is not JSON, using default response")
        else:
            try:
                response_json = json.loads(response_text)
            except json.JSONDecodeError:
                response_json = {
                    "checklist": ["Inspect safety equipment", "Review milestone progress"],
                    "tips": ["Prioritize team communication", "Check weather updates"],
                    "quote": "Every step forward counts!"
                }
                logger.warning("Failed to parse model output as JSON, using default response")
        
        return response_json
    
    except Exception as e:
        logger.error(f"Error generating coaching response: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")