Spaces:
Sleeping
Sleeping
Stefan
commited on
Commit
·
dd8d615
1
Parent(s):
e513d31
Add application file
Browse files- .DS_Store +0 -0
- Dockerfile +18 -0
- app/.DS_Store +0 -0
- app/__init__.py +0 -0
- app/lstm_model.h5 +3 -0
- app/models.py +0 -0
- app/routes.py +55 -0
- main.py +43 -0
- requirements.txt +51 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
Dockerfile
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use the official Python image
|
2 |
+
FROM python:3.11-slim
|
3 |
+
|
4 |
+
# Set the working directory
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Copy requirements and install dependencies
|
8 |
+
COPY requirements.txt .
|
9 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
10 |
+
|
11 |
+
# Copy the rest of the application
|
12 |
+
COPY . .
|
13 |
+
|
14 |
+
# Expose the port FastAPI will run on
|
15 |
+
EXPOSE 7860
|
16 |
+
|
17 |
+
# Run the FastAPI app with Uvicorn
|
18 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
app/__init__.py
ADDED
File without changes
|
app/lstm_model.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bd738f5910dfcddf00835fafdedbbfd66df8fe1be433665ffa067259efe4d8fe
|
3 |
+
size 427312
|
app/models.py
ADDED
File without changes
|
app/routes.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, APIRouter
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from keras.models import load_model
|
5 |
+
from pydantic import BaseModel
|
6 |
+
|
7 |
+
# Load the model
|
8 |
+
model = load_model("app/lstm_model.h5")
|
9 |
+
# Define the FastAPI app
|
10 |
+
router = APIRouter()
|
11 |
+
|
12 |
+
class PredictionInput(BaseModel):
|
13 |
+
input_data: dict
|
14 |
+
|
15 |
+
def preprocess_and_predict(input_data):
|
16 |
+
input_data = input_data.drop(columns=['COMPANY', 'PRICE OF LAST TRANSACTION'])
|
17 |
+
# Load the pre-trained model
|
18 |
+
timesteps = model.input_shape[1]
|
19 |
+
features = model.input_shape[2]
|
20 |
+
|
21 |
+
# Ensure input_data has the correct number of features
|
22 |
+
input_data = input_data.iloc[:, :features]
|
23 |
+
|
24 |
+
# Handle missing values and normalize
|
25 |
+
input_data = input_data.fillna(0)
|
26 |
+
max_value = input_data.max().max()
|
27 |
+
input_data_normalized = input_data / max_value
|
28 |
+
|
29 |
+
# Check if there are enough rows for timesteps
|
30 |
+
if len(input_data) < timesteps:
|
31 |
+
raise ValueError(f"Input data must have at least {timesteps} rows for prediction.")
|
32 |
+
|
33 |
+
# Reshape the data
|
34 |
+
input_data_reshaped = np.array([input_data_normalized.values[-timesteps:]])
|
35 |
+
input_data_reshaped = input_data_reshaped.reshape(1, timesteps, features)
|
36 |
+
|
37 |
+
# Predict
|
38 |
+
predictions = model.predict(input_data_reshaped)
|
39 |
+
predictions_denormalized = predictions * max_value
|
40 |
+
|
41 |
+
return predictions_denormalized.round()[0][0]
|
42 |
+
|
43 |
+
# API endpoint
|
44 |
+
@router.post("/predict/")
|
45 |
+
async def predict(payload: PredictionInput):
|
46 |
+
try:
|
47 |
+
input_data = payload.input_data
|
48 |
+
dataframe = pd.DataFrame.from_dict(input_data)
|
49 |
+
|
50 |
+
return {"prediction": preprocess_and_predict(input_data=dataframe)}
|
51 |
+
|
52 |
+
except ValueError as e:
|
53 |
+
raise HTTPException(status_code=422, detail=str(e))
|
54 |
+
except Exception as e:
|
55 |
+
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
main.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
from app.routes import router as api_router
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
from fastapi import FastAPI, HTTPException, Request
|
5 |
+
from fastapi.responses import JSONResponse
|
6 |
+
import logging
|
7 |
+
|
8 |
+
|
9 |
+
# Initialize the FastAPI app
|
10 |
+
app = FastAPI(
|
11 |
+
title="API for the DAS Homework",
|
12 |
+
description="This api is is used to serve for the DAS Homework web application",
|
13 |
+
version="1.0.0"
|
14 |
+
)
|
15 |
+
|
16 |
+
app.add_middleware(
|
17 |
+
CORSMiddleware,
|
18 |
+
allow_origins=["https://das-prototype.web.app"],
|
19 |
+
allow_credentials=True,
|
20 |
+
allow_methods=["*"], # Allow all HTTP methods (e.g., GET, POST)
|
21 |
+
allow_headers=["*"], # Allow all headers
|
22 |
+
)
|
23 |
+
|
24 |
+
logging.basicConfig(level=logging.INFO)
|
25 |
+
|
26 |
+
@app.middleware("http")
|
27 |
+
async def log_requests(request: Request, call_next):
|
28 |
+
try:
|
29 |
+
body = await request.body()
|
30 |
+
logging.info(f"Request body: {body.decode()}")
|
31 |
+
response = await call_next(request)
|
32 |
+
logging.info(f"Response status: {response.status_code}")
|
33 |
+
return response
|
34 |
+
except Exception as e:
|
35 |
+
logging.error(f"Error occurred: {e}")
|
36 |
+
raise e
|
37 |
+
|
38 |
+
# Include the API router
|
39 |
+
app.include_router(api_router)
|
40 |
+
|
41 |
+
@app.get("/")
|
42 |
+
async def root():
|
43 |
+
return {"message": "API for the DAS Homework"}
|
requirements.txt
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
annotated-types==0.7.0
|
3 |
+
anyio==4.8.0
|
4 |
+
astunparse==1.6.3
|
5 |
+
certifi==2024.12.14
|
6 |
+
charset-normalizer==3.4.1
|
7 |
+
click==8.1.8
|
8 |
+
fastapi==0.115.6
|
9 |
+
flatbuffers==24.12.23
|
10 |
+
gast==0.6.0
|
11 |
+
google-pasta==0.2.0
|
12 |
+
grpcio==1.69.0
|
13 |
+
h11==0.14.0
|
14 |
+
h5py==3.12.1
|
15 |
+
idna==3.10
|
16 |
+
keras==3.8.0
|
17 |
+
libclang==18.1.1
|
18 |
+
Markdown==3.7
|
19 |
+
markdown-it-py==3.0.0
|
20 |
+
MarkupSafe==3.0.2
|
21 |
+
mdurl==0.1.2
|
22 |
+
ml-dtypes==0.4.1
|
23 |
+
namex==0.0.8
|
24 |
+
numpy==2.0.2
|
25 |
+
opt_einsum==3.4.0
|
26 |
+
optree==0.14.0
|
27 |
+
packaging==24.2
|
28 |
+
pandas==2.2.3
|
29 |
+
protobuf==5.29.3
|
30 |
+
pydantic==2.10.5
|
31 |
+
pydantic_core==2.27.2
|
32 |
+
Pygments==2.19.1
|
33 |
+
python-dateutil==2.9.0.post0
|
34 |
+
python-multipart==0.0.20
|
35 |
+
pytz==2024.2
|
36 |
+
requests==2.32.3
|
37 |
+
rich==13.9.4
|
38 |
+
six==1.17.0
|
39 |
+
sniffio==1.3.1
|
40 |
+
starlette==0.41.3
|
41 |
+
tensorboard==2.18.0
|
42 |
+
tensorboard-data-server==0.7.2
|
43 |
+
tensorflow==2.18.0
|
44 |
+
tensorflow-io-gcs-filesystem==0.37.1
|
45 |
+
termcolor==2.5.0
|
46 |
+
typing_extensions==4.12.2
|
47 |
+
tzdata==2024.2
|
48 |
+
urllib3==2.3.0
|
49 |
+
uvicorn==0.34.0
|
50 |
+
Werkzeug==3.1.3
|
51 |
+
wrapt==1.17.2
|