Spaces:
Running
Running
# from fastapi import FastAPI, HTTPException, status, Depends | |
# from fastapi.responses import RedirectResponse | |
# from pydantic import BaseModel, conlist | |
# import pandas as pd | |
# from pycaret.classification import load_model, predict_model | |
# import logging | |
# from typing import Optional | |
# import numpy as np | |
# import os | |
# # Constants | |
# MODEL_PATH = "./api/model/saved_tuned_model" # os.getenv("MODEL_PATH", "saved_tuned_model") # Load model path from environment variable | |
# EMBEDDING_DIMENSION = 1024 # Update this to match your model's expected input dimension | |
# LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") # Logging level from environment variable | |
# # Configure logging | |
# logging.basicConfig(level=LOG_LEVEL) | |
# logger = logging.getLogger(__name__) | |
# # Load the saved model | |
# def load_tuned_model(model_path: str): | |
# """Load the pre-trained model from the specified path.""" | |
# try: | |
# logger.info(f"Loading model from {model_path}...") | |
# model = load_model(model_path) | |
# logger.info("Model loaded successfully.") | |
# return model | |
# except Exception as e: | |
# logger.error(f"Failed to load the model: {str(e)}") | |
# raise RuntimeError(f"Model loading failed: {str(e)}") | |
# tuned_model = load_tuned_model(MODEL_PATH) | |
# # Define the input data model using Pydantic | |
# class EmbeddingRequest(BaseModel): | |
# embedding: conlist( | |
# float, min_length=EMBEDDING_DIMENSION, max_length=EMBEDDING_DIMENSION | |
# ) | |
# # Define the response model | |
# class PredictionResponse(BaseModel): | |
# predicted_label: int | |
# predicted_score: float | |
# # Initialize FastAPI app | |
# app = FastAPI( | |
# title="Embedding Prediction API", | |
# description="API for predicting labels and scores from embeddings using a pre-trained model.", | |
# version="1.0.0", | |
# ) | |
# # Dependency for model access | |
# def get_model(): | |
# """Dependency to provide the loaded model to endpoints.""" | |
# return tuned_model | |
# # Define the prediction endpoint | |
# @app.post("/predict", response_model=PredictionResponse) | |
# async def predict( | |
# request: EmbeddingRequest, | |
# model=Depends(get_model), | |
# ): | |
# """ | |
# Predicts the label and score for a given embedding. | |
# Args: | |
# request (EmbeddingRequest): A request containing the embedding as a list of floats. | |
# model: The pre-trained model injected via dependency. | |
# Returns: | |
# PredictionResponse: A response containing the predicted label and score. | |
# """ | |
# try: | |
# logger.info("Received prediction request.") | |
# # Convert the input embedding to a DataFrame | |
# input_data = pd.DataFrame( | |
# [request.embedding], | |
# columns=[f"embedding_{i}" for i in range(EMBEDDING_DIMENSION)], | |
# ) | |
# # Make a prediction using the loaded model | |
# logger.info("Making prediction...") | |
# prediction = predict_model(model, data=input_data) | |
# # Extract the predicted label and score | |
# predicted_label = prediction["prediction_label"].iloc[0] | |
# predicted_score = prediction["prediction_score"].iloc[0] | |
# logger.info( | |
# f"Prediction successful: label={predicted_label}, score={predicted_score}" | |
# ) | |
# return PredictionResponse( | |
# predicted_label=int(predicted_label), | |
# predicted_score=float(predicted_score), | |
# ) | |
# except Exception as e: | |
# logger.error(f"Prediction failed: {str(e)}") | |
# raise HTTPException( | |
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
# detail=f"An error occurred during prediction: {str(e)}", | |
# ) | |
# # Health check endpoint | |
# @app.get("/health", status_code=status.HTTP_200_OK) | |
# async def health_check(): | |
# """Health check endpoint to verify the API is running.""" | |
# return {"status": "healthy"} | |
# # Run the FastAPI app | |
# if __name__ == "__main__": | |
# import uvicorn | |
# uvicorn.run(app, host="0.0.0.0", port=8000) | |
from fastapi import FastAPI, HTTPException, status, Depends | |
from fastapi.responses import RedirectResponse | |
from pydantic import BaseModel, conlist, ValidationError | |
from pydantic_settings import BaseSettings | |
import pandas as pd | |
from pycaret.classification import load_model, predict_model | |
import logging | |
from typing import Optional, List | |
import numpy as np | |
import os | |
# Configure structured logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
# Define settings using Pydantic BaseSettings | |
class Settings(BaseSettings): | |
model_path: str = "./api/model/saved_tuned_model" | |
embedding_dimension: int = 1024 | |
log_level: str = "INFO" | |
class Config: | |
env_file = ".env" | |
env_file_encoding = "utf-8" | |
settings = Settings() | |
# Load the saved model | |
def load_tuned_model(model_path: str): | |
"""Load the pre-trained model from the specified path.""" | |
try: | |
logger.info(f"Loading model from {model_path}...") | |
model = load_model(model_path) | |
logger.info("Model loaded successfully.") | |
return model | |
except Exception as e: | |
logger.error(f"Failed to load the model: {str(e)}") | |
raise RuntimeError(f"Model loading failed: {str(e)}") | |
tuned_model = load_tuned_model(settings.model_path) | |
# Define the input data model using Pydantic | |
class EmbeddingRequest(BaseModel): | |
# embedding: conlist( | |
# float, | |
# min_length=settings.embedding_dimension, | |
# max_length=settings.embedding_dimension, | |
# ) | |
embedding: List[float] | |
# Define the response model | |
class PredictionResponse(BaseModel): | |
predicted_label: int | |
predicted_score: float | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Embedding Prediction API", | |
description="API for predicting labels and scores from embeddings using a pre-trained model.", | |
version="1.0.0", | |
) | |
# Dependency for model access | |
def get_model(): | |
"""Dependency to provide the loaded model to endpoints.""" | |
return tuned_model | |
async def root(): | |
return RedirectResponse(url="/docs") | |
# Define the prediction endpoint | |
async def predict( | |
request: EmbeddingRequest, | |
model=Depends(get_model), | |
): | |
""" | |
Predicts the label and score for a given embedding. | |
Args: | |
request (EmbeddingRequest): A request containing the embedding as a list of floats. | |
model: The pre-trained model injected via dependency. | |
Returns: | |
PredictionResponse: A response containing the predicted label and score. | |
""" | |
try: | |
logger.info("Received prediction request.") | |
# Convert the input embedding to a DataFrame | |
input_data = pd.DataFrame( | |
[request.embedding], | |
columns=[f"embedding_{i}" for i in range(settings.embedding_dimension)], | |
) | |
# Make a prediction using the loaded model | |
logger.info("Making prediction...") | |
prediction = predict_model(model, data=input_data) | |
# Validate the prediction output | |
if ( | |
"prediction_label" not in prediction.columns | |
or "prediction_score" not in prediction.columns | |
): | |
raise ValueError("Model prediction output is missing required columns.") | |
# Extract the predicted label and score | |
predicted_label = prediction["prediction_label"].iloc[0] | |
predicted_score = prediction["prediction_score"].iloc[0] | |
if predicted_label == 3: | |
predicted_label = 4 | |
logger.info( | |
f"Prediction successful: label={predicted_label}, score={predicted_score}" | |
) | |
return PredictionResponse( | |
predicted_label=int(predicted_label), | |
predicted_score=float(predicted_score), | |
) | |
except ValidationError as e: | |
logger.error(f"Validation error: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=f"Invalid input data: {str(e)}", | |
) | |
except ValueError as e: | |
logger.error(f"Value error: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Model output validation failed: {str(e)}", | |
) | |
except Exception as e: | |
logger.error(f"Prediction failed: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"An error occurred during prediction: {str(e)}", | |
) | |
# Health check endpoint | |
async def health_check(): | |
"""Health check endpoint to verify the API is running.""" | |
return {"status": "healthy"} | |
# # Run the FastAPI app | |
# if __name__ == "__main__": | |
# import uvicorn | |
# uvicorn.run(app, host="0.0.0.0", port=8000) | |