Sketch2DrawApp / app.py
szili2011's picture
Create app.py
704a77e verified
raw
history blame
2.04 kB
import os
import numpy as np
import cv2
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from tensorflow import keras
from starlette.responses import FileResponse
from starlette.middleware.cors import CORSMiddleware
# Define the FastAPI app
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load the model
model_path = 'sketch2draw_model.h5' # Update with your model path
model = keras.models.load_model(model_path)
# Define the request body
class TextureRequest(BaseModel):
texture: str
# Load class names for predictions
class_names = ['grass', 'dirt', 'wood', 'water', 'sky', 'clouds']
@app.get("/", response_class=HTMLResponse)
async def read_root():
return """
<html>
<head>
<title>Sketch to Draw</title>
</head>
<body>
<h1>Sketch to Draw Model</h1>
<form action="/predict" method="post">
<input type="text" name="texture" placeholder="Enter texture name (grass, dirt, wood, water, sky, clouds)">
<button type="submit">Predict</button>
</form>
</body>
</html>
"""
@app.post("/predict")
async def predict_texture(request: TextureRequest):
texture_name = request.texture
# Process the input texture (you can modify this part)
# Example: Load image and preprocess it
# image = cv2.imread(f'path_to_your_texture_images/{texture_name}.png')
# image = cv2.resize(image, (128, 128)) # Resize as per your model's input
# image = np.expand_dims(image, axis=0) / 255.0 # Normalize if needed
# Make prediction
predictions = model.predict(image) # Add your processed image here
predicted_class = class_names[np.argmax(predictions)]
return {"predicted_texture": predicted_class}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080)