Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import json | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# Define input model for validation | |
class CoachingInput(BaseModel): | |
role: str | |
project_id: str | |
milestones: str | |
reflection_log: str | |
# Load model and tokenizer | |
try: | |
model_path = "./fine_tuned_construction_llm" | |
model = AutoModelForCausalLM.from_pretrained(model_path) | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
logger.info("Model and tokenizer loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load model or tokenizer: {str(e)}") | |
raise Exception(f"Model loading failed: {str(e)}") | |
async def generate_coaching(data: CoachingInput): | |
try: | |
# Prepare input text | |
input_text = ( | |
f"Role: {data.role}, Project: {data.project_id}, " | |
f"Milestones: {data.milestones}, Reflection: {data.reflection_log}" | |
) | |
# Tokenize input | |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
# Generate output | |
outputs = model.generate( | |
inputs["input_ids"], | |
max_length=200, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
do_sample=True, | |
temperature=0.7 | |
) | |
# Decode and parse response | |
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Simulate structured output (replace with actual parsing logic based on model output) | |
# This assumes the model outputs a JSON-like string; adjust based on fine-tuning | |
try: | |
response_json = json.loads(response_text) | |
except json.JSONDecodeError: | |
# Fallback: Construct a default response if parsing fails | |
response_json = { | |
"checklist": ["Inspect safety equipment", "Review milestone progress"], | |
"tips": ["Prioritize team communication", "Check weather updates"], | |
"quote": "Every step forward counts!" | |
} | |
logger.warning("Failed to parse model output as JSON, using default response") | |
return response_json | |
except Exception as e: | |
logger.error(f"Error generating coaching response: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
async def health_check(): | |
return {"status": "healthy"} | |