Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import cv2 | |
from fastapi import FastAPI | |
from fastapi.responses import HTMLResponse | |
from pydantic import BaseModel | |
from tensorflow import keras | |
from starlette.responses import FileResponse | |
from starlette.middleware.cors import CORSMiddleware | |
# Define the FastAPI app | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Load the model | |
model_path = 'sketch2draw_model.h5' # Update with your model path | |
model = keras.models.load_model(model_path) | |
# Define the request body | |
class TextureRequest(BaseModel): | |
texture: str | |
# Load class names for predictions | |
class_names = ['grass', 'dirt', 'wood', 'water', 'sky', 'clouds'] | |
async def read_root(): | |
return """ | |
<html> | |
<head> | |
<title>Sketch to Draw</title> | |
</head> | |
<body> | |
<h1>Sketch to Draw Model</h1> | |
<form action="/predict" method="post"> | |
<input type="text" name="texture" placeholder="Enter texture name (grass, dirt, wood, water, sky, clouds)"> | |
<button type="submit">Predict</button> | |
</form> | |
</body> | |
</html> | |
""" | |
async def predict_texture(request: TextureRequest): | |
texture_name = request.texture | |
# Process the input texture (you can modify this part) | |
# Example: Load image and preprocess it | |
# image = cv2.imread(f'path_to_your_texture_images/{texture_name}.png') | |
# image = cv2.resize(image, (128, 128)) # Resize as per your model's input | |
# image = np.expand_dims(image, axis=0) / 255.0 # Normalize if needed | |
# Make prediction | |
predictions = model.predict(image) # Add your processed image here | |
predicted_class = class_names[np.argmax(predictions)] | |
return {"predicted_texture": predicted_class} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8080) | |