Spaces:
Sleeping
Sleeping
import cv2 | |
import datetime | |
from matplotlib.colors import hsv_to_rgb | |
import torch | |
import numpy as np | |
from super_gradients.training import models | |
from deep_sort_torch.deep_sort.deep_sort import DeepSort | |
import os | |
def get_color(number): | |
""" Converts an integer number to a color """ | |
hue = number*30 % 180 | |
saturation = number*103 % 256 | |
value = number*50 % 256 | |
hsv_array = [hue/179, saturation/255, value/255] | |
rgb = hsv_to_rgb(hsv_array) | |
return [int(c*255) for c in rgb] | |
def img_predict(media, model, out_path,filename): | |
save_to = os.path.join(out_path, filename) | |
images_predictions = model.predict(media,conf=0.70,fuse_model=False) | |
images_predictions.save(output_folder=out_path, box_thickness=2, show_confidence=True) | |
return None | |
def vid_predict(media, model, tracker, out_path,filename): | |
print("Running Predict") | |
save_to = os.path.join(out_path, filename) | |
cap = cv2.VideoCapture(media) | |
if cap.isOpened(): | |
width = cap.get(3) # float `widtqh` | |
print('width',width) | |
height = cap.get(4) | |
print('Height',height) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
# or | |
fps = cap.get(5) | |
print('fps:', fps) # float `fps` | |
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) | |
# or frame_count = cap.get(7) | |
print('frames count:', frame_count) # float `frame_count` | |
out = cv2.VideoWriter(save_to, cv2.VideoWriter_fourcc(*'vp80'), fps, (640,640)) | |
fall_records = {} | |
frame_id = 0 | |
while True: | |
frame_id += 1 | |
if frame_id > frame_count: | |
break | |
print('frame_id', frame_id) | |
ret, img = cap.read() | |
img = cv2.resize(img, (640, 640),cv2.INTER_AREA) | |
width, height = img.shape[1], img.shape[0] | |
### recalibrate color channels to rgb for use in model prediction | |
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
overlay = img.copy() | |
### create list objects needed for tracking | |
detects = [] | |
conffs = [] | |
if ret: | |
model_predictions = model.predict(img_rgb,conf=0.70,fuse_model=False) | |
classnames = model_predictions[0].class_names | |
results = model_predictions[0].prediction | |
bboxes = results.bboxes_xyxy | |
if len(bboxes) >= 1: | |
confs = results.confidence | |
labels = results.labels | |
for bbox, conf, label in zip(bboxes, confs, labels): | |
label = int(label) | |
conf = np.round(conf, decimals=2) | |
x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3] | |
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
### for tracking model | |
bw = abs(x1 - x2) | |
bh = abs(y1 - y2) | |
cx , cy = x1 + bw//2, y1 + bh//2 | |
coords = [cx, cy, bw, bh] | |
detects.append(coords) | |
conffs.append([float(conf)]) | |
### Tracker | |
xywhs = torch.tensor(detects) | |
conffs = torch.tensor(conffs) | |
tracker_results = tracker.update(xywhs, conffs, img_rgb) | |
### conduct check on track_records | |
now = datetime.datetime.now() | |
if len(fall_records.keys()) >=1: | |
### reset timer for calculating immobility to 0 if time lapsed since last detection of fall more than N seconds | |
fall_records = {id: item if (now - item['present']).total_seconds() <= 3.0 else {'start':now, 'present': now} for id, item in fall_records.items() } | |
if len(tracker_results)>=1: | |
for track,conf,label in zip(tracker_results,conffs, labels): | |
conf = conf.numpy()[0] | |
duration = 0 | |
minute = 0 | |
sec = 0 | |
x1, y1 ,x2, y2, id = track | |
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
if id in fall_records.keys(): | |
### record present time | |
present = datetime.datetime.now() | |
fall_records[id].update({'present': present}) | |
### calculate duration | |
duration = fall_records[id]['present'] - fall_records[id]['start'] | |
duration = int(duration.total_seconds()) | |
### record status | |
fall_records[id].update({'status': 'IMMOBILE'}) if duration >= 5 else fall_records[id].update({'status': None}) | |
print(f"Frame:{frame_id} ID: {id} Conf: {conf} Duration:{duration} Status: {fall_records[id]['status']}") | |
print(fall_records[id]) | |
minute, sec = divmod(duration,60) | |
else: | |
start = datetime.datetime.now() | |
fall_records[id] = {'start': start} | |
fall_records[id].update({'present': start}) | |
classname = classnames[int(label)] | |
color = get_color(id*20) | |
if duration < 5: | |
display_text = f"{str(classname)} ({str(id)}) {str(conf)} Elapsed: {round(minute)}min{round(sec)}s" | |
(w, h), _ = cv2.getTextSize( | |
display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 1) | |
cv2.rectangle(img,(x1, y1), (x2, y2),color,1) | |
cv2.rectangle(overlay,(x1, y1), (x2, y2),color,1) | |
cv2.rectangle(overlay, (min(x1,int(width)-w), max(1,y1 - 20)), (min(x1+ w,int(width)) , max(21,y1)), color, cv2.FILLED) | |
else: | |
display_text = f"{str(classname)} ({str(id)}) {str(conf)} IMMOBILE: {round(minute)}min{round(sec)}s " | |
(w, h), _ = cv2.getTextSize( | |
display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 1) | |
cv2.rectangle(img,(x1, y1), (x2, y2),(0,0,255),1) | |
cv2.rectangle(overlay,(x1, y1), (x2, y2),(0,0,255),1) | |
cv2.rectangle(overlay, (min(x1,int(width)-w), max(1,y1 - 20)), (min(x1+ w,int(width)) , max(21,y1)), (0,0,255), cv2.FILLED) | |
cv2.putText(img,display_text, (min(x1,int(width)-w), max(21,y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0),2) | |
cv2.putText(overlay,display_text, (min(x1,int(width)-w), max(21,y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0),2) | |
### output image | |
alpha = 0.6 | |
masked = cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0) | |
out.write(masked) | |
cap.release() | |
out.release() | |
cv2.destroyAllWindows() | |
if __name__ == '__main__': | |
#ckpt_path = "/home/kaelan/Projects/Jupyter/Pytorch/Yolo-Nas/yolov-app/checkpoints/ckpt_latest.pth" | |
ckpt_path = "/home/kaelan/Projects/Jupyter/Pytorch/Yolo-Nas/checkpoints_Fall_detection/Fall_yolonas_run2/ckpt_latest.pth" | |
best_model = models.get('yolo_nas_s', | |
num_classes=1, | |
checkpoint_path=ckpt_path) | |
# best_model.set_dataset_processing_params( | |
# class_names=['Fall-Detected'], | |
# iou=0.35, conf=0.7, | |
# ) | |
best_model = best_model.to("cuda" if torch.cuda.is_available() else "cpu") | |
#best_model = models.get("yolo_nas_s", pretrained_weights="coco") | |
best_model.eval() | |
#### Initiatize tracker | |
tracker_model = "./checkpoints/ckpt.t7" | |
tracker = DeepSort(model_path=tracker_model,max_age=30,nn_budget=100, max_iou_distance=0.7, max_dist=0.2) | |
title = "skate.mp4" | |
media = "/home/kaelan/Projects/data/videos/" + title | |
vid_predict(media,best_model,tracker) | |