Stefan commited on
Commit
dd8d615
·
1 Parent(s): e513d31

Add application file

Browse files
Files changed (9) hide show
  1. .DS_Store +0 -0
  2. Dockerfile +18 -0
  3. app/.DS_Store +0 -0
  4. app/__init__.py +0 -0
  5. app/lstm_model.h5 +3 -0
  6. app/models.py +0 -0
  7. app/routes.py +55 -0
  8. main.py +43 -0
  9. requirements.txt +51 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the official Python image
2
+ FROM python:3.11-slim
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+ # Copy requirements and install dependencies
8
+ COPY requirements.txt .
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # Copy the rest of the application
12
+ COPY . .
13
+
14
+ # Expose the port FastAPI will run on
15
+ EXPOSE 7860
16
+
17
+ # Run the FastAPI app with Uvicorn
18
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
app/.DS_Store ADDED
Binary file (6.15 kB). View file
 
app/__init__.py ADDED
File without changes
app/lstm_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd738f5910dfcddf00835fafdedbbfd66df8fe1be433665ffa067259efe4d8fe
3
+ size 427312
app/models.py ADDED
File without changes
app/routes.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, APIRouter
2
+ import pandas as pd
3
+ import numpy as np
4
+ from keras.models import load_model
5
+ from pydantic import BaseModel
6
+
7
+ # Load the model
8
+ model = load_model("app/lstm_model.h5")
9
+ # Define the FastAPI app
10
+ router = APIRouter()
11
+
12
+ class PredictionInput(BaseModel):
13
+ input_data: dict
14
+
15
+ def preprocess_and_predict(input_data):
16
+ input_data = input_data.drop(columns=['COMPANY', 'PRICE OF LAST TRANSACTION'])
17
+ # Load the pre-trained model
18
+ timesteps = model.input_shape[1]
19
+ features = model.input_shape[2]
20
+
21
+ # Ensure input_data has the correct number of features
22
+ input_data = input_data.iloc[:, :features]
23
+
24
+ # Handle missing values and normalize
25
+ input_data = input_data.fillna(0)
26
+ max_value = input_data.max().max()
27
+ input_data_normalized = input_data / max_value
28
+
29
+ # Check if there are enough rows for timesteps
30
+ if len(input_data) < timesteps:
31
+ raise ValueError(f"Input data must have at least {timesteps} rows for prediction.")
32
+
33
+ # Reshape the data
34
+ input_data_reshaped = np.array([input_data_normalized.values[-timesteps:]])
35
+ input_data_reshaped = input_data_reshaped.reshape(1, timesteps, features)
36
+
37
+ # Predict
38
+ predictions = model.predict(input_data_reshaped)
39
+ predictions_denormalized = predictions * max_value
40
+
41
+ return predictions_denormalized.round()[0][0]
42
+
43
+ # API endpoint
44
+ @router.post("/predict/")
45
+ async def predict(payload: PredictionInput):
46
+ try:
47
+ input_data = payload.input_data
48
+ dataframe = pd.DataFrame.from_dict(input_data)
49
+
50
+ return {"prediction": preprocess_and_predict(input_data=dataframe)}
51
+
52
+ except ValueError as e:
53
+ raise HTTPException(status_code=422, detail=str(e))
54
+ except Exception as e:
55
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
main.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from app.routes import router as api_router
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi import FastAPI, HTTPException, Request
5
+ from fastapi.responses import JSONResponse
6
+ import logging
7
+
8
+
9
+ # Initialize the FastAPI app
10
+ app = FastAPI(
11
+ title="API for the DAS Homework",
12
+ description="This api is is used to serve for the DAS Homework web application",
13
+ version="1.0.0"
14
+ )
15
+
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["https://das-prototype.web.app"],
19
+ allow_credentials=True,
20
+ allow_methods=["*"], # Allow all HTTP methods (e.g., GET, POST)
21
+ allow_headers=["*"], # Allow all headers
22
+ )
23
+
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+ @app.middleware("http")
27
+ async def log_requests(request: Request, call_next):
28
+ try:
29
+ body = await request.body()
30
+ logging.info(f"Request body: {body.decode()}")
31
+ response = await call_next(request)
32
+ logging.info(f"Response status: {response.status_code}")
33
+ return response
34
+ except Exception as e:
35
+ logging.error(f"Error occurred: {e}")
36
+ raise e
37
+
38
+ # Include the API router
39
+ app.include_router(api_router)
40
+
41
+ @app.get("/")
42
+ async def root():
43
+ return {"message": "API for the DAS Homework"}
requirements.txt ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ annotated-types==0.7.0
3
+ anyio==4.8.0
4
+ astunparse==1.6.3
5
+ certifi==2024.12.14
6
+ charset-normalizer==3.4.1
7
+ click==8.1.8
8
+ fastapi==0.115.6
9
+ flatbuffers==24.12.23
10
+ gast==0.6.0
11
+ google-pasta==0.2.0
12
+ grpcio==1.69.0
13
+ h11==0.14.0
14
+ h5py==3.12.1
15
+ idna==3.10
16
+ keras==3.8.0
17
+ libclang==18.1.1
18
+ Markdown==3.7
19
+ markdown-it-py==3.0.0
20
+ MarkupSafe==3.0.2
21
+ mdurl==0.1.2
22
+ ml-dtypes==0.4.1
23
+ namex==0.0.8
24
+ numpy==2.0.2
25
+ opt_einsum==3.4.0
26
+ optree==0.14.0
27
+ packaging==24.2
28
+ pandas==2.2.3
29
+ protobuf==5.29.3
30
+ pydantic==2.10.5
31
+ pydantic_core==2.27.2
32
+ Pygments==2.19.1
33
+ python-dateutil==2.9.0.post0
34
+ python-multipart==0.0.20
35
+ pytz==2024.2
36
+ requests==2.32.3
37
+ rich==13.9.4
38
+ six==1.17.0
39
+ sniffio==1.3.1
40
+ starlette==0.41.3
41
+ tensorboard==2.18.0
42
+ tensorboard-data-server==0.7.2
43
+ tensorflow==2.18.0
44
+ tensorflow-io-gcs-filesystem==0.37.1
45
+ termcolor==2.5.0
46
+ typing_extensions==4.12.2
47
+ tzdata==2024.2
48
+ urllib3==2.3.0
49
+ uvicorn==0.34.0
50
+ Werkzeug==3.1.3
51
+ wrapt==1.17.2