Sketch2DrawApp / app.py
szili2011's picture
Update app.py
57aac39 verified
raw
history blame
2.25 kB
import numpy as np
import base64
import gradio as gr
from tensorflow import keras
import cv2
# Load the model
model_path = 'sketch2draw_model.h5' # Update with your model path
model = keras.models.load_model(model_path)
# Define color mapping for different terrains
color_mapping = {
"Water": (0, 0, 255), # Blue
"Grass": (0, 255, 0), # Green
"Dirt": (139, 69, 19), # Brown
"Clouds": (255, 255, 255), # White
"Wood": (160, 82, 45), # Sienna
"Sky": (135, 206, 235) # Sky Blue
}
def predict(image):
# Decode the image from base64
image_data = np.frombuffer(base64.b64decode(image.split(",")[1]), np.uint8)
image = cv2.imdecode(image_data, cv2.IMREAD_COLOR)
# Resize and normalize the image
image = cv2.resize(image, (128, 128)) # Resize as per your model's input
image = np.expand_dims(image, axis=0) / 255.0 # Normalize if needed
# Generate image based on the model
generated_image = model.predict(image)[0] # Generate image
generated_image = (generated_image * 255).astype(np.uint8) # Rescale to 0-255
return generated_image
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h1>Sketch to Draw Model</h1>")
with gr.Row():
with gr.Column():
canvas = gr.Sketchpad(label="Draw Here") # Updated: Removed 'shape' argument
clear_btn = gr.Button("Clear")
generate_btn = gr.Button("Generate Image")
with gr.Column():
# Add color buttons for terrains
color_btns = {name: gr.Button(name) for name in color_mapping.keys()}
output_image = gr.Image(label="Generated Image", type="numpy")
# Define the actions for buttons
def clear_canvas():
return np.zeros((400, 400, 3), dtype=np.uint8)
clear_btn.click(fn=clear_canvas, inputs=None, outputs=canvas)
# Assign color to brush for each color button
def change_color(color):
return color
for color_name, color in color_mapping.items():
color_btns[color_name].click(fn=change_color, inputs=None, outputs=canvas)
# Click to generate an image
generate_btn.click(fn=predict, inputs=canvas, outputs=output_image)
# Launch the Gradio app
demo.launch()