sshi commited on
Commit
88c51e4
1 Parent(s): 94e2de2

App bug fix.

Browse files
Files changed (2) hide show
  1. app.py +9 -7
  2. requirements.txt +2 -1
app.py CHANGED
@@ -13,6 +13,7 @@ from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
13
 
14
  from PIL import Image, ImageDraw
15
  import cv2
 
16
  import matplotlib.pyplot as plt
17
 
18
  id2label = {1: 'person', 2: 'rider', 3: 'car', 4: 'bus', 5: 'truck', 6: 'bike', 7: 'motor', 8: 'traffic light', 9: 'traffic sign', 10: 'train'}
@@ -100,20 +101,21 @@ def rescale_bboxes(out_bbox, size):
100
 
101
  def plot_results(pil_img, prob, boxes):
102
 
103
- draw = ImageDraw.Draw(pil_img)
104
 
105
  for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
106
  cl = p.argmax()
107
  c = colors[cl]
 
108
 
109
- draw.rectangle([xmin, ymin, xmax - xmin, ymax - ymin], outline=c, width=2)
110
- draw.text(
111
- [xmin + 5, ymin + 5],
112
- f'{id2label[cl.item()]}: {p[cl]:0.2f}',
113
- fill=c)
114
  # ax.text(xmin, ymin, text, fontsize=10,
115
  # bbox=dict(facecolor=c, alpha=0.5))
116
- return pil_img
117
  # return fig
118
 
119
 
 
13
 
14
  from PIL import Image, ImageDraw
15
  import cv2
16
+ import numpy
17
  import matplotlib.pyplot as plt
18
 
19
  id2label = {1: 'person', 2: 'rider', 3: 'car', 4: 'bus', 5: 'truck', 6: 'bike', 7: 'motor', 8: 'traffic light', 9: 'traffic sign', 10: 'train'}
 
101
 
102
  def plot_results(pil_img, prob, boxes):
103
 
104
+ img = numpy.asarray(pil_img)
105
 
106
  for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
107
  cl = p.argmax()
108
  c = colors[cl]
109
+ c1, c2 = (xmin, ymin), (xmax, ymax)
110
 
111
+ cv2.rectangle(img, c1, c2, c, thickness=2, lineType=cv2.LINE_AA)
112
+ # cv2.text(
113
+ # [xmin + 5, ymin + 5],
114
+ # f'{id2label[cl.item()]}: {p[cl]:0.2f}',
115
+ # fill=c)
116
  # ax.text(xmin, ymin, text, fontsize=10,
117
  # bbox=dict(facecolor=c, alpha=0.5))
118
+ return return Image.fromarray(img[:,:,::-1])
119
  # return fig
120
 
121
 
requirements.txt CHANGED
@@ -6,4 +6,5 @@ matplotlib>=3.2.2
6
  Pillow>=7.1.2
7
  torch==1.10.0
8
  pytorch-lightning==1.9.3
9
- opencv-python>=4.1.1
 
 
6
  Pillow>=7.1.2
7
  torch==1.10.0
8
  pytorch-lightning==1.9.3
9
+ opencv-python>=4.1.1
10
+ numpy>=1.18.5