Spaces:
Build error
Build error
File size: 8,696 Bytes
1865436 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import argparse
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import time
import tensorflow as tf
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
tf.config.experimental.set_memory_growth(physical_devices[0], True)
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
# deep sort imports
from deep_sort import nn_matching
from application_util import preprocessing
from deep_sort.detection import Detection
from deep_sort.tracker import Tracker
from _tools_ import generate_detections as gdet
# deepsort
from mrcnn.mrcnn_color import MRCNN
# ocr
# from sts.demo.sts import handle_sts
def _parse_args():
parser = argparse.ArgumentParser(description="")
parser.add_argument("--model",
help="detection model",
type=str,
default="./checkpoint/maskrcnn_signboard_ss.ckpt")
parser.add_argument("--input_size",
help="input size",
type=int,
default=1024)
parser.add_argument("--score",
help="score threshold",
type=float,
default=0.50)
parser.add_argument("--size",
help="resize images to",
type=int,
default=1024)
parser.add_argument("--video",
help="path to input video or set to 0 for webcam",
type=str,
default="./samples/demo.mp4")
parser.add_argument("--output",
help="path to output video",
type=str,
default="./outputs/demo.mp4")
parser.add_argument("--output_format",
help="codec used in VideoWriter when saving video to file",
type=str,
default='mp4v')
parser.add_argument("--dont_show",
help="dont show video output",
type=bool,
default=True)
parser.add_argument("--info",
help="show detailed info of tracked objects",
type=bool,
default=True)
parser.add_argument("--count",
help="count objects being tracked on screen",
type=bool,
default=True)
args = parser.parse_args()
return args
def handle(args):
# Definition of the parameters
max_cosine_distance = 0.4
nn_budget = None
nms_max_overlap = 1.0
# initialize deep sort
model_filename = 'checkpoint/signboard_2793.pb'
encoder = gdet.create_box_encoder(model_filename, batch_size=1)
# calculate cosine distance metric
metric = nn_matching.NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget)
# initialize tracker
tracker = Tracker(metric)
# initialize maskrcnn
mrcnn = MRCNN(args.model, args.input_size, args.score)
# load configuration for object detector
video_path = args.video
# begin video capture
try:
vid = cv2.VideoCapture(int(video_path))
except:
vid = cv2.VideoCapture(video_path)
out = None
# get video ready to save locally if flag is set
if args.output:
# by default VideoCapture returns float instead of int
width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(vid.get(cv2.CAP_PROP_FPS))
codec = cv2.VideoWriter_fourcc(*args.output_format)
out = cv2.VideoWriter(args.output, codec, fps, (width, height))
frame_num = 0
# while video is running
while True:
return_value, frame = vid.read()
if return_value:
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
else:
print('Video has ended or failed, try a different video format!')
break
frame_num +=1
print('Frame #: ', frame_num)
start_time = time.time()
boxes, scores, class_names, class_ids, class_color = mrcnn.detect_result_(image, min_score=0.5)
count = len(class_names)
if args.count:
cv2.putText(frame, "Objects being tracked: {0}".format(count), (5, 35), cv2.FONT_HERSHEY_COMPLEX_SMALL, 2, (0, 255, 0), 2)
print("Objects being tracked: {0}".format(count))
# encode yolo detections and feed to tracker
features = encoder(frame, boxes)
detections = [Detection(box, score, class_name, feature) for box, score, class_name, feature in zip(boxes, scores, class_names, features)]
#initialize color map
cmap = plt.get_cmap('tab20b')
colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]
# run non-maxima supression
boxs = np.array([d.tlwh for d in detections])
scores = np.array([d.confidence for d in detections])
classes = np.array([d.class_name for d in detections])
indices = preprocessing.non_max_suppression(boxs, classes, nms_max_overlap, scores)
detections = [detections[i] for i in indices]
# Call the tracker
tracker.predict()
tracker.update(detections)
# update tracks
with open("./outputs/{}.txt".format(frame_num), "a+", encoding="utf-8") as ff:
for track in tracker.tracks:
if not track.is_confirmed() or track.time_since_update > 1:
continue
bbox = track.to_tlbr()
# crop to ids folder
ids_path = "./ids/"+str(track.track_id)
if not os.path.isdir(ids_path):
os.mkdir(ids_path)
crop_ids = frame[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])]
num_ids = 0
while os.path.isfile(os.path.join(ids_path, str(track.track_id) + "_" + str(frame_num) + "_" + str(num_ids)+".png")):
num_ids += 1
final_ids_path = os.path.join(ids_path, str(track.track_id) + "_" + str(frame_num) + "_" + str(num_ids)+".png")
cv2.imwrite(final_ids_path, crop_ids)
for track in tracker.tracks:
if not track.is_confirmed() or track.time_since_update > 1:
continue
bbox = track.to_tlbr()
class_name = track.get_class()
# predict ocr
crop_ids = frame[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])]
dict_box_sign_out, dict_rec_sign_out = [], [] # handle_sts(crop_ids)
# draw bbox on screen
color = colors[int(track.track_id) % len(colors)]
color = [i * 255 for i in color]
cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), color, 2)
cv2.rectangle(frame, (int(bbox[0]), int(bbox[1]-30)), (int(bbox[0])+(len(class_name)+len(str(track.track_id)))*17, int(bbox[1])), color, -1)
cv2.putText(frame, class_name + "-" + str(track.track_id),(int(bbox[0]), int(bbox[1]-10)),0, 0.75, (255,255,255),2)
dict_rec_sign_out_join = "_".join(dict_rec_sign_out)
cv2.putText(frame, dict_rec_sign_out_join, (int(bbox[0]), int(bbox[1]+20)), 0, 0.75, (255, 255, 255), 2)
# if enable info flag then print details about each track
if args.info:
print("Tracker ID: {}, Class: {}, BBox Coords (xmin, ymin, xmax, ymax): {}".format(str(track.track_id), class_name, (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]))))
ff.write("{}, {}, {}, {}, {}, {}\n".format(str(track.track_id), int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]), dict_rec_sign_out_join))
ff.close()
# calculate frames per second of running detections
fps = 1.0 / (time.time() - start_time)
print("FPS: %.2f" % fps)
result = frame
if not args.dont_show:
cv2.imshow("Output Video", result)
# if output flag is set, save video file
if args.output:
cv2.imwrite("./outputs/{0}.jpg".format(frame_num), result)
out.write(result)
if cv2.waitKey(1) & 0xFF == ord('q'): break
cv2.destroyAllWindows()
def main():
args = _parse_args()
handle(args)
if __name__ == '__main__':
main()
|