sanbatte commited on
Commit
2939a15
·
1 Parent(s): 395b563

refactor app

Browse files
__init__.py ADDED
File without changes
app/data/data_finance.csv ADDED
The diff for this file is too large to render. See raw diff
 
{data → app/data}/lightgbm_deuda.pkl RENAMED
File without changes
app/models/__init__.py ADDED
File without changes
app/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
app/models/__pycache__/prediction_models.cpython-310.pyc ADDED
Binary file (684 Bytes). View file
 
app/models/prediction_models.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List
3
+
4
+
5
+ class PredictionRequest(BaseModel):
6
+ invoiceId: List[int]
7
+ country: str
8
+
9
+
10
+ class PredictionResponse(BaseModel):
11
+ invoiceId: int
12
+ prediction: float
app/routes/__init__.py ADDED
File without changes
app/routes/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
app/routes/__pycache__/prediction.cpython-310.pyc ADDED
Binary file (1.04 kB). View file
 
app/routes/prediction.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from app.models.prediction_models import PredictionRequest, PredictionResponse
3
+ from typing import List
4
+ from app.utils.data_preparation import load_model_and_data
5
+ from app.utils.validations import check_country_code, check_valid_ids
6
+
7
+ router = APIRouter()
8
+
9
+ df, model = load_model_and_data()
10
+
11
+
12
+ @router.post("/", response_model=List[PredictionResponse])
13
+ def predict(request: PredictionRequest):
14
+ check_country_code(request)
15
+ check_valid_ids(request, df)
16
+
17
+ prediction_data = df.loc[request.invoiceId]
18
+ predictions = model.predict(prediction_data)
19
+
20
+ response_data = [
21
+ {"invoiceId": invoice_id, "prediction": float(prediction)}
22
+ for invoice_id, prediction in zip(request.invoiceId, predictions)
23
+ ]
24
+
25
+ return response_data
app/utils/__init__.py ADDED
File without changes
app/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (151 Bytes). View file
 
app/utils/__pycache__/data_preparation.cpython-310.pyc ADDED
Binary file (965 Bytes). View file
 
app/utils/__pycache__/validations.cpython-310.pyc ADDED
Binary file (734 Bytes). View file
 
utils.py → app/utils/data_preparation.py RENAMED
@@ -1,14 +1,23 @@
1
  import pandas as pd
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  def prepare_data(df: pd.DataFrame = None) -> pd.DataFrame:
5
  """
6
  Prepare data.
7
  """
8
- # Assuming no additional preprocessing is required for this example
9
  df = df.drop(["Unnamed: 0", "overdueDays"], axis=1)
10
  df = df.drop(["businessId", "payerId"], axis=1)
11
- df = df.set_index("invoiceId")
12
  df = df[
13
  [
14
  "receiptAmount",
 
1
  import pandas as pd
2
+ import joblib
3
+
4
+
5
+ def load_model_and_data():
6
+ with open("app/data/lightgbm_deuda.pkl", "rb") as file:
7
+ model = joblib.load(file)
8
+
9
+ df = pd.read_csv("app/data/data_finance.csv")
10
+ df = df.set_index("invoiceId")
11
+ return df, model
12
 
13
 
14
  def prepare_data(df: pd.DataFrame = None) -> pd.DataFrame:
15
  """
16
  Prepare data.
17
  """
 
18
  df = df.drop(["Unnamed: 0", "overdueDays"], axis=1)
19
  df = df.drop(["businessId", "payerId"], axis=1)
20
+
21
  df = df[
22
  [
23
  "receiptAmount",
app/utils/validations.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import HTTPException
2
+
3
+
4
+ def check_country_code(request):
5
+ if request.country not in ["CL", "MX"]:
6
+ raise HTTPException(
7
+ status_code=400, detail=f"Invalid country code: {request.country}"
8
+ )
9
+ else:
10
+ return print("correct country code")
11
+
12
+
13
+ def check_valid_ids(request, df):
14
+ invalid_ids = set(request.invoiceId) - set(df.index)
15
+ if invalid_ids:
16
+ raise HTTPException(
17
+ status_code=400, detail=f"Invalid invoiceId(s): {invalid_ids}"
18
+ )
19
+ else:
20
+ return print("invoice ids are valid")
data/dataTest.csv DELETED
The diff for this file is too large to render. See raw diff
 
main.py CHANGED
@@ -1,71 +1,13 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from os import getenv
4
- import pandas as pd
5
- import joblib
6
 
7
- # import pickle
8
- import numpy as np
9
- from utils import prepare_data
10
 
11
  app = FastAPI()
12
 
13
- # Cargar el modelo XGBoost desde el archivo .pkl
14
- # with open("data/lightgbm_deuda.pkl", "rb") as file:
15
- # model = pickle.load(file)
16
- with open("data/lightgbm_deuda.pkl", "rb") as file:
17
- # model = joblib.load(file)
18
- model = joblib.load(file)
19
-
20
- # Cargar el DataFrame desde el archivo CSV
21
- df = pd.read_csv("data/dataTest.csv")
22
- # df = df.set_index("invoiceId")
23
-
24
-
25
- df = prepare_data(df)
26
-
27
-
28
- class PredictionRequest(BaseModel):
29
- invoiceId: list[int]
30
- country: str
31
-
32
-
33
- class PredictionResponse(BaseModel):
34
- invoiceId: int
35
- prediction: float
36
-
37
-
38
- @app.post("/predict")
39
- def predict(request: PredictionRequest):
40
- # Verificar que los invoiceId enviados estén en el DataFrame
41
-
42
- invalid_ids = set(request.invoiceId) - set(df.index)
43
- if invalid_ids:
44
- raise HTTPException(
45
- status_code=400, detail=f"Invalid invoiceId(s): {invalid_ids}"
46
- )
47
- if request.country not in ["CL", "MX"]:
48
- raise HTTPException(
49
- status_code=400, detail=f"Invalid country code: {request.country}"
50
- )
51
-
52
- # Filtrar el DataFrame para obtener solo las filas correspondientes a los invoiceId enviados
53
- prediction_data = df.loc[request.invoiceId]
54
-
55
- # Realizar la predicción con el modelo
56
- predictions = model.predict(prediction_data)
57
-
58
- # Crear la respuesta
59
- response_data = [
60
- {"invoiceId": invoice_id, "prediction": float(prediction)}
61
- for invoice_id, prediction in zip(request.invoiceId, predictions)
62
- ]
63
-
64
- return response_data
65
-
66
-
67
- # if __name__ == "__main__":
68
- # import uvicorn
69
-
70
- # print("building")
71
- # uvicorn.run(app, host="0.0.0.0", reload=True)
 
1
+ from fastapi import FastAPI, Depends
2
+ from app.routes import prediction
3
+ from app.models.prediction_models import PredictionRequest, PredictionResponse
 
 
4
 
 
 
 
5
 
6
  app = FastAPI()
7
 
8
+ # Include three routers
9
+ app.include_router(
10
+ prediction.router,
11
+ prefix="/predict",
12
+ tags=["prediction"],
13
+ )