drkareemkamal's picture
Rename main.py to app.py
16cc839 verified
# import numpy as np
# import cv2
# import tritonclient.grpc as grpcclient
# import sys
# import argparse
# import os
# # Class names for the dataset
# class_names = [
# 'helm',
# 'no_helm',
# "person"
# ]
# def get_triton_client(url: str = '104.192.4.139:8001'):
# # try:
# keepalive_options = grpcclient.KeepAliveOptions(
# keepalive_time_ms=2**31 - 1,
# keepalive_timeout_ms=20000,
# keepalive_permit_without_calls=False,
# http2_max_pings_without_data=2
# )
# triton_client = grpcclient.InferenceServerClient(
# url=url,
# verbose=False,
# keepalive_options=keepalive_options)
# # except Exception as e:
# # print("Channel creation failed: " + str(e))
# # sys.exit()
# return triton_client
# def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h):
# label = f'{class_names[class_id]}: {confidence:.2f}'
# color = (255, 0, 0)
# cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2)
# cv2.putText(img, label, (x - 10, y - 10),
# cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# def process_frame(frame, expected_image_shape, model_name, triton_client):
# original_image, input_image, scale = preprocess_frame(frame, expected_image_shape)
# num_detections, detection_boxes, detection_scores, detection_classes = run_inference(
# model_name, input_image, triton_client)
# for index in range(num_detections[0]):
# box = detection_boxes[index]
# draw_bounding_box(original_image,
# detection_classes[index],
# detection_scores[index],
# round(box[0] * scale),
# round(box[1] * scale),
# round((box[0] + box[2]) * scale),
# round((box[1] + box[3]) * scale))
# return original_image
# def preprocess_frame(frame, expected_image_shape):
# expected_width = expected_image_shape[0]
# expected_height = expected_image_shape[1]
# expected_length = min(expected_height, expected_width)
# [height, width, _] = frame.shape
# length = max(height, width)
# image = np.zeros((length, length, 3), np.uint8)
# image[0:height, 0:width] = frame
# scale = length / expected_length
# input_image = cv2.resize(image, (expected_width, expected_height))
# input_image = (input_image / 255.0).astype(np.float32)
# input_image = input_image.transpose(2, 0, 1) # Channel first
# input_image = np.expand_dims(input_image, axis=0)
# return frame, input_image, scale
# def run_inference(model_name: str, input_image: np.ndarray, triton_client: grpcclient.InferenceServerClient):
# inputs = [grpcclient.InferInput('images', input_image.shape, "FP32")]
# inputs[0].set_data_from_numpy(input_image)
# outputs = [
# grpcclient.InferRequestedOutput('num_detections'),
# grpcclient.InferRequestedOutput('detection_boxes'),
# grpcclient.InferRequestedOutput('detection_scores'),
# grpcclient.InferRequestedOutput('detection_classes')
# ]
# results = triton_client.infer(model_name=model_name, inputs=inputs, outputs=outputs)
# num_detections = results.as_numpy('num_detections')
# detection_boxes = results.as_numpy('detection_boxes')
# detection_scores = results.as_numpy('detection_scores')
# detection_classes = results.as_numpy('detection_classes')
# return num_detections, detection_boxes, detection_scores, detection_classes
# def main(video_path, model_name, url):
# triton_client = get_triton_client(url)
# expected_image_shape = triton_client.get_model_metadata(model_name).inputs[0].shape[-2:]
# cap = cv2.VideoCapture(video_path)
# if not cap.isOpened():
# print("Error: Could not open video.")
# sys.exit()
# fourcc = cv2.VideoWriter_fourcc(*'XVID')
# output_path = os.path.splitext(video_path)[0] + "_output.avi"
# out = cv2.VideoWriter(output_path, fourcc, 20.0, (int(cap.get(3)), int(cap.get(4))))
# while True:
# ret, frame = cap.read()
# if not ret:
# break
# # Process each frame
# output_frame = process_frame(frame, expected_image_shape, model_name, triton_client)
# # Write processed frame to the output video
# out.write(output_frame)
# # Display the frame with bounding boxes
# cv2.imshow('Video', output_frame)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
# cap.release()
# out.release()
# cv2.destroyAllWindows()
# print(f"Output saved as {output_path}")
# if __name__ == '__main__':
# parser = argparse.ArgumentParser()
# parser.add_argument('--video_path', type=str, default='./assets/helmet.mp4')
# parser.add_argument('--model_name', type=str, default='yolov8_ensemble')
# parser.add_argument('--url', type=str, default='104.192.4.139:8001')
# args = parser.parse_args()
# main(args.video_path, args.model_name, args.url)
import numpy as np
import cv2
import tritonclient.grpc as grpcclient
import sys
import argparse
import os
# Class names for the dataset
class_names = [
'helm',
'no_helm',
"person"
]
def get_triton_client(url: str = 'localhost:8001'):
keepalive_options = grpcclient.KeepAliveOptions(
keepalive_time_ms=2**31 - 1,
keepalive_timeout_ms=20000,
keepalive_permit_without_calls=False,
http2_max_pings_without_data=2
)
return grpcclient.InferenceServerClient(url=url, verbose=False, keepalive_options=keepalive_options)
def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h):
label = f'{class_names[class_id]}: {confidence:.2f}'
color = (255, 0, 0)
cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2)
cv2.putText(img, label, (x - 10, y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
def process_frame(frame, expected_image_shape, model_name, triton_client):
original_image, input_image, scale = preprocess_frame(frame, expected_image_shape)
num_detections, detection_boxes, detection_scores, detection_classes = run_inference(
model_name, input_image, triton_client)
for index in range(num_detections[0]):
box = detection_boxes[index]
draw_bounding_box(original_image,
detection_classes[index],
detection_scores[index],
round(box[0] * scale),
round(box[1] * scale),
round((box[0] + box[2]) * scale),
round((box[1] + box[3]) * scale))
return original_image
def preprocess_frame(frame, expected_image_shape):
expected_width, expected_height = expected_image_shape
height, width, _ = frame.shape
length = max(height, width)
image = np.zeros((length, length, 3), np.uint8)
image[0:height, 0:width] = frame
scale = length / min(expected_image_shape)
input_image = cv2.resize(image, (expected_width, expected_height))
input_image = (input_image / 255.0).astype(np.float32)
input_image = input_image.transpose(2, 0, 1) # Channel first
input_image = np.expand_dims(input_image, axis=0)
return frame, input_image, scale
def run_inference(model_name: str, input_image: np.ndarray, triton_client: grpcclient.InferenceServerClient):
inputs = [grpcclient.InferInput('images', input_image.shape, "FP32")]
inputs[0].set_data_from_numpy(input_image)
outputs = [
grpcclient.InferRequestedOutput('num_detections'),
grpcclient.InferRequestedOutput('detection_boxes'),
grpcclient.InferRequestedOutput('detection_scores'),
grpcclient.InferRequestedOutput('detection_classes')
]
results = triton_client.infer(model_name=model_name, inputs=inputs, outputs=outputs)
num_detections = results.as_numpy('num_detections')
detection_boxes = results.as_numpy('detection_boxes')
detection_scores = results.as_numpy('detection_scores')
detection_classes = results.as_numpy('detection_classes')
return num_detections, detection_boxes, detection_scores, detection_classes
def main(video_path, model_name, url):
triton_client = get_triton_client(url)
expected_image_shape = triton_client.get_model_metadata(model_name).inputs[0].shape[-2:]
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print("Error: Could not open video.")
sys.exit()
fourcc = cv2.VideoWriter_fourcc(*'XVID')
output_path = os.path.splitext(video_path)[0] + "_output.avi"
out = cv2.VideoWriter(output_path, fourcc, 20.0, (int(cap.get(3)), int(cap.get(4))))
while True:
ret, frame = cap.read()
if not ret:
break
# Process each frame
output_frame = process_frame(frame, expected_image_shape, model_name, triton_client)
# Write processed frame to the output video
out.write(output_frame)
# Display the frame with bounding boxes
cv2.imshow('Video', output_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
out.release()
cv2.destroyAllWindows()
print(f"Output saved as {output_path}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--video_path', type=str, default='./assets/helmet.mp4')
parser.add_argument('--model_name', type=str, default='yolov8_ensemble')
parser.add_argument('--url', type=str, default='localhost:8001')
args = parser.parse_args()
main(args.video_path, args.model_name, args.url)