Spaces:
Sleeping
Sleeping
File size: 4,707 Bytes
16dbf0f b181476 3db7383 16dbf0f b181476 16dbf0f b181476 e7c1a90 b181476 e7c1a90 3db7383 16dbf0f 3db7383 16dbf0f e7c1a90 16dbf0f e7c1a90 16dbf0f e7c1a90 16dbf0f e7c1a90 16dbf0f e7c1a90 16dbf0f b181476 e7c1a90 16dbf0f e7c1a90 b181476 e7c1a90 92b443e b181476 e7c1a90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import logging
import os
import asyncio
# Set up logging to both stdout and a file
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler(), # Log to stdout
logging.FileHandler("/app/app.log") # Log to a file
]
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI()
# Define input model for validation
class CoachingInput(BaseModel):
role: str
project_id: str
milestones: str
reflection_log: str
# Global variables for model and tokenizer
model = None
tokenizer = None
model_load_status = "not_loaded"
# Define model path and fallback
model_path = "/app/fine-tuned-construction-llm"
fallback_model = "distilgpt2"
# Asynchronous function to load model in the background
async def load_model_background():
global model, tokenizer, model_load_status
try:
if os.path.isdir(model_path):
logger.info(f"Loading local model from {model_path}")
model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
model_load_status = "local_model_loaded"
else:
logger.info(f"Model directory not found: {model_path}. Using pre-trained model: {fallback_model}")
model = AutoModelForCausalLM.from_pretrained(fallback_model)
tokenizer = AutoTokenizer.from_pretrained(fallback_model)
model_load_status = "fallback_model_loaded"
logger.info("Model and tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {str(e)}")
model_load_status = f"failed: {str(e)}"
# Startup event to initiate model loading in the background
@app.on_event("startup")
async def startup_event(background_tasks: BackgroundTasks):
logger.info("FastAPI application started")
background_tasks.add_task(load_model_background)
@app.get("/health")
async def health_check():
return {
"status": "healthy" if model_load_status in ["local_model_loaded", "fallback_model_loaded"] else "starting",
"model_load_status": model_load_status
}
@app.post("/generate_coaching")
async def generate_coaching(data: CoachingInput):
if model is None or tokenizer is None:
logger.error("Model or tokenizer not loaded")
raise HTTPException(status_code=503, detail="Model not loaded yet. Please try again later.")
try:
# Prepare input text
input_text = (
f"Role: {data.role}, Project: {data.project_id}, "
f"Milestones: {data.milestones}, Reflection: {data.reflection_log}"
)
# Tokenize input
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate output
outputs = model.generate(
inputs["input_ids"],
max_length=200,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
temperature=0.7
)
# Decode and parse response
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Since distilgpt2 may not output JSON, parse the response manually or use fallback
if not response_text.startswith("{"):
checklist = ["Inspect safety equipment", "Review milestone progress"]
tips = ["Prioritize team communication", "Check weather updates"]
quote = "Every step forward counts!"
response_json = {"checklist": checklist, "tips": tips, "quote": quote}
logger.warning("Model output is not JSON, using default response")
else:
try:
response_json = json.loads(response_text)
except json.JSONDecodeError:
response_json = {
"checklist": ["Inspect safety equipment", "Review milestone progress"],
"tips": ["Prioritize team communication", "Check weather updates"],
"quote": "Every step forward counts!"
}
logger.warning("Failed to parse model output as JSON, using default response")
return response_json
except Exception as e:
logger.error(f"Error generating coaching response: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |