from fastapi import FastAPI, HTTPException, BackgroundTasks from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import json import logging import os import asyncio # Set up logging to both stdout and a file logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[ logging.StreamHandler(), # Log to stdout logging.FileHandler("/app/app.log") # Log to a file ] ) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI() # Define input model for validation class CoachingInput(BaseModel): role: str project_id: str milestones: str reflection_log: str # Global variables for model and tokenizer model = None tokenizer = None model_load_status = "not_loaded" # Define model path and fallback model_path = "/app/fine-tuned-construction-llm" fallback_model = "distilgpt2" # Asynchronous function to load model in the background async def load_model_background(): global model, tokenizer, model_load_status try: if os.path.isdir(model_path): logger.info(f"Loading local model from {model_path}") model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True) tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) model_load_status = "local_model_loaded" else: logger.info(f"Model directory not found: {model_path}. Using pre-trained model: {fallback_model}") model = AutoModelForCausalLM.from_pretrained(fallback_model) tokenizer = AutoTokenizer.from_pretrained(fallback_model) model_load_status = "fallback_model_loaded" logger.info("Model and tokenizer loaded successfully") except Exception as e: logger.error(f"Failed to load model or tokenizer: {str(e)}") model_load_status = f"failed: {str(e)}" # Startup event to initiate model loading in the background @app.on_event("startup") async def startup_event(background_tasks: BackgroundTasks): logger.info("FastAPI application started") background_tasks.add_task(load_model_background) @app.get("/health") async def health_check(): return { "status": "healthy" if model_load_status in ["local_model_loaded", "fallback_model_loaded"] else "starting", "model_load_status": model_load_status } @app.post("/generate_coaching") async def generate_coaching(data: CoachingInput): if model is None or tokenizer is None: logger.error("Model or tokenizer not loaded") raise HTTPException(status_code=503, detail="Model not loaded yet. Please try again later.") try: # Prepare input text input_text = ( f"Role: {data.role}, Project: {data.project_id}, " f"Milestones: {data.milestones}, Reflection: {data.reflection_log}" ) # Tokenize input inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) # Generate output outputs = model.generate( inputs["input_ids"], max_length=200, num_return_sequences=1, no_repeat_ngram_size=2, do_sample=True, temperature=0.7 ) # Decode and parse response response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Since distilgpt2 may not output JSON, parse the response manually or use fallback if not response_text.startswith("{"): checklist = ["Inspect safety equipment", "Review milestone progress"] tips = ["Prioritize team communication", "Check weather updates"] quote = "Every step forward counts!" response_json = {"checklist": checklist, "tips": tips, "quote": quote} logger.warning("Model output is not JSON, using default response") else: try: response_json = json.loads(response_text) except json.JSONDecodeError: response_json = { "checklist": ["Inspect safety equipment", "Review milestone progress"], "tips": ["Prioritize team communication", "Check weather updates"], "quote": "Every step forward counts!" } logger.warning("Failed to parse model output as JSON, using default response") return response_json except Exception as e: logger.error(f"Error generating coaching response: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")