db2 / app.py
l337chode's picture
Upload app.py
d4e016a verified
from transformers import PreTrainedModel
import numpy as np
import tensorflow as tf
from PIL import ImageDraw
class YOLOv8TFLiteForDroneDetection(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model_path = "best_float32.tflite" # Path to the TFLite model file
self.interpreter = tf.lite.Interpreter(model_path=self.model_path)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# This method is used to load the model from a path or a Hugging Face model hub
return cls()
def preprocess(self, image):
# Resize and normalize input image
image = image.resize((416, 416)) # Resize to match the model input size
image = np.asarray(image) / 255.0 # Normalize pixel values to [0, 1]
image = np.expand_dims(image, axis=0) # Add batch dimension
return image.astype(np.float32)
def predict(self, image):
# Perform inference on the input image
input_data = self.preprocess(image)
self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
self.interpreter.invoke()
outputs = [self.interpreter.get_tensor(output_detail['index']) for output_detail in self.output_details]
return outputs
def draw_boxes(self, image, boxes):
# Draw bounding boxes on the image
draw = ImageDraw.Draw(image)
for box in boxes:
# Box format: [x_min, y_min, x_max, y_max, confidence, class_id]
x_min, y_min, x_max, y_max, _, _ = box
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=2)
return image
def detect_and_draw_boxes(self, image):
# Detect drones and draw bounding boxes
outputs = self.predict(image)
# Process outputs as needed
# Example: assuming the first output is bounding box coordinates
boxes = outputs[0][0] # Assuming batch size is 1
image_with_boxes = self.draw_boxes(image.copy(), boxes)
return image_with_boxes
# Example usage:
# model = YOLOv8TFLiteForDroneDetection.from_pretrained("path_to_tflite_model")
# image = Image.open("drone_image.jpg")
# image_with_boxes = model.detect_and_draw_boxes(image)
# image_with_boxes.show()