Sketch2DrawApp / app.py
szili2011's picture
Update app.py
09a050f verified
import numpy as np
import base64
import gradio as gr
from tensorflow import keras
import cv2
# Load the model
model_path = 'sketch2draw_model.h5' # Update with your model path
model = keras.models.load_model(model_path)
# Define color mapping for different terrains
color_mapping = {
"Water": (0, 0, 255), # Blue
"Grass": (0, 255, 0), # Green
"Dirt": (139, 69, 19), # Brown
"Clouds": (255, 255, 255), # White
"Wood": (160, 82, 45), # Sienna
"Sky": (135, 206, 235) # Sky Blue
}
# Function to predict and generate image
def predict(image):
# Decode the image from base64
image_data = np.frombuffer(base64.b64decode(image.split(",")[1]), np.uint8)
image = cv2.imdecode(image_data, cv2.IMREAD_COLOR)
# Resize and normalize the image
image = cv2.resize(image, (128, 128)) # Resize as per your model's input
image = np.expand_dims(image, axis=0) / 255.0 # Normalize if needed
# Generate image based on the model
generated_image = model.predict(image)[0] # Generate image
generated_image = (generated_image * 255).astype(np.uint8) # Rescale to 0-255
return generated_image
# Function to clear the canvas
def clear_canvas():
return np.zeros((400, 400, 3), dtype=np.uint8)
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h1>Sketch to Draw Model</h1>")
with gr.Row():
with gr.Column():
# Create a sketchpad for drawing
canvas = gr.Sketchpad(label="Draw Here")
clear_btn = gr.Button("Clear")
generate_btn = gr.Button("Generate Image")
with gr.Column():
# Create a dropdown for different terrain colors
color_dropdown = gr.Dropdown(
label="Select Brush Color",
choices=list(color_mapping.keys()),
value="Water" # Default color
)
output_image = gr.Image(label="Generated Image", type="numpy")
# Define the actions for buttons
clear_btn.click(fn=clear_canvas, inputs=None, outputs=canvas)
# Click to generate an image
generate_btn.click(fn=predict, inputs=canvas, outputs=output_image)
# Launch the Gradio app
demo.launch()