Stefan commited on
Commit
201c67c
·
1 Parent(s): 86d42d1

revert to main.py

Browse files
Files changed (1) hide show
  1. main.py +52 -3
main.py CHANGED
@@ -4,7 +4,11 @@ from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi import FastAPI, HTTPException, Request
5
  from fastapi.responses import JSONResponse
6
  import logging
7
-
 
 
 
 
8
 
9
  # Initialize the FastAPI app
10
  app = FastAPI(
@@ -35,9 +39,54 @@ async def log_requests(request: Request, call_next):
35
  logging.error(f"Error occurred: {e}")
36
  raise e
37
 
38
- # Include the API router
39
- app.include_router(api_router)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  @app.get("/")
42
  async def root():
43
  return {"message": "API for the DAS Homework"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from fastapi import FastAPI, HTTPException, Request
5
  from fastapi.responses import JSONResponse
6
  import logging
7
+ from fastapi import FastAPI, HTTPException, APIRouter
8
+ import pandas as pd
9
+ import numpy as np
10
+ from keras.models import load_model
11
+ from pydantic import BaseModel
12
 
13
  # Initialize the FastAPI app
14
  app = FastAPI(
 
39
  logging.error(f"Error occurred: {e}")
40
  raise e
41
 
42
+ model = load_model("app/lstm_model.h5")
43
+
44
+ class PredictionInput(BaseModel):
45
+ input_data: dict
46
+
47
+ def preprocess_and_predict(input_data):
48
+ input_data = input_data.drop(columns=['COMPANY', 'PRICE OF LAST TRANSACTION'])
49
+ # Load the pre-trained model
50
+ timesteps = model.input_shape[1]
51
+ features = model.input_shape[2]
52
+
53
+ # Ensure input_data has the correct number of features
54
+ input_data = input_data.iloc[:, :features]
55
+
56
+ # Handle missing values and normalize
57
+ input_data = input_data.fillna(0)
58
+ max_value = input_data.max().max()
59
+ input_data_normalized = input_data / max_value
60
+
61
+ # Check if there are enough rows for timesteps
62
+ if len(input_data) < timesteps:
63
+ raise ValueError(f"Input data must have at least {timesteps} rows for prediction.")
64
+
65
+ # Reshape the data
66
+ input_data_reshaped = np.array([input_data_normalized.values[-timesteps:]])
67
+ input_data_reshaped = input_data_reshaped.reshape(1, timesteps, features)
68
+
69
+ # Predict
70
+ predictions = model.predict(input_data_reshaped)
71
+ predictions_denormalized = predictions * max_value
72
+
73
+ return predictions_denormalized.round()[0][0]
74
+
75
 
76
  @app.get("/")
77
  async def root():
78
  return {"message": "API for the DAS Homework"}
79
+
80
+ # API endpoint
81
+ @app.post("/predict/")
82
+ async def predict(payload: PredictionInput):
83
+ try:
84
+ input_data = payload.input_data
85
+ dataframe = pd.DataFrame.from_dict(input_data)
86
+
87
+ return {"prediction": preprocess_and_predict(input_data=dataframe)}
88
+
89
+ except ValueError as e:
90
+ raise HTTPException(status_code=422, detail=str(e))
91
+ except Exception as e:
92
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")