Spaces:
Sleeping
Sleeping
File size: 4,297 Bytes
b181476 3db7383 b181476 e7c1a90 b181476 e7c1a90 3db7383 e7c1a90 3db7383 e7c1a90 b181476 e7c1a90 b181476 e7c1a90 92b443e b181476 e7c1a90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import logging
import os
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI()
# Define input model for validation
class CoachingInput(BaseModel):
role: str
project_id: str
milestones: str
reflection_log: str
# Global variables for model and tokenizer
model = None
tokenizer = None
model_load_status = "not_loaded"
# Define model path and fallback
model_path = "/app/fine-tuned-construction-llm"
fallback_model = "distilgpt2" # Smaller model for faster loading
# Load model and tokenizer at startup
def load_model():
global model, tokenizer, model_load_status
try:
if os.path.isdir(model_path):
logger.info(f"Loading local model from {model_path}")
model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
model_load_status = "local_model_loaded"
else:
logger.warning(f"Model directory not found: {model_path}. Falling back to pre-trained model: {fallback_model}")
model = AutoModelForCausalLM.from_pretrained(fallback_model)
tokenizer = AutoTokenizer.from_pretrained(fallback_model)
model_load_status = "fallback_model_loaded"
logger.info("Model and tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {str(e)}")
model_load_status = f"failed: {str(e)}"
# Do not raise an exception; allow the app to start
# Load model on startup
load_model()
@app.on_event("startup")
async def startup_event():
logger.info("FastAPI application started")
@app.get("/health")
async def health_check():
return {"status": "healthy", "model_load_status": model_load_status}
@app.post("/generate_coaching")
async def generate_coaching(data: CoachingInput):
if model is None or tokenizer is None:
logger.error("Model or tokenizer not loaded")
raise HTTPException(status_code=503, detail="Model not loaded. Please check server logs.")
try:
# Prepare input text
input_text = (
f"Role: {data.role}, Project: {data.project_id}, "
f"Milestones: {data.milestones}, Reflection: {data.reflection_log}"
)
# Tokenize input
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate output
outputs = model.generate(
inputs["input_ids"],
max_length=200,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
temperature=0.7
)
# Decode and parse response
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Since distilgpt2 may not output JSON, parse the response manually or use fallback
if not response_text.startswith("{"):
checklist = ["Inspect safety equipment", "Review milestone progress"]
tips = ["Prioritize team communication", "Check weather updates"]
quote = "Every step forward counts!"
response_json = {"checklist": checklist, "tips": tips, "quote": quote}
logger.warning("Model output is not JSON, using default response")
else:
try:
response_json = json.loads(response_text)
except json.JSONDecodeError:
response_json = {
"checklist": ["Inspect safety equipment", "Review milestone progress"],
"tips": ["Prioritize team communication", "Check weather updates"],
"quote": "Every step forward counts!"
}
logger.warning("Failed to parse model output as JSON, using default response")
return response_json
except Exception as e:
logger.error(f"Error generating coaching response: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |