Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
import uvicorn | |
import os, sys | |
# Add the root directory to sys.path | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
from model_pipeline.model_predict import load_model, predict as initial_predict | |
from llama_pipeline.llama_predict import predict as llama_predict | |
from db_connection import insert_db | |
from logging_config.logger_config import get_logger | |
# Initialize the FastAPI app | |
app = FastAPI() | |
# Initialize the logger | |
logger = get_logger(__name__) | |
# Load the latest model at startup | |
model = load_model() | |
# Mount the static files directory | |
app.mount("/static", StaticFiles(directory="fastapi_app/static"), name="static") | |
def read_root(): | |
with open("fastapi_app/static/index.html") as f: | |
html_content = f.read() | |
return HTMLResponse(content=html_content, status_code=200) | |
def health_check(): | |
logger.info("Health check endpoint accessed.") | |
return {"status": "ok"} | |
class TextInput(BaseModel): | |
text: str | |
class PredictionInput(BaseModel): | |
text: str | |
initial_prediction: str | |
llama_category: str | |
llama_explanation: str | |
user_rating: int | |
def predict_sentiment(input_data: TextInput): | |
logger.info(f"Prediction request received with text: {input_data.text}") | |
# Initial model prediction | |
initial_prediction = initial_predict(input_data.text, model = model) | |
# LLaMA 3 prediction | |
llama_prediction = llama_predict(input_data.text) | |
# Prepare response | |
response = { | |
"text": input_data.text, | |
"initial_prediction": initial_prediction, | |
"llama_category": llama_prediction['Category'], | |
"llama_explanation": llama_prediction['Explanation'] | |
} | |
logger.info(f"Prediction response: {response}") | |
return response | |
def submit_interaction(data: PredictionInput): | |
logger.info(f"Received interaction data: {data}") | |
logger.info(f"Received text: {data.text}") | |
logger.info(f"Received initial_prediction: {data.initial_prediction}") | |
logger.info(f"Received llama_category: {data.llama_category}") | |
logger.info(f"Received llama_explanation: {data.llama_explanation}") | |
logger.info(f"Received user_rating: {data.user_rating}") | |
interaction_data = { | |
"Input_text": data.text, | |
"Model_prediction": data.initial_prediction, | |
"Llama_3_Prediction": data.llama_category, | |
"Llama_3_Explanation": data.llama_explanation, | |
"User Rating": data.user_rating, | |
} | |
response = insert_db(interaction_data) | |
logger.info(f"Database response: {response}") | |
return {"status": "success", "response": response} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |