fall-detection / model_tools.py
Kaelan
real hard
056020b
raw
history blame
1.75 kB
import cv2
import datetime
from matplotlib.colors import hsv_to_rgb
import torch
import numpy as np
from super_gradients.training import models
from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
from super_gradients.training.pipelines.pipelines import DetectionPipeline
from deep_sort_torch.deep_sort.deep_sort import DeepSort
import os
# make sure to set IOU and confidence in the pipeline constructor
def get_color(number):
""" Converts an integer number to a color """
# change these however you want to
hue = number*30 % 180
saturation = number*103 % 256
value = number*50 % 256
# expects normalized values
hsv_array = [hue/179, saturation/255, value/255]
rgb = hsv_to_rgb(hsv_array)
return [int(c*255) for c in rgb]
def img_predict(media, model, out_path,filename):
save_to = os.path.join(out_path, filename)
images_predictions = model.predict(media,conf=0.70,fuse_model=False)
images_predictions.save(output_folder=out_path, box_thickness=2, show_confidence=True)
return None
def get_prediction(model, image_in, pipeline):
''' Obtains DetectionPrediction object from a single input RGB image
'''
# Preprocess
preprocessed_image, processing_metadata = pipeline.image_processor.preprocess_image(image=image_in.copy())
# Predict
with torch.no_grad():
torch_input = torch.Tensor(preprocessed_image).unsqueeze(0).to('cuda')
model_output = model(torch_input)
prediction = pipeline._decode_model_output(model_output, model_input=torch_input)
# Postprocess
return pipeline.image_processor.postprocess_predictions(predictions=prediction[0], metadata=processing_metadata)