from transformers import PreTrainedModel import numpy as np import tensorflow as tf from PIL import ImageDraw class YOLOv8TFLiteForDroneDetection(PreTrainedModel): def __init__(self, config): super().__init__(config) self.model_path = "best_float32.tflite" # Path to the TFLite model file self.interpreter = tf.lite.Interpreter(model_path=self.model_path) self.interpreter.allocate_tensors() self.input_details = self.interpreter.get_input_details() self.output_details = self.interpreter.get_output_details() @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # This method is used to load the model from a path or a Hugging Face model hub return cls() def preprocess(self, image): # Resize and normalize input image image = image.resize((416, 416)) # Resize to match the model input size image = np.asarray(image) / 255.0 # Normalize pixel values to [0, 1] image = np.expand_dims(image, axis=0) # Add batch dimension return image.astype(np.float32) def predict(self, image): # Perform inference on the input image input_data = self.preprocess(image) self.interpreter.set_tensor(self.input_details[0]['index'], input_data) self.interpreter.invoke() outputs = [self.interpreter.get_tensor(output_detail['index']) for output_detail in self.output_details] return outputs def draw_boxes(self, image, boxes): # Draw bounding boxes on the image draw = ImageDraw.Draw(image) for box in boxes: # Box format: [x_min, y_min, x_max, y_max, confidence, class_id] x_min, y_min, x_max, y_max, _, _ = box draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=2) return image def detect_and_draw_boxes(self, image): # Detect drones and draw bounding boxes outputs = self.predict(image) # Process outputs as needed # Example: assuming the first output is bounding box coordinates boxes = outputs[0][0] # Assuming batch size is 1 image_with_boxes = self.draw_boxes(image.copy(), boxes) return image_with_boxes # Example usage: # model = YOLOv8TFLiteForDroneDetection.from_pretrained("path_to_tflite_model") # image = Image.open("drone_image.jpg") # image_with_boxes = model.detect_and_draw_boxes(image) # image_with_boxes.show()