File size: 2,485 Bytes
d4e016a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from transformers import PreTrainedModel
import numpy as np
import tensorflow as tf
from PIL import ImageDraw

class YOLOv8TFLiteForDroneDetection(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.model_path = "best_float32.tflite"  # Path to the TFLite model file
        self.interpreter = tf.lite.Interpreter(model_path=self.model_path)
        self.interpreter.allocate_tensors()
        self.input_details = self.interpreter.get_input_details()
        self.output_details = self.interpreter.get_output_details()

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        # This method is used to load the model from a path or a Hugging Face model hub
        return cls()

    def preprocess(self, image):
        # Resize and normalize input image
        image = image.resize((416, 416))  # Resize to match the model input size
        image = np.asarray(image) / 255.0  # Normalize pixel values to [0, 1]
        image = np.expand_dims(image, axis=0)  # Add batch dimension
        return image.astype(np.float32)

    def predict(self, image):
        # Perform inference on the input image
        input_data = self.preprocess(image)
        self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
        self.interpreter.invoke()
        outputs = [self.interpreter.get_tensor(output_detail['index']) for output_detail in self.output_details]
        return outputs

    def draw_boxes(self, image, boxes):
        # Draw bounding boxes on the image
        draw = ImageDraw.Draw(image)
        for box in boxes:
            # Box format: [x_min, y_min, x_max, y_max, confidence, class_id]
            x_min, y_min, x_max, y_max, _, _ = box
            draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=2)
        return image

    def detect_and_draw_boxes(self, image):
        # Detect drones and draw bounding boxes
        outputs = self.predict(image)
        # Process outputs as needed
        # Example: assuming the first output is bounding box coordinates
        boxes = outputs[0][0]  # Assuming batch size is 1
        image_with_boxes = self.draw_boxes(image.copy(), boxes)
        return image_with_boxes

# Example usage:
# model = YOLOv8TFLiteForDroneDetection.from_pretrained("path_to_tflite_model")
# image = Image.open("drone_image.jpg")
# image_with_boxes = model.detect_and_draw_boxes(image)
# image_with_boxes.show()