Spaces:
Sleeping
Sleeping
Update fastapi_app/main.py
Browse files- fastapi_app/main.py +91 -91
fastapi_app/main.py
CHANGED
@@ -1,91 +1,91 @@
|
|
1 |
-
from fastapi import FastAPI, Request
|
2 |
-
from fastapi.responses import HTMLResponse
|
3 |
-
from fastapi.staticfiles import StaticFiles
|
4 |
-
from pydantic import BaseModel
|
5 |
-
import uvicorn
|
6 |
-
import os, sys
|
7 |
-
|
8 |
-
# Add the root directory to sys.path
|
9 |
-
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
10 |
-
from model_pipeline.model_predict import load_model, predict as initial_predict
|
11 |
-
from llama_pipeline.llama_predict import predict as llama_predict
|
12 |
-
from db_connection import insert_db
|
13 |
-
from logging_config.logger_config import get_logger
|
14 |
-
|
15 |
-
# Initialize the FastAPI app
|
16 |
-
app = FastAPI()
|
17 |
-
|
18 |
-
# Initialize the logger
|
19 |
-
logger = get_logger(__name__)
|
20 |
-
|
21 |
-
# Load the latest model at startup
|
22 |
-
model = load_model()
|
23 |
-
|
24 |
-
# Mount the static files directory
|
25 |
-
app.mount("/static", StaticFiles(directory="fastapi_app/static"), name="static")
|
26 |
-
|
27 |
-
@app.get("/", response_class=HTMLResponse)
|
28 |
-
def read_root():
|
29 |
-
with open("fastapi_app/static/index.html") as f:
|
30 |
-
html_content = f.read()
|
31 |
-
return HTMLResponse(content=html_content, status_code=200)
|
32 |
-
|
33 |
-
@app.get("/health")
|
34 |
-
def health_check():
|
35 |
-
logger.info("Health check endpoint accessed.")
|
36 |
-
return {"status": "ok"}
|
37 |
-
|
38 |
-
class TextInput(BaseModel):
|
39 |
-
text: str
|
40 |
-
|
41 |
-
class PredictionInput(BaseModel):
|
42 |
-
text: str
|
43 |
-
initial_prediction: str
|
44 |
-
llama_category: str
|
45 |
-
llama_explanation: str
|
46 |
-
user_rating: int
|
47 |
-
|
48 |
-
@app.post("/predict_sentiment")
|
49 |
-
def predict_sentiment(input_data: TextInput):
|
50 |
-
logger.info(f"Prediction request received with text: {input_data.text}")
|
51 |
-
|
52 |
-
# Initial model prediction
|
53 |
-
initial_prediction = initial_predict(input_data.text, model = model)
|
54 |
-
|
55 |
-
# LLaMA 3 prediction
|
56 |
-
llama_prediction = llama_predict(input_data.text)
|
57 |
-
|
58 |
-
# Prepare response
|
59 |
-
response = {
|
60 |
-
"text": input_data.text,
|
61 |
-
"initial_prediction": initial_prediction,
|
62 |
-
"llama_category": llama_prediction['Category'],
|
63 |
-
"llama_explanation": llama_prediction['Explanation']
|
64 |
-
}
|
65 |
-
|
66 |
-
logger.info(f"Prediction response: {response}")
|
67 |
-
return response
|
68 |
-
|
69 |
-
@app.post("/submit_interaction")
|
70 |
-
def submit_interaction(data: PredictionInput):
|
71 |
-
logger.info(f"Received interaction data: {data}")
|
72 |
-
logger.info(f"Received text: {data.text}")
|
73 |
-
logger.info(f"Received initial_prediction: {data.initial_prediction}")
|
74 |
-
logger.info(f"Received llama_category: {data.llama_category}")
|
75 |
-
logger.info(f"Received llama_explanation: {data.llama_explanation}")
|
76 |
-
logger.info(f"Received user_rating: {data.user_rating}")
|
77 |
-
|
78 |
-
interaction_data = {
|
79 |
-
"Input_text": data.text,
|
80 |
-
"Model_prediction": data.initial_prediction,
|
81 |
-
"Llama_3_Prediction": data.llama_category,
|
82 |
-
"Llama_3_Explanation": data.llama_explanation,
|
83 |
-
"User Rating": data.user_rating,
|
84 |
-
}
|
85 |
-
|
86 |
-
response = insert_db(interaction_data)
|
87 |
-
logger.info(f"Database response: {response}")
|
88 |
-
return {"status": "success", "response": response}
|
89 |
-
|
90 |
-
if __name__ == "__main__":
|
91 |
-
uvicorn.run(app, host="0.0.0.0", port=
|
|
|
1 |
+
from fastapi import FastAPI, Request
|
2 |
+
from fastapi.responses import HTMLResponse
|
3 |
+
from fastapi.staticfiles import StaticFiles
|
4 |
+
from pydantic import BaseModel
|
5 |
+
import uvicorn
|
6 |
+
import os, sys
|
7 |
+
|
8 |
+
# Add the root directory to sys.path
|
9 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
10 |
+
from model_pipeline.model_predict import load_model, predict as initial_predict
|
11 |
+
from llama_pipeline.llama_predict import predict as llama_predict
|
12 |
+
from db_connection import insert_db
|
13 |
+
from logging_config.logger_config import get_logger
|
14 |
+
|
15 |
+
# Initialize the FastAPI app
|
16 |
+
app = FastAPI()
|
17 |
+
|
18 |
+
# Initialize the logger
|
19 |
+
logger = get_logger(__name__)
|
20 |
+
|
21 |
+
# Load the latest model at startup
|
22 |
+
model = load_model()
|
23 |
+
|
24 |
+
# Mount the static files directory
|
25 |
+
app.mount("/static", StaticFiles(directory="fastapi_app/static"), name="static")
|
26 |
+
|
27 |
+
@app.get("/", response_class=HTMLResponse)
|
28 |
+
def read_root():
|
29 |
+
with open("fastapi_app/static/index.html") as f:
|
30 |
+
html_content = f.read()
|
31 |
+
return HTMLResponse(content=html_content, status_code=200)
|
32 |
+
|
33 |
+
@app.get("/health")
|
34 |
+
def health_check():
|
35 |
+
logger.info("Health check endpoint accessed.")
|
36 |
+
return {"status": "ok"}
|
37 |
+
|
38 |
+
class TextInput(BaseModel):
|
39 |
+
text: str
|
40 |
+
|
41 |
+
class PredictionInput(BaseModel):
|
42 |
+
text: str
|
43 |
+
initial_prediction: str
|
44 |
+
llama_category: str
|
45 |
+
llama_explanation: str
|
46 |
+
user_rating: int
|
47 |
+
|
48 |
+
@app.post("/predict_sentiment")
|
49 |
+
def predict_sentiment(input_data: TextInput):
|
50 |
+
logger.info(f"Prediction request received with text: {input_data.text}")
|
51 |
+
|
52 |
+
# Initial model prediction
|
53 |
+
initial_prediction = initial_predict(input_data.text, model = model)
|
54 |
+
|
55 |
+
# LLaMA 3 prediction
|
56 |
+
llama_prediction = llama_predict(input_data.text)
|
57 |
+
|
58 |
+
# Prepare response
|
59 |
+
response = {
|
60 |
+
"text": input_data.text,
|
61 |
+
"initial_prediction": initial_prediction,
|
62 |
+
"llama_category": llama_prediction['Category'],
|
63 |
+
"llama_explanation": llama_prediction['Explanation']
|
64 |
+
}
|
65 |
+
|
66 |
+
logger.info(f"Prediction response: {response}")
|
67 |
+
return response
|
68 |
+
|
69 |
+
@app.post("/submit_interaction")
|
70 |
+
def submit_interaction(data: PredictionInput):
|
71 |
+
logger.info(f"Received interaction data: {data}")
|
72 |
+
logger.info(f"Received text: {data.text}")
|
73 |
+
logger.info(f"Received initial_prediction: {data.initial_prediction}")
|
74 |
+
logger.info(f"Received llama_category: {data.llama_category}")
|
75 |
+
logger.info(f"Received llama_explanation: {data.llama_explanation}")
|
76 |
+
logger.info(f"Received user_rating: {data.user_rating}")
|
77 |
+
|
78 |
+
interaction_data = {
|
79 |
+
"Input_text": data.text,
|
80 |
+
"Model_prediction": data.initial_prediction,
|
81 |
+
"Llama_3_Prediction": data.llama_category,
|
82 |
+
"Llama_3_Explanation": data.llama_explanation,
|
83 |
+
"User Rating": data.user_rating,
|
84 |
+
}
|
85 |
+
|
86 |
+
response = insert_db(interaction_data)
|
87 |
+
logger.info(f"Database response: {response}")
|
88 |
+
return {"status": "success", "response": response}
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|