from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import json import logging import os # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Define input model for validation class CoachingInput(BaseModel): role: str project_id: str milestones: str reflection_log: str # Define model path (absolute path in the container) model_path = "/app/fine-tuned-construction-llm" # Verify the model directory exists if not os.path.isdir(model_path): logger.error(f"Model directory not found: {model_path}") raise Exception(f"Model directory not found: {model_path}") # Load model and tokenizer try: model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True) tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) logger.info("Model and tokenizer loaded successfully") except Exception as e: logger.error(f"Failed to load model or tokenizer: {str(e)}") raise Exception(f"Model loading failed: {str(e)}") @app.post("/generate_coaching") async def generate_coaching(data: CoachingInput): try: # Prepare input text input_text = ( f"Role: {data.role}, Project: {data.project_id}, " f"Milestones: {data.milestones}, Reflection: {data.reflection_log}" ) # Tokenize input inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) # Generate output outputs = model.generate( inputs["input_ids"], max_length=200, num_return_sequences=1, no_repeat_ngram_size=2, do_sample=True, temperature=0.7 ) # Decode and parse response response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Simulate structured output (replace with actual parsing logic based on model output) # This assumes the model outputs a JSON-like string; adjust based on fine-tuning try: response_json = json.loads(response_text) except json.JSONDecodeError: # Fallback: Construct a default response if parsing fails response_json = { "checklist": ["Inspect safety equipment", "Review milestone progress"], "tips": ["Prioritize team communication", "Check weather updates"], "quote": "Every step forward counts!" } logger.warning("Failed to parse model output as JSON, using default response") return response_json except Exception as e: logger.error(f"Error generating coaching response: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @app.get("/health") async def health_check(): return {"status": "healthy"}