|
from keras.api.models import Sequential |
|
from keras.api.layers import InputLayer, Dense |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
import numpy as np |
|
from typing import List |
|
|
|
class InputData(BaseModel): |
|
data: List[float] |
|
|
|
app = FastAPI() |
|
|
|
|
|
def build_model(): |
|
model = Sequential( |
|
[ |
|
InputLayer( |
|
input_shape=(2,), name="dense_2_input" |
|
), |
|
Dense(16, activation="relu", name="dense_2"), |
|
Dense(1, activation="sigmoid", name="dense_3"), |
|
] |
|
) |
|
model.load_weights( |
|
"model.h5" |
|
) |
|
model.compile( |
|
loss="mean_squared_error", optimizer="adam", metrics=["binary_accuracy"] |
|
) |
|
return model |
|
|
|
|
|
model = build_model() |
|
|
|
|
|
|
|
@app.post("/predict/") |
|
async def predict(data: InputData): |
|
print(f"Data: {data}") |
|
global model |
|
try: |
|
|
|
input_data = np.array(data.data).reshape( |
|
1, -1 |
|
) |
|
prediction = model.predict(input_data).round() |
|
return {"prediction": prediction.tolist()} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|