File size: 2,268 Bytes
704a77e
ce971f1
 
704a77e
00935ea
704a77e
 
 
 
 
c5b0335
 
 
 
 
 
 
 
 
704a77e
ce971f1
 
 
 
704a77e
ce971f1
 
 
704a77e
c5b0335
 
 
 
ce971f1
 
 
 
 
 
 
fa2d0fb
ce971f1
70720ab
ce971f1
 
c5b0335
 
 
 
ce971f1
 
 
 
 
fa2d0fb
 
 
 
 
c5b0335
fa2d0fb
c5b0335
 
70720ab
ce971f1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import base64
import gradio as gr
from tensorflow import keras
import cv2

# Load the model
model_path = 'sketch2draw_model.h5'  # Update with your model path
model = keras.models.load_model(model_path)

# Define color mapping for different terrains
color_mapping = {
    "Water": (0, 0, 255),   # Blue
    "Grass": (0, 255, 0),   # Green
    "Dirt": (139, 69, 19),  # Brown
    "Clouds": (255, 255, 255),  # White
    "Wood": (160, 82, 45),  # Sienna
    "Sky": (135, 206, 235)   # Sky Blue
}

def predict(image):
    # Decode the image from base64
    image_data = np.frombuffer(base64.b64decode(image.split(",")[1]), np.uint8)
    image = cv2.imdecode(image_data, cv2.IMREAD_COLOR)

    # Resize and normalize the image
    image = cv2.resize(image, (128, 128))  # Resize as per your model's input
    image = np.expand_dims(image, axis=0) / 255.0  # Normalize if needed

    # Generate image based on the model
    generated_image = model.predict(image)[0]  # Generate image
    generated_image = (generated_image * 255).astype(np.uint8)  # Rescale to 0-255
    return generated_image

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("<h1>Sketch to Draw Model</h1>")
    
    with gr.Row():
        with gr.Column():
            canvas = gr.Sketchpad(label="Draw Here", shape=(400, 400))  # Updated: Removed 'tool' argument
            clear_btn = gr.Button("Clear")
            generate_btn = gr.Button("Generate Image")
        
        with gr.Column():
            # Add color buttons for terrains
            color_btns = {name: gr.Button(name) for name in color_mapping.keys()}
            output_image = gr.Image(label="Generated Image", type="numpy")

    # Define the actions for buttons
    def clear_canvas():
        return np.zeros((400, 400, 3), dtype=np.uint8)

    clear_btn.click(fn=clear_canvas, inputs=None, outputs=canvas)

    # Assign color to brush for each color button
    def change_color(color):
        return color

    for color_name, color in color_mapping.items():
        color_btns[color_name].click(fn=change_color, inputs=None, outputs=canvas)

    # Click to generate an image
    generate_btn.click(fn=predict, inputs=canvas, outputs=output_image)

# Launch the Gradio app
demo.launch()