sshi commited on
Commit
224b9f6
1 Parent(s): e209888

App bug fix.

Browse files
Files changed (2) hide show
  1. app.py +17 -14
  2. requirements.txt +2 -1
app.py CHANGED
@@ -11,7 +11,8 @@ import pytorch_lightning as pl
11
 
12
  from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
13
 
14
- from PIL import Image
 
15
  import matplotlib.pyplot as plt
16
 
17
  id2label = {1: 'person', 2: 'rider', 3: 'car', 4: 'bus', 5: 'truck', 6: 'bike', 7: 'motor', 8: 'traffic light', 9: 'traffic sign', 10: 'train'}
@@ -100,21 +101,23 @@ def rescale_bboxes(out_bbox, size):
100
  return b
101
 
102
  def plot_results(pil_img, prob, boxes):
103
- fig = plt.figure(figsize=(16,10), dpi=120)
104
- plt.imshow(pil_img)
105
- ax = plt.gca()
106
  colors = COLORS * 100
 
107
  for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
108
  cl = p.argmax()
109
  c = colors[cl]
110
- ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
111
- fill=False, color=c, linewidth=2))
112
- text = f'{id2label[cl.item()]}: {p[cl]:0.2f}'
113
- ax.text(xmin, ymin, text, fontsize=10,
114
- bbox=dict(facecolor=c, alpha=0.5))
115
- plt.axis('off')
116
- # return Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb())
117
- return fig
 
 
118
 
119
 
120
  def generate_preds(processor, model, image):
@@ -145,8 +148,8 @@ def detect(img):
145
  interface = gr.Interface(
146
  fn=detect,
147
  inputs=[gr.Image(type="pil")],
148
- # outputs=gr.Image(type="pil"),
149
- outputs = ['plot'],
150
  examples=[["./imgs/example1.jpg"], ["./imgs/example2.jpg"]],
151
  title="YOLOS for traffic object detection",
152
  description="A downstream application for <a href='https://huggingface.co/docs/transformers/model_doc/yolos' style='text-decoration: underline' target='_blank'>YOLOS</a> which can performe traffic object detection. ")
 
11
 
12
  from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
13
 
14
+ from PIL import Image, ImageDraw
15
+ import cv2
16
  import matplotlib.pyplot as plt
17
 
18
  id2label = {1: 'person', 2: 'rider', 3: 'car', 4: 'bus', 5: 'truck', 6: 'bike', 7: 'motor', 8: 'traffic light', 9: 'traffic sign', 10: 'train'}
 
101
  return b
102
 
103
  def plot_results(pil_img, prob, boxes):
104
+
105
+ draw = ImageDraw.Draw(pil_img)
 
106
  colors = COLORS * 100
107
+
108
  for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
109
  cl = p.argmax()
110
  c = colors[cl]
111
+
112
+ draw.rectangle([xmin, ymin, xmax - xmin, ymax - ymin], outline=c, width=2)
113
+ draw.text(
114
+ [xmin + 5, ymin + 5],
115
+ f'{id2label[cl.item()]}: {p[cl]:0.2f}',
116
+ fill=c)
117
+ # ax.text(xmin, ymin, text, fontsize=10,
118
+ # bbox=dict(facecolor=c, alpha=0.5))
119
+ return Image.fromarray(pil_img[:,:,::-1])
120
+ # return fig
121
 
122
 
123
  def generate_preds(processor, model, image):
 
148
  interface = gr.Interface(
149
  fn=detect,
150
  inputs=[gr.Image(type="pil")],
151
+ outputs=gr.Image(type="pil"),
152
+ # outputs = ['plot'],
153
  examples=[["./imgs/example1.jpg"], ["./imgs/example2.jpg"]],
154
  title="YOLOS for traffic object detection",
155
  description="A downstream application for <a href='https://huggingface.co/docs/transformers/model_doc/yolos' style='text-decoration: underline' target='_blank'>YOLOS</a> which can performe traffic object detection. ")
requirements.txt CHANGED
@@ -5,4 +5,5 @@ git+https://github.com/huggingface/transformers.git@main
5
  matplotlib>=3.2.2
6
  Pillow>=7.1.2
7
  torch==1.10.0
8
- pytorch-lightning==1.9.3
 
 
5
  matplotlib>=3.2.2
6
  Pillow>=7.1.2
7
  torch==1.10.0
8
+ pytorch-lightning==1.9.3
9
+ opencv-python>=4.1.1