fall-detection / app.py
Kaelan's picture
Update app.py
287cd5c verified
import gradio as gr
import spaces
from super_gradients.training import models
from deep_sort_torch.deep_sort.deep_sort import DeepSort
from super_gradients.training import models
from super_gradients.training.pipelines.pipelines import DetectionPipeline
from model_tools import get_prediction, get_color
import cv2
import datetime
import torch
import os
import gradio as gr
import numpy as np
np.float = float
np.int = int
np.object = object
np.bool = bool
dir = os.getcwd()+ '/uploads/'
inp = gr.Image(type="pil")
output = gr.Image(type="pil")
examples=[[dir +"cafe_fall.mp4","Fall in cafe"],
[dir +"slip.mp4","Run and Fall2"],
[dir +"skate.mp4","Skate and Fall"],
[dir +"kitchen.mp4","Fall in kitchen"],
[dir +"studycam.mp4","Experiment fall"]]
ckpt_path = os.getcwd() + "/checkpoints/best181-8376/ckpt_latest.pth"
best_model = models.get('yolo_nas_s',
num_classes=1,
checkpoint_path=ckpt_path)
best_model = best_model.to("cuda" if torch.cuda.is_available() else "cpu")
#best_model = models.get("yolo_nas_s", pretrained_weights="coco")
best_model.eval()
#### Initiatize tracker
tracker_model = os.getcwd() + "/checkpoints/ckpt.t7"
tracker = DeepSort(model_path=tracker_model,max_age=30,nn_budget=100, max_iou_distance=0.7, max_dist=0.2)
out_path=dir
filename = 'demo.webm'
description = "Yolo model to detect if a person is falling or fallen with deepsort to track how long the subject has fallen.\
If the duration crosses a threshold of 5s, the bounding box will turn red and the subject be labelled as IMMOBILE."
@spaces.GPU
def vid_predict(media):
pipeline = DetectionPipeline(
model=best_model,
image_processor=best_model._image_processor,
post_prediction_callback=best_model.get_post_prediction_callback(iou=0.25, conf=0.70,
nms_top_k=100, # Example value, adjust based on your needs
max_predictions=50, # Example value, adjust based on your needs
multi_label_per_box=False, # Example value, adjust based on your needs
class_agnostic_nms=False),
class_names=best_model._class_names,
)
print("Running Predict")
save_to = os.path.join(out_path, filename)
cap = cv2.VideoCapture(media)
if cap.isOpened():
width = cap.get(3) # float `widtqh`
print('width',width)
height = cap.get(4)
print('Height',height)
fps = cap.get(cv2.CAP_PROP_FPS)
# or
fps = cap.get(5)
print('fps:', fps) # float `fps`
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
# or frame_count = cap.get(7)
print('frames count:', frame_count) # float `frame_count`
out = cv2.VideoWriter(save_to, cv2.VideoWriter_fourcc(*'VP08'), fps, (640,640))
fall_records = {}
frame_id = 0
while True:
frame_id += 1
if frame_id > frame_count:
break
print('frame_id', frame_id)
ret, img = cap.read()
#img = cv2.resize(img, (1280, 720),cv2.INTER_AREA)
# if height > 720:
# print("Reshaped")
img = cv2.resize(img, (640, 640),cv2.INTER_AREA)
width, height = img.shape[1], img.shape[0]
### recalibrate color channels to rgb for use in model prediction
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
overlay = img.copy()
### create list objects needed for tracking
detects = []
conffs = []
if ret:
print("START ")
model_predictions = get_prediction(best_model, img_rgb, pipeline)
print(model_predictions)
classnames = ['Fall-Detected']
results = model_predictions
bboxes = results.bboxes_xyxy
if len(bboxes) >= 1:
confs = results.confidence
labels = results.labels
for bbox, conf, label in zip(bboxes, confs, labels):
label = int(label)
conf = np.round(conf, decimals=2)
x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
### for tracking model
bw = abs(x1 - x2)
bh = abs(y1 - y2)
cx , cy = x1 + bw//2, y1 + bh//2
coords = [cx, cy, bw, bh]
detects.append(coords)
conffs.append([float(conf)])
### Tracker
xywhs = torch.tensor(detects)
conffs = torch.tensor(conffs)
#tracker_results = deepsort.update(xywhs, confss,oids, img)
tracker_results = tracker.update(xywhs, conffs, img_rgb)
### conduct check on track_records
now = datetime.datetime.now()
if len(fall_records.keys()) >=1:
#print(fall_records)
### reset timer for calculating immobility to 0 if time lapsed since last detection of fall more than N seconds
fall_records = {id: item if (now - item['present']).total_seconds() <= 3.0 else {'start':now, 'present': now} for id, item in fall_records.items() }
if len(tracker_results)>=1:
for track,conf,label in zip(tracker_results,conffs, labels):
conf = conf.numpy()[0]
duration = 0
minute = 0
sec = 0
x1, y1 ,x2, y2, id = track
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
if id in fall_records.keys():
### record present time
present = datetime.datetime.now()
fall_records[id].update({'present': present})
### calculate duration
duration = fall_records[id]['present'] - fall_records[id]['start']
duration = int(duration.total_seconds())
### record status
fall_records[id].update({'status': 'IMMOBILE'}) if duration >= 5 else fall_records[id].update({'status': None})
print(f"Frame:{frame_id} ID: {id} Conf: {conf} Duration:{duration} Status: {fall_records[id]['status']}")
print(fall_records[id])
minute, sec = divmod(duration,60)
else:
start = datetime.datetime.now()
fall_records[id] = {'start': start}
fall_records[id].update({'present': start})
classname = classnames[int(label)]
color = get_color(id*20)
if duration < 5:
display_text = f"{str(classname)} ({str(id)}) {str(conf)} Elapsed: {round(minute)}min{round(sec)}s"
(w, h), _ = cv2.getTextSize(
display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 1)
cv2.rectangle(img,(x1, y1), (x2, y2),color,1)
cv2.rectangle(overlay,(x1, y1), (x2, y2),color,1)
cv2.rectangle(overlay, (min(x1,int(width)-w), max(1,y1 - 20)), (min(x1+ w,int(width)) , max(21,y1)), color, cv2.FILLED)
else:
display_text = f"{str(classname)} ({str(id)}) {str(conf)} IMMOBILE: {round(minute)}min{round(sec)}s "
(w, h), _ = cv2.getTextSize(
display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 1)
cv2.rectangle(img,(x1, y1), (x2, y2),(0,0,255),1)
cv2.rectangle(overlay,(x1, y1), (x2, y2),(0,0,255),1)
cv2.rectangle(overlay, (min(x1,int(width)-w), max(1,y1 - 20)), (min(x1+ w,int(width)) , max(21,y1)), (0,0,255), cv2.FILLED)
cv2.putText(img,display_text, (min(x1,int(width)-w), max(21,y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0),2)
cv2.putText(overlay,display_text, (min(x1,int(width)-w), max(21,y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0),2)
alpha = 0.6
masked = cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0)
out.write(masked)
cap.release()
out.release()
cv2.destroyAllWindows()
return save_to
def run():
demo = gr.Interface(fn=vid_predict, inputs=gr.Video(format='mp4'), outputs=gr.Video(), examples=examples, description=description,cache_examples=False, title='Fall detection and tracking with deep sort')
demo.launch(server_port=7860)
if __name__ == "__main__":
run()