Manh Ho Dinh commited on
Commit
a724d88
·
verified ·
1 Parent(s): 4463d9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -12
app.py CHANGED
@@ -21,7 +21,7 @@ db = firestore.client()
21
  colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
22
  for j in range(10)]
23
 
24
- detection_threshold = 0.7
25
  model = YOLO(MODEL_PATH)
26
 
27
  def addToDatabase(ss_id, obj_ids):
@@ -48,7 +48,7 @@ def addToDatabase(ss_id, obj_ids):
48
 
49
 
50
  def traffic_counting(video):
51
-
52
  obj_ids = {"person": [],
53
  "bicycle": [],
54
  "car": [],
@@ -97,25 +97,72 @@ def traffic_counting(video):
97
  # Count each type of traffic
98
  output_data = {key: len(value) for key, value in obj_ids.items()}
99
  df = pd.DataFrame(list(output_data.items()), columns=['Type', 'Number'])
100
-
101
  yield frame, df
102
  ret, frame = cap.read()
103
-
 
104
  cap.release()
105
  cv2.destroyAllWindows()
106
  video_path = video.replace("\\", "/")
107
- addToDatabase(video_path.split("/")[-1][:-4], obj_ids)
108
 
109
 
 
 
 
110
 
111
- input_video = gr.Video(label="Input Video")
112
- output_video = gr.Image(type="numpy", label="Processing Video")
113
- output_data = gr.Dataframe(interactive=False, label="Traffic's Frequency")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- demo = gr.Interface(traffic_counting,
116
- inputs=input_video,
117
- outputs=[output_video, output_data],
118
- examples=[os.path.join('video', x) for x in os.listdir('video') if x != ".gitkeep"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  allow_flagging='never'
120
  )
121
 
 
21
  colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
22
  for j in range(10)]
23
 
24
+ detection_threshold = 0.1
25
  model = YOLO(MODEL_PATH)
26
 
27
  def addToDatabase(ss_id, obj_ids):
 
48
 
49
 
50
  def traffic_counting(video):
51
+
52
  obj_ids = {"person": [],
53
  "bicycle": [],
54
  "car": [],
 
97
  # Count each type of traffic
98
  output_data = {key: len(value) for key, value in obj_ids.items()}
99
  df = pd.DataFrame(list(output_data.items()), columns=['Type', 'Number'])
100
+
101
  yield frame, df
102
  ret, frame = cap.read()
103
+
104
+
105
  cap.release()
106
  cv2.destroyAllWindows()
107
  video_path = video.replace("\\", "/")
108
+ # addToDatabase(video_path.split("/")[-1][:-4], obj_ids)
109
 
110
 
111
+ # input_video = gr.Video(label="Input Video")
112
+ # output_video = gr.outputs.Video(label="Processing Video")
113
+ # output_data = gr.Dataframe(interactive=False, label="Traffic's Frequency")
114
 
115
+ # demo = gr.Interface(traffic_counting,
116
+ # inputs=input_video,
117
+ # outputs=[output_video, output_data],
118
+ # examples=[os.path.join('video', x) for x in os.listdir('video') if x != ".gitkeep"],
119
+ # allow_flagging='never'
120
+ # )
121
+ def traffic_detection(image):
122
+
123
+ results = model.predict(image)
124
+ detections = []
125
+ obj_ids = {"person": [],
126
+ "bicycle": [],
127
+ "car": [],
128
+ "motocycle": [],
129
+ "bus": [],
130
+ "truck": [],
131
+ "other": []}
132
 
133
+ for result in results:
134
+ for r in result.boxes.data.tolist():
135
+ x1, y1, x2, y2, score, class_id = r
136
+ x1 = int(x1)
137
+ x2 = int(x2)
138
+ y1 = int(y1)
139
+ y2 = int(y2)
140
+ class_id = int(class_id)
141
+ if score > detection_threshold:
142
+ detections.append([x1, y1, x2, y2, class_id, score])
143
+ cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (compute_color_for_labels(class_id)), 1)
144
+ label_name = ID2LABEL[class_id] if class_id in ID2LABEL.keys() else "other"
145
+ cv2.putText(image,f"{label_name}",
146
+ (int(x1) + 5, int(y1) - 5),
147
+ cv2.FONT_HERSHEY_SIMPLEX, 0.3,compute_color_for_labels(class_id), 1, cv2.LINE_AA )
148
+
149
+ # Count each type of traffic
150
+ output_data = {key: len(value) for key, value in obj_ids.items()}
151
+ df = pd.DataFrame(list(output_data.items()), columns=['Type', 'Number'])
152
+ yield image, df
153
+
154
+
155
+
156
+
157
+
158
+ # Input is a image
159
+ input_image = gr.Image(label="Input Image")
160
+ output_image = gr.Image(type="filepath", label="Processing Image")
161
+ output_data = gr.Dataframe(interactive=False, label="Traffic's Frequency")
162
+ demo = gr.Interface(traffic_detection,
163
+ inputs=input_image,
164
+ outputs=[output_image, output_data],
165
+ examples=[os.path.join('image', x) for x in os.listdir('image') if x != ".gitkeep"],
166
  allow_flagging='never'
167
  )
168