szili2011 commited on
Commit
704a77e
·
verified ·
1 Parent(s): 5e22aa1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ from fastapi import FastAPI
5
+ from fastapi.responses import HTMLResponse
6
+ from pydantic import BaseModel
7
+ from tensorflow import keras
8
+ from starlette.responses import FileResponse
9
+ from starlette.middleware.cors import CORSMiddleware
10
+
11
+ # Define the FastAPI app
12
+ app = FastAPI()
13
+
14
+ # Add CORS middleware
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"],
18
+ allow_credentials=True,
19
+ allow_methods=["*"],
20
+ allow_headers=["*"],
21
+ )
22
+
23
+ # Load the model
24
+ model_path = 'sketch2draw_model.h5' # Update with your model path
25
+ model = keras.models.load_model(model_path)
26
+
27
+ # Define the request body
28
+ class TextureRequest(BaseModel):
29
+ texture: str
30
+
31
+ # Load class names for predictions
32
+ class_names = ['grass', 'dirt', 'wood', 'water', 'sky', 'clouds']
33
+
34
+ @app.get("/", response_class=HTMLResponse)
35
+ async def read_root():
36
+ return """
37
+ <html>
38
+ <head>
39
+ <title>Sketch to Draw</title>
40
+ </head>
41
+ <body>
42
+ <h1>Sketch to Draw Model</h1>
43
+ <form action="/predict" method="post">
44
+ <input type="text" name="texture" placeholder="Enter texture name (grass, dirt, wood, water, sky, clouds)">
45
+ <button type="submit">Predict</button>
46
+ </form>
47
+ </body>
48
+ </html>
49
+ """
50
+
51
+ @app.post("/predict")
52
+ async def predict_texture(request: TextureRequest):
53
+ texture_name = request.texture
54
+
55
+ # Process the input texture (you can modify this part)
56
+ # Example: Load image and preprocess it
57
+ # image = cv2.imread(f'path_to_your_texture_images/{texture_name}.png')
58
+ # image = cv2.resize(image, (128, 128)) # Resize as per your model's input
59
+ # image = np.expand_dims(image, axis=0) / 255.0 # Normalize if needed
60
+
61
+ # Make prediction
62
+ predictions = model.predict(image) # Add your processed image here
63
+ predicted_class = class_names[np.argmax(predictions)]
64
+
65
+ return {"predicted_texture": predicted_class}
66
+
67
+ if __name__ == "__main__":
68
+ import uvicorn
69
+ uvicorn.run(app, host="0.0.0.0", port=8080)