skyBluezz commited on
Commit
3444920
·
verified ·
1 Parent(s): 2b2b6f2

Update app.py

Browse files

Model now returns [out_img, boxes, labels]

Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -162,8 +162,9 @@ def run_detr(image:Image, detr_processor, detr_model, confidence_threshold: floa
162
  threshold=confidence_threshold)
163
  outputs = postprocessed_outputs[0]
164
  # dict{scores, logits, labels, boxes}
165
- outputs = outputs['boxes']
166
- return outputs
 
167
 
168
 
169
  def resize_dimensions(dimensions, target_size):
@@ -284,7 +285,6 @@ class ControlNetDepthDesignModelMulti:
284
 
285
  return design_image
286
 
287
-
288
  def create_demo(model):
289
  gr.Markdown("### demo")
290
  with gr.Row():
@@ -348,9 +348,9 @@ def create_demo(model):
348
  # -- run detr --
349
  # -----------------
350
  # clear_gpu()
351
- bboxes = run_detr(out_img, detr_processor, detr_model, detr_confidence_threshold)
352
 
353
- return out_img, bboxes.tolist()
354
 
355
  submit.click(on_submit, inputs=[input_image, input_text, num_steps, guidance_scale, seed, strength, detr_confidence_threshold, a_prompt, n_prompt, img_size], outputs=[design_image, bboxes])
356
 
 
162
  threshold=confidence_threshold)
163
  outputs = postprocessed_outputs[0]
164
  # dict{scores, logits, labels, boxes}
165
+ labels = [detr_model.config.id2label[label.item()] for label in outputs["labels"]]
166
+ boxes = outputs['boxes']
167
+ return boxes, labels
168
 
169
 
170
  def resize_dimensions(dimensions, target_size):
 
285
 
286
  return design_image
287
 
 
288
  def create_demo(model):
289
  gr.Markdown("### demo")
290
  with gr.Row():
 
348
  # -- run detr --
349
  # -----------------
350
  # clear_gpu()
351
+ bboxes, labels = run_detr(out_img, detr_processor, detr_model, detr_confidence_threshold)
352
 
353
+ return out_img, bboxes.tolist(), labels
354
 
355
  submit.click(on_submit, inputs=[input_image, input_text, num_steps, guidance_scale, seed, strength, detr_confidence_threshold, a_prompt, n_prompt, img_size], outputs=[design_image, bboxes])
356