szili2011 commited on
Commit
c5b0335
·
verified ·
1 Parent(s): d98dc08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -13
app.py CHANGED
@@ -8,8 +8,15 @@ import cv2
8
  model_path = 'sketch2draw_model.h5' # Update with your model path
9
  model = keras.models.load_model(model_path)
10
 
11
- # Load class names for predictions
12
- class_names = ['grass', 'dirt', 'wood', 'water', 'sky', 'clouds']
 
 
 
 
 
 
 
13
 
14
  def predict(image):
15
  # Decode the image from base64
@@ -20,11 +27,10 @@ def predict(image):
20
  image = cv2.resize(image, (128, 128)) # Resize as per your model's input
21
  image = np.expand_dims(image, axis=0) / 255.0 # Normalize if needed
22
 
23
- # Make prediction
24
- predictions = model.predict(image)
25
- predicted_class = class_names[np.argmax(predictions)]
26
-
27
- return predicted_class
28
 
29
  # Create Gradio interface
30
  with gr.Blocks() as demo:
@@ -32,21 +38,25 @@ with gr.Blocks() as demo:
32
 
33
  with gr.Row():
34
  with gr.Column():
35
- canvas = gr.Sketchpad(label="Draw Here") # Removed tool argument
36
- brush_color = gr.ColorPicker(value="black", label="Brush Color")
37
  clear_btn = gr.Button("Clear")
38
 
39
  with gr.Column():
40
- predict_btn = gr.Button("Predict")
41
- output_label = gr.Textbox(label="Predicted Texture")
42
-
 
43
  # Define the actions for buttons
44
  def clear_canvas():
45
  return np.zeros((400, 400, 3), dtype=np.uint8)
46
 
47
  clear_btn.click(fn=clear_canvas, inputs=None, outputs=canvas)
48
 
49
- predict_btn.click(fn=predict, inputs=canvas, outputs=output_label)
 
 
 
 
50
 
51
  # Launch the Gradio app
52
  demo.launch()
 
8
  model_path = 'sketch2draw_model.h5' # Update with your model path
9
  model = keras.models.load_model(model_path)
10
 
11
+ # Define color mapping for different terrains
12
+ color_mapping = {
13
+ "Water": (0, 0, 255), # Blue
14
+ "Grass": (0, 255, 0), # Green
15
+ "Dirt": (139, 69, 19), # Brown
16
+ "Clouds": (255, 255, 255), # White
17
+ "Wood": (160, 82, 45), # Sienna
18
+ "Sky": (135, 206, 235) # Sky Blue
19
+ }
20
 
21
  def predict(image):
22
  # Decode the image from base64
 
27
  image = cv2.resize(image, (128, 128)) # Resize as per your model's input
28
  image = np.expand_dims(image, axis=0) / 255.0 # Normalize if needed
29
 
30
+ # Generate image based on the model
31
+ generated_image = model.predict(image)[0] # Generate image
32
+ generated_image = (generated_image * 255).astype(np.uint8) # Rescale to 0-255
33
+ return generated_image
 
34
 
35
  # Create Gradio interface
36
  with gr.Blocks() as demo:
 
38
 
39
  with gr.Row():
40
  with gr.Column():
41
+ canvas = gr.Sketchpad(label="Draw Here")
 
42
  clear_btn = gr.Button("Clear")
43
 
44
  with gr.Column():
45
+ # Add color buttons for terrains
46
+ color_btns = {name: gr.Button(name) for name in color_mapping.keys()}
47
+ output_image = gr.Image(label="Generated Image", type="numpy")
48
+
49
  # Define the actions for buttons
50
  def clear_canvas():
51
  return np.zeros((400, 400, 3), dtype=np.uint8)
52
 
53
  clear_btn.click(fn=clear_canvas, inputs=None, outputs=canvas)
54
 
55
+ for color_name, color in color_mapping.items():
56
+ color_btns[color_name].click(fn=lambda color=color: canvas.update(value=color), inputs=None, outputs=None)
57
+
58
+ # Click to generate an image
59
+ canvas.submit(fn=predict, inputs=canvas, outputs=output_image)
60
 
61
  # Launch the Gradio app
62
  demo.launch()