PhyscalX commited on
Commit
805ce53
Β·
1 Parent(s): 6da7f6f

Use image array for all prompts

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -173,8 +173,8 @@ def build_gradio_app(queues, command):
173
 
174
  def on_submit_btn(click_img, mask_img, prompt, multipoint):
175
  if prompt == 0:
176
- img = cv2.imread(click_img["image"])
177
- points = np.array(click_img["points"]).reshape((-1, 2, 3))
178
  if multipoint == 1:
179
  points = points.reshape((-1, 3))
180
  lt = points[np.where(points[:, 2] == 2)[0]][None, :, :]
@@ -183,7 +183,7 @@ def build_gradio_app(queues, command):
183
  points = [lt, rb, poly] if len(lt) > 0 else [poly, np.array([[[0, 0, 4]]])]
184
  points = np.concatenate(points, axis=1)
185
  elif prompt == 1:
186
- img, points = mask_img["background"][:, :, (2, 1, 0)], []
187
  for layer in mask_img["layers"]:
188
  ys, xs = np.nonzero(layer[:, :, 0])
189
  keep = np.linspace(0, ys.shape[0], 11, dtype="int64")[1:-1]
@@ -192,6 +192,7 @@ def build_gradio_app(queues, command):
192
  points = np.pad(points, [(0, 0), (0, 0), (0, 1)], constant_values=1)
193
  pad_points = np.array([[[0, 0, 4]]], "float32").repeat(points.shape[0], 0)
194
  points = np.concatenate([points, pad_points], axis=1)
 
195
  img = np.zeros((480, 640, 3), dtype="uint8") if img is None else img
196
  points = (np.array([[[0, 0, 4]]]) if len(points) == 0 else points).astype("float32")
197
  inputs = {"img": img, "points": points}
@@ -208,8 +209,8 @@ def build_gradio_app(queues, command):
208
  app = gr.Blocks(title=title, theme=theme, css=css).__enter__()
209
  gr.Markdown(header)
210
  container, column = gr.Row().__enter__(), gr.Column().__enter__()
211
- click_img = gr_ext.ImagePrompter(type="filepath", show_label=False)
212
- draw_img = gr.ImageEditor(type="numpy", show_label=False, visible=False)
213
  interactions = "LeftClick (FG) | MiddleClick (BG) | PressMove (Box) | Draw (Sketch)"
214
  gr.Markdown("<h3 style='text-align: center'>[πŸ–±οΈ | πŸ–οΈ]: 🌟🌟 {} 🌟🌟 </h3>".format(interactions))
215
  row = gr.Row().__enter__()
@@ -252,4 +253,4 @@ if __name__ == "__main__":
252
  actor.start()
253
  app = build_gradio_app(queues[:-1], commands[-1])
254
  app.queue()
255
- app.launch()
 
173
 
174
  def on_submit_btn(click_img, mask_img, prompt, multipoint):
175
  if prompt == 0:
176
+ img, points = click_img["image"], click_img["points"]
177
+ points = np.array(points).reshape((-1, 2, 3))
178
  if multipoint == 1:
179
  points = points.reshape((-1, 3))
180
  lt = points[np.where(points[:, 2] == 2)[0]][None, :, :]
 
183
  points = [lt, rb, poly] if len(lt) > 0 else [poly, np.array([[[0, 0, 4]]])]
184
  points = np.concatenate(points, axis=1)
185
  elif prompt == 1:
186
+ img, points = mask_img["background"], []
187
  for layer in mask_img["layers"]:
188
  ys, xs = np.nonzero(layer[:, :, 0])
189
  keep = np.linspace(0, ys.shape[0], 11, dtype="int64")[1:-1]
 
192
  points = np.pad(points, [(0, 0), (0, 0), (0, 1)], constant_values=1)
193
  pad_points = np.array([[[0, 0, 4]]], "float32").repeat(points.shape[0], 0)
194
  points = np.concatenate([points, pad_points], axis=1)
195
+ img = img[:, :, (2, 1, 0)] if img is not None else img
196
  img = np.zeros((480, 640, 3), dtype="uint8") if img is None else img
197
  points = (np.array([[[0, 0, 4]]]) if len(points) == 0 else points).astype("float32")
198
  inputs = {"img": img, "points": points}
 
209
  app = gr.Blocks(title=title, theme=theme, css=css).__enter__()
210
  gr.Markdown(header)
211
  container, column = gr.Row().__enter__(), gr.Column().__enter__()
212
+ click_img = gr_ext.ImagePrompter(show_label=False)
213
+ draw_img = gr.ImageEditor(show_label=False, visible=False)
214
  interactions = "LeftClick (FG) | MiddleClick (BG) | PressMove (Box) | Draw (Sketch)"
215
  gr.Markdown("<h3 style='text-align: center'>[πŸ–±οΈ | πŸ–οΈ]: 🌟🌟 {} 🌟🌟 </h3>".format(interactions))
216
  row = gr.Row().__enter__()
 
253
  actor.start()
254
  app = build_gradio_app(queues[:-1], commands[-1])
255
  app.queue()
256
+ app.launch(show_api=False)