Spaces:
Configuration error
Configuration error
# import numpy as np | |
# import cv2 | |
# import tritonclient.grpc as grpcclient | |
# import sys | |
# import argparse | |
# import os | |
# # Class names for the dataset | |
# class_names = [ | |
# 'helm', | |
# 'no_helm', | |
# "person" | |
# ] | |
# def get_triton_client(url: str = '104.192.4.139:8001'): | |
# # try: | |
# keepalive_options = grpcclient.KeepAliveOptions( | |
# keepalive_time_ms=2**31 - 1, | |
# keepalive_timeout_ms=20000, | |
# keepalive_permit_without_calls=False, | |
# http2_max_pings_without_data=2 | |
# ) | |
# triton_client = grpcclient.InferenceServerClient( | |
# url=url, | |
# verbose=False, | |
# keepalive_options=keepalive_options) | |
# # except Exception as e: | |
# # print("Channel creation failed: " + str(e)) | |
# # sys.exit() | |
# return triton_client | |
# def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h): | |
# label = f'{class_names[class_id]}: {confidence:.2f}' | |
# color = (255, 0, 0) | |
# cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2) | |
# cv2.putText(img, label, (x - 10, y - 10), | |
# cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
# def process_frame(frame, expected_image_shape, model_name, triton_client): | |
# original_image, input_image, scale = preprocess_frame(frame, expected_image_shape) | |
# num_detections, detection_boxes, detection_scores, detection_classes = run_inference( | |
# model_name, input_image, triton_client) | |
# for index in range(num_detections[0]): | |
# box = detection_boxes[index] | |
# draw_bounding_box(original_image, | |
# detection_classes[index], | |
# detection_scores[index], | |
# round(box[0] * scale), | |
# round(box[1] * scale), | |
# round((box[0] + box[2]) * scale), | |
# round((box[1] + box[3]) * scale)) | |
# return original_image | |
# def preprocess_frame(frame, expected_image_shape): | |
# expected_width = expected_image_shape[0] | |
# expected_height = expected_image_shape[1] | |
# expected_length = min(expected_height, expected_width) | |
# [height, width, _] = frame.shape | |
# length = max(height, width) | |
# image = np.zeros((length, length, 3), np.uint8) | |
# image[0:height, 0:width] = frame | |
# scale = length / expected_length | |
# input_image = cv2.resize(image, (expected_width, expected_height)) | |
# input_image = (input_image / 255.0).astype(np.float32) | |
# input_image = input_image.transpose(2, 0, 1) # Channel first | |
# input_image = np.expand_dims(input_image, axis=0) | |
# return frame, input_image, scale | |
# def run_inference(model_name: str, input_image: np.ndarray, triton_client: grpcclient.InferenceServerClient): | |
# inputs = [grpcclient.InferInput('images', input_image.shape, "FP32")] | |
# inputs[0].set_data_from_numpy(input_image) | |
# outputs = [ | |
# grpcclient.InferRequestedOutput('num_detections'), | |
# grpcclient.InferRequestedOutput('detection_boxes'), | |
# grpcclient.InferRequestedOutput('detection_scores'), | |
# grpcclient.InferRequestedOutput('detection_classes') | |
# ] | |
# results = triton_client.infer(model_name=model_name, inputs=inputs, outputs=outputs) | |
# num_detections = results.as_numpy('num_detections') | |
# detection_boxes = results.as_numpy('detection_boxes') | |
# detection_scores = results.as_numpy('detection_scores') | |
# detection_classes = results.as_numpy('detection_classes') | |
# return num_detections, detection_boxes, detection_scores, detection_classes | |
# def main(video_path, model_name, url): | |
# triton_client = get_triton_client(url) | |
# expected_image_shape = triton_client.get_model_metadata(model_name).inputs[0].shape[-2:] | |
# cap = cv2.VideoCapture(video_path) | |
# if not cap.isOpened(): | |
# print("Error: Could not open video.") | |
# sys.exit() | |
# fourcc = cv2.VideoWriter_fourcc(*'XVID') | |
# output_path = os.path.splitext(video_path)[0] + "_output.avi" | |
# out = cv2.VideoWriter(output_path, fourcc, 20.0, (int(cap.get(3)), int(cap.get(4)))) | |
# while True: | |
# ret, frame = cap.read() | |
# if not ret: | |
# break | |
# # Process each frame | |
# output_frame = process_frame(frame, expected_image_shape, model_name, triton_client) | |
# # Write processed frame to the output video | |
# out.write(output_frame) | |
# # Display the frame with bounding boxes | |
# cv2.imshow('Video', output_frame) | |
# if cv2.waitKey(1) & 0xFF == ord('q'): | |
# break | |
# cap.release() | |
# out.release() | |
# cv2.destroyAllWindows() | |
# print(f"Output saved as {output_path}") | |
# if __name__ == '__main__': | |
# parser = argparse.ArgumentParser() | |
# parser.add_argument('--video_path', type=str, default='./assets/helmet.mp4') | |
# parser.add_argument('--model_name', type=str, default='yolov8_ensemble') | |
# parser.add_argument('--url', type=str, default='104.192.4.139:8001') | |
# args = parser.parse_args() | |
# main(args.video_path, args.model_name, args.url) | |
import numpy as np | |
import cv2 | |
import tritonclient.grpc as grpcclient | |
import sys | |
import argparse | |
import os | |
# Class names for the dataset | |
class_names = [ | |
'helm', | |
'no_helm', | |
"person" | |
] | |
def get_triton_client(url: str = 'localhost:8001'): | |
keepalive_options = grpcclient.KeepAliveOptions( | |
keepalive_time_ms=2**31 - 1, | |
keepalive_timeout_ms=20000, | |
keepalive_permit_without_calls=False, | |
http2_max_pings_without_data=2 | |
) | |
return grpcclient.InferenceServerClient(url=url, verbose=False, keepalive_options=keepalive_options) | |
def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h): | |
label = f'{class_names[class_id]}: {confidence:.2f}' | |
color = (255, 0, 0) | |
cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2) | |
cv2.putText(img, label, (x - 10, y - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
def process_frame(frame, expected_image_shape, model_name, triton_client): | |
original_image, input_image, scale = preprocess_frame(frame, expected_image_shape) | |
num_detections, detection_boxes, detection_scores, detection_classes = run_inference( | |
model_name, input_image, triton_client) | |
for index in range(num_detections[0]): | |
box = detection_boxes[index] | |
draw_bounding_box(original_image, | |
detection_classes[index], | |
detection_scores[index], | |
round(box[0] * scale), | |
round(box[1] * scale), | |
round((box[0] + box[2]) * scale), | |
round((box[1] + box[3]) * scale)) | |
return original_image | |
def preprocess_frame(frame, expected_image_shape): | |
expected_width, expected_height = expected_image_shape | |
height, width, _ = frame.shape | |
length = max(height, width) | |
image = np.zeros((length, length, 3), np.uint8) | |
image[0:height, 0:width] = frame | |
scale = length / min(expected_image_shape) | |
input_image = cv2.resize(image, (expected_width, expected_height)) | |
input_image = (input_image / 255.0).astype(np.float32) | |
input_image = input_image.transpose(2, 0, 1) # Channel first | |
input_image = np.expand_dims(input_image, axis=0) | |
return frame, input_image, scale | |
def run_inference(model_name: str, input_image: np.ndarray, triton_client: grpcclient.InferenceServerClient): | |
inputs = [grpcclient.InferInput('images', input_image.shape, "FP32")] | |
inputs[0].set_data_from_numpy(input_image) | |
outputs = [ | |
grpcclient.InferRequestedOutput('num_detections'), | |
grpcclient.InferRequestedOutput('detection_boxes'), | |
grpcclient.InferRequestedOutput('detection_scores'), | |
grpcclient.InferRequestedOutput('detection_classes') | |
] | |
results = triton_client.infer(model_name=model_name, inputs=inputs, outputs=outputs) | |
num_detections = results.as_numpy('num_detections') | |
detection_boxes = results.as_numpy('detection_boxes') | |
detection_scores = results.as_numpy('detection_scores') | |
detection_classes = results.as_numpy('detection_classes') | |
return num_detections, detection_boxes, detection_scores, detection_classes | |
def main(video_path, model_name, url): | |
triton_client = get_triton_client(url) | |
expected_image_shape = triton_client.get_model_metadata(model_name).inputs[0].shape[-2:] | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print("Error: Could not open video.") | |
sys.exit() | |
fourcc = cv2.VideoWriter_fourcc(*'XVID') | |
output_path = os.path.splitext(video_path)[0] + "_output.avi" | |
out = cv2.VideoWriter(output_path, fourcc, 20.0, (int(cap.get(3)), int(cap.get(4)))) | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Process each frame | |
output_frame = process_frame(frame, expected_image_shape, model_name, triton_client) | |
# Write processed frame to the output video | |
out.write(output_frame) | |
# Display the frame with bounding boxes | |
cv2.imshow('Video', output_frame) | |
if cv2.waitKey(1) & 0xFF == ord('q'): | |
break | |
cap.release() | |
out.release() | |
cv2.destroyAllWindows() | |
print(f"Output saved as {output_path}") | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--video_path', type=str, default='./assets/helmet.mp4') | |
parser.add_argument('--model_name', type=str, default='yolov8_ensemble') | |
parser.add_argument('--url', type=str, default='localhost:8001') | |
args = parser.parse_args() | |
main(args.video_path, args.model_name, args.url) | |