Timmyafolami commited on
Commit
4c604a2
·
verified ·
1 Parent(s): 58b6524

Update fastapi_app/main.py

Browse files
Files changed (1) hide show
  1. fastapi_app/main.py +91 -91
fastapi_app/main.py CHANGED
@@ -1,91 +1,91 @@
1
- from fastapi import FastAPI, Request
2
- from fastapi.responses import HTMLResponse
3
- from fastapi.staticfiles import StaticFiles
4
- from pydantic import BaseModel
5
- import uvicorn
6
- import os, sys
7
-
8
- # Add the root directory to sys.path
9
- sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
10
- from model_pipeline.model_predict import load_model, predict as initial_predict
11
- from llama_pipeline.llama_predict import predict as llama_predict
12
- from db_connection import insert_db
13
- from logging_config.logger_config import get_logger
14
-
15
- # Initialize the FastAPI app
16
- app = FastAPI()
17
-
18
- # Initialize the logger
19
- logger = get_logger(__name__)
20
-
21
- # Load the latest model at startup
22
- model = load_model()
23
-
24
- # Mount the static files directory
25
- app.mount("/static", StaticFiles(directory="fastapi_app/static"), name="static")
26
-
27
- @app.get("/", response_class=HTMLResponse)
28
- def read_root():
29
- with open("fastapi_app/static/index.html") as f:
30
- html_content = f.read()
31
- return HTMLResponse(content=html_content, status_code=200)
32
-
33
- @app.get("/health")
34
- def health_check():
35
- logger.info("Health check endpoint accessed.")
36
- return {"status": "ok"}
37
-
38
- class TextInput(BaseModel):
39
- text: str
40
-
41
- class PredictionInput(BaseModel):
42
- text: str
43
- initial_prediction: str
44
- llama_category: str
45
- llama_explanation: str
46
- user_rating: int
47
-
48
- @app.post("/predict_sentiment")
49
- def predict_sentiment(input_data: TextInput):
50
- logger.info(f"Prediction request received with text: {input_data.text}")
51
-
52
- # Initial model prediction
53
- initial_prediction = initial_predict(input_data.text, model = model)
54
-
55
- # LLaMA 3 prediction
56
- llama_prediction = llama_predict(input_data.text)
57
-
58
- # Prepare response
59
- response = {
60
- "text": input_data.text,
61
- "initial_prediction": initial_prediction,
62
- "llama_category": llama_prediction['Category'],
63
- "llama_explanation": llama_prediction['Explanation']
64
- }
65
-
66
- logger.info(f"Prediction response: {response}")
67
- return response
68
-
69
- @app.post("/submit_interaction")
70
- def submit_interaction(data: PredictionInput):
71
- logger.info(f"Received interaction data: {data}")
72
- logger.info(f"Received text: {data.text}")
73
- logger.info(f"Received initial_prediction: {data.initial_prediction}")
74
- logger.info(f"Received llama_category: {data.llama_category}")
75
- logger.info(f"Received llama_explanation: {data.llama_explanation}")
76
- logger.info(f"Received user_rating: {data.user_rating}")
77
-
78
- interaction_data = {
79
- "Input_text": data.text,
80
- "Model_prediction": data.initial_prediction,
81
- "Llama_3_Prediction": data.llama_category,
82
- "Llama_3_Explanation": data.llama_explanation,
83
- "User Rating": data.user_rating,
84
- }
85
-
86
- response = insert_db(interaction_data)
87
- logger.info(f"Database response: {response}")
88
- return {"status": "success", "response": response}
89
-
90
- if __name__ == "__main__":
91
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.responses import HTMLResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from pydantic import BaseModel
5
+ import uvicorn
6
+ import os, sys
7
+
8
+ # Add the root directory to sys.path
9
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
10
+ from model_pipeline.model_predict import load_model, predict as initial_predict
11
+ from llama_pipeline.llama_predict import predict as llama_predict
12
+ from db_connection import insert_db
13
+ from logging_config.logger_config import get_logger
14
+
15
+ # Initialize the FastAPI app
16
+ app = FastAPI()
17
+
18
+ # Initialize the logger
19
+ logger = get_logger(__name__)
20
+
21
+ # Load the latest model at startup
22
+ model = load_model()
23
+
24
+ # Mount the static files directory
25
+ app.mount("/static", StaticFiles(directory="fastapi_app/static"), name="static")
26
+
27
+ @app.get("/", response_class=HTMLResponse)
28
+ def read_root():
29
+ with open("fastapi_app/static/index.html") as f:
30
+ html_content = f.read()
31
+ return HTMLResponse(content=html_content, status_code=200)
32
+
33
+ @app.get("/health")
34
+ def health_check():
35
+ logger.info("Health check endpoint accessed.")
36
+ return {"status": "ok"}
37
+
38
+ class TextInput(BaseModel):
39
+ text: str
40
+
41
+ class PredictionInput(BaseModel):
42
+ text: str
43
+ initial_prediction: str
44
+ llama_category: str
45
+ llama_explanation: str
46
+ user_rating: int
47
+
48
+ @app.post("/predict_sentiment")
49
+ def predict_sentiment(input_data: TextInput):
50
+ logger.info(f"Prediction request received with text: {input_data.text}")
51
+
52
+ # Initial model prediction
53
+ initial_prediction = initial_predict(input_data.text, model = model)
54
+
55
+ # LLaMA 3 prediction
56
+ llama_prediction = llama_predict(input_data.text)
57
+
58
+ # Prepare response
59
+ response = {
60
+ "text": input_data.text,
61
+ "initial_prediction": initial_prediction,
62
+ "llama_category": llama_prediction['Category'],
63
+ "llama_explanation": llama_prediction['Explanation']
64
+ }
65
+
66
+ logger.info(f"Prediction response: {response}")
67
+ return response
68
+
69
+ @app.post("/submit_interaction")
70
+ def submit_interaction(data: PredictionInput):
71
+ logger.info(f"Received interaction data: {data}")
72
+ logger.info(f"Received text: {data.text}")
73
+ logger.info(f"Received initial_prediction: {data.initial_prediction}")
74
+ logger.info(f"Received llama_category: {data.llama_category}")
75
+ logger.info(f"Received llama_explanation: {data.llama_explanation}")
76
+ logger.info(f"Received user_rating: {data.user_rating}")
77
+
78
+ interaction_data = {
79
+ "Input_text": data.text,
80
+ "Model_prediction": data.initial_prediction,
81
+ "Llama_3_Prediction": data.llama_category,
82
+ "Llama_3_Explanation": data.llama_explanation,
83
+ "User Rating": data.user_rating,
84
+ }
85
+
86
+ response = insert_db(interaction_data)
87
+ logger.info(f"Database response: {response}")
88
+ return {"status": "success", "response": response}
89
+
90
+ if __name__ == "__main__":
91
+ uvicorn.run(app, host="0.0.0.0", port=7860)