gabcares commited on
Commit
2b6b101
·
verified ·
1 Parent(s): 2354717

Update main.py

Browse files

Update to use XGBRegressor as the most performance-inclined model compared to RandomForestRegressor

Files changed (1) hide show
  1. main.py +177 -177
main.py CHANGED
@@ -1,177 +1,177 @@
1
- import os
2
- from dotenv import load_dotenv
3
-
4
- from collections.abc import AsyncIterator
5
- from contextlib import asynccontextmanager
6
-
7
- from fastapi import FastAPI, Query
8
- from fastapi.responses import FileResponse
9
- from fastapi.staticfiles import StaticFiles
10
- from fastapi_cache import FastAPICache
11
- from fastapi_cache.backends.inmemory import InMemoryBackend
12
- from fastapi_cache.coder import PickleCoder
13
- from fastapi_cache.decorator import cache
14
- import logging
15
-
16
- from pydantic import BaseModel, Field
17
- from typing import List, Union, Optional
18
- from datetime import datetime
19
-
20
- from sklearn.pipeline import Pipeline
21
- import joblib
22
-
23
- import pandas as pd
24
-
25
- import httpx
26
- from io import BytesIO
27
-
28
-
29
- from utils.config import (
30
- ONE_DAY_SEC,
31
- ONE_WEEK_SEC,
32
- ENV_PATH,
33
- DESCRIPTION,
34
- ALL_MODELS
35
- )
36
-
37
- load_dotenv(ENV_PATH)
38
-
39
-
40
- @asynccontextmanager
41
- async def lifespan(_: FastAPI) -> AsyncIterator[None]:
42
- FastAPICache.init(InMemoryBackend())
43
- yield
44
-
45
-
46
- # FastAPI Object
47
- app = FastAPI(
48
- title='Yassir Eta Prediction',
49
- version='1.0.0',
50
- description=DESCRIPTION,
51
- lifespan=lifespan,
52
- )
53
-
54
- app.mount("/assets", StaticFiles(directory="assets"), name="assets")
55
-
56
-
57
- @app.get('/favicon.ico', include_in_schema=False)
58
- @cache(expire=ONE_WEEK_SEC, namespace='eta_favicon') # Cache for 1 week
59
- async def favicon():
60
- file_name = "favicon.ico"
61
- file_path = os.path.join(app.root_path, "assets", file_name)
62
- return FileResponse(path=file_path, headers={"Content-Disposition": "attachment; filename=" + file_name})
63
-
64
-
65
- # API input features
66
-
67
-
68
- class EtaFeatures(BaseModel):
69
- timestamp: List[datetime] = Field(
70
- description="Timestamp: Time that the trip was started")
71
- origin_lat: List[float] = Field(
72
- description="Origin_lat: Origin latitude (in degrees)")
73
- origin_lon: List[float] = Field(
74
- description="Origin_lon: Origin longitude (in degrees)")
75
- destination_lat: List[float] = Field(
76
- description="Destination_lat: Destination latitude (in degrees)")
77
- destination_lon: List[float] = Field(
78
- description="Destination_lon: Destination longitude (in degrees)")
79
- trip_distance: List[float] = Field(
80
- description="Trip_distance: Distance in meters on a driving route")
81
-
82
-
83
- class Url(BaseModel):
84
- url: str
85
- pipeline_url: str
86
-
87
-
88
- class ResultData(BaseModel):
89
- prediction: List[float]
90
-
91
-
92
- class PredictionResponse(BaseModel):
93
- execution_msg: str
94
- execution_code: int
95
- result: ResultData
96
-
97
-
98
- class ErrorResponse(BaseModel):
99
- execution_msg: str
100
- execution_code: int
101
- error: Optional[str]
102
-
103
-
104
- logging.basicConfig(level=logging.ERROR,
105
- format='%(asctime)s - %(levelname)s - %(message)s')
106
-
107
-
108
- # Load the model pipelines and encoder
109
- # Cache for 1 day
110
- @cache(expire=ONE_DAY_SEC, namespace='pipeline_resource', coder=PickleCoder)
111
- async def load_pipeline(pipeline_url: Url) -> Pipeline:
112
- async def url_to_data(url: Url):
113
- async with httpx.AsyncClient() as client:
114
- response = await client.get(url)
115
- response.raise_for_status() # Ensure we catch any HTTP errors
116
- # Convert response content to BytesIO object
117
- data = BytesIO(response.content)
118
- return data
119
-
120
- pipeline = None
121
- try:
122
- pipeline: Pipeline = joblib.load(await url_to_data(pipeline_url))
123
- except Exception as e:
124
- logging.error(
125
- "Omg, an error occurred in loading the pipeline resources: %s", e)
126
- finally:
127
- return pipeline
128
-
129
-
130
- # Endpoints
131
-
132
- # Status endpoint: check if api is online
133
- @app.get('/')
134
- @cache(expire=ONE_WEEK_SEC, namespace='eta_status_check') # Cache for 1 week
135
- async def status_check():
136
- return {"Status": "API is online..."}
137
-
138
-
139
- @cache(expire=ONE_DAY_SEC, namespace='pipeline_regressor') # Cache for 1 day
140
- async def pipeline_regressor(pipeline: Pipeline, data: EtaFeatures) -> Union[ErrorResponse, PredictionResponse]:
141
- msg = 'Execution failed'
142
- code = 0
143
- output = ErrorResponse(**{'execution_msg': msg,
144
- 'execution_code': code, 'error': None})
145
-
146
- try:
147
- # Create dataframe
148
- df = pd.DataFrame.from_dict(data.__dict__)
149
-
150
- # Make prediction
151
- preds = pipeline.predict(df)
152
- predictions = [float(pred) for pred in preds]
153
-
154
- result = ResultData(**{"prediction": predictions})
155
-
156
- msg = 'Execution was successful'
157
- code = 1
158
- output = PredictionResponse(
159
- **{'execution_msg': msg,
160
- 'execution_code': code, 'result': result}
161
- )
162
-
163
- except Exception as e:
164
- error = f"Omg, pipeline regressor failure. {e}"
165
- output = ErrorResponse(**{'execution_msg': msg,
166
- 'execution_code': code, 'error': error})
167
-
168
- finally:
169
- return output
170
-
171
-
172
- @app.post('/api/v1/eta/prediction', tags=['All Models'])
173
- async def query_eta_prediction(data: EtaFeatures, model: str = Query('RandomForestRegressor', enum=list(ALL_MODELS.keys()))) -> Union[ErrorResponse, PredictionResponse]:
174
- pipeline_url: Url = ALL_MODELS[model]
175
- pipeline = await load_pipeline(pipeline_url)
176
- output = await pipeline_regressor(pipeline, data)
177
- return output
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ from collections.abc import AsyncIterator
5
+ from contextlib import asynccontextmanager
6
+
7
+ from fastapi import FastAPI, Query
8
+ from fastapi.responses import FileResponse
9
+ from fastapi.staticfiles import StaticFiles
10
+ from fastapi_cache import FastAPICache
11
+ from fastapi_cache.backends.inmemory import InMemoryBackend
12
+ from fastapi_cache.coder import PickleCoder
13
+ from fastapi_cache.decorator import cache
14
+ import logging
15
+
16
+ from pydantic import BaseModel, Field
17
+ from typing import List, Union, Optional
18
+ from datetime import datetime
19
+
20
+ from sklearn.pipeline import Pipeline
21
+ import joblib
22
+
23
+ import pandas as pd
24
+
25
+ import httpx
26
+ from io import BytesIO
27
+
28
+
29
+ from utils.config import (
30
+ ONE_DAY_SEC,
31
+ ONE_WEEK_SEC,
32
+ ENV_PATH,
33
+ DESCRIPTION,
34
+ ALL_MODELS
35
+ )
36
+
37
+ load_dotenv(ENV_PATH)
38
+
39
+
40
+ @asynccontextmanager
41
+ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
42
+ FastAPICache.init(InMemoryBackend())
43
+ yield
44
+
45
+
46
+ # FastAPI Object
47
+ app = FastAPI(
48
+ title='Yassir Eta Prediction',
49
+ version='1.0.0',
50
+ description=DESCRIPTION,
51
+ lifespan=lifespan,
52
+ )
53
+
54
+ app.mount("/assets", StaticFiles(directory="assets"), name="assets")
55
+
56
+
57
+ @app.get('/favicon.ico', include_in_schema=False)
58
+ @cache(expire=ONE_WEEK_SEC, namespace='eta_favicon') # Cache for 1 week
59
+ async def favicon():
60
+ file_name = "favicon.ico"
61
+ file_path = os.path.join(app.root_path, "assets", file_name)
62
+ return FileResponse(path=file_path, headers={"Content-Disposition": "attachment; filename=" + file_name})
63
+
64
+
65
+ # API input features
66
+
67
+
68
+ class EtaFeatures(BaseModel):
69
+ timestamp: List[datetime] = Field(
70
+ description="Timestamp: Time that the trip was started")
71
+ origin_lat: List[float] = Field(
72
+ description="Origin_lat: Origin latitude (in degrees)")
73
+ origin_lon: List[float] = Field(
74
+ description="Origin_lon: Origin longitude (in degrees)")
75
+ destination_lat: List[float] = Field(
76
+ description="Destination_lat: Destination latitude (in degrees)")
77
+ destination_lon: List[float] = Field(
78
+ description="Destination_lon: Destination longitude (in degrees)")
79
+ trip_distance: List[float] = Field(
80
+ description="Trip_distance: Distance in meters on a driving route")
81
+
82
+
83
+ class Url(BaseModel):
84
+ url: str
85
+ pipeline_url: str
86
+
87
+
88
+ class ResultData(BaseModel):
89
+ prediction: List[float]
90
+
91
+
92
+ class PredictionResponse(BaseModel):
93
+ execution_msg: str
94
+ execution_code: int
95
+ result: ResultData
96
+
97
+
98
+ class ErrorResponse(BaseModel):
99
+ execution_msg: str
100
+ execution_code: int
101
+ error: Optional[str]
102
+
103
+
104
+ logging.basicConfig(level=logging.ERROR,
105
+ format='%(asctime)s - %(levelname)s - %(message)s')
106
+
107
+
108
+ # Load the model pipelines and encoder
109
+ # Cache for 1 day
110
+ @cache(expire=ONE_DAY_SEC, namespace='pipeline_resource', coder=PickleCoder)
111
+ async def load_pipeline(pipeline_url: Url) -> Pipeline:
112
+ async def url_to_data(url: Url):
113
+ async with httpx.AsyncClient() as client:
114
+ response = await client.get(url)
115
+ response.raise_for_status() # Ensure we catch any HTTP errors
116
+ # Convert response content to BytesIO object
117
+ data = BytesIO(response.content)
118
+ return data
119
+
120
+ pipeline = None
121
+ try:
122
+ pipeline: Pipeline = joblib.load(await url_to_data(pipeline_url))
123
+ except Exception as e:
124
+ logging.error(
125
+ "Omg, an error occurred in loading the pipeline resources: %s", e)
126
+ finally:
127
+ return pipeline
128
+
129
+
130
+ # Endpoints
131
+
132
+ # Status endpoint: check if api is online
133
+ @app.get('/')
134
+ @cache(expire=ONE_WEEK_SEC, namespace='eta_status_check') # Cache for 1 week
135
+ async def status_check():
136
+ return {"Status": "API is online..."}
137
+
138
+
139
+ @cache(expire=ONE_DAY_SEC, namespace='pipeline_regressor') # Cache for 1 day
140
+ async def pipeline_regressor(pipeline: Pipeline, data: EtaFeatures) -> Union[ErrorResponse, PredictionResponse]:
141
+ msg = 'Execution failed'
142
+ code = 0
143
+ output = ErrorResponse(**{'execution_msg': msg,
144
+ 'execution_code': code, 'error': None})
145
+
146
+ try:
147
+ # Create dataframe
148
+ df = pd.DataFrame.from_dict(data.__dict__)
149
+
150
+ # Make prediction
151
+ preds = pipeline.predict(df)
152
+ predictions = [float(pred) for pred in preds]
153
+
154
+ result = ResultData(**{"prediction": predictions})
155
+
156
+ msg = 'Execution was successful'
157
+ code = 1
158
+ output = PredictionResponse(
159
+ **{'execution_msg': msg,
160
+ 'execution_code': code, 'result': result}
161
+ )
162
+
163
+ except Exception as e:
164
+ error = f"Omg, pipeline regressor failure. {e}"
165
+ output = ErrorResponse(**{'execution_msg': msg,
166
+ 'execution_code': code, 'error': error})
167
+
168
+ finally:
169
+ return output
170
+
171
+
172
+ @app.post('/api/v1/eta/prediction', tags=['All Models'])
173
+ async def query_eta_prediction(data: EtaFeatures, model: str = Query('XGBRegressor', enum=list(ALL_MODELS.keys()))) -> Union[ErrorResponse, PredictionResponse]:
174
+ pipeline_url: Url = ALL_MODELS[model]
175
+ pipeline = await load_pipeline(pipeline_url)
176
+ output = await pipeline_regressor(pipeline, data)
177
+ return output