|
from transformers import PreTrainedModel |
|
import numpy as np |
|
import tensorflow as tf |
|
from PIL import ImageDraw |
|
|
|
class YOLOv8TFLiteForDroneDetection(PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model_path = "best_float32.tflite" |
|
self.interpreter = tf.lite.Interpreter(model_path=self.model_path) |
|
self.interpreter.allocate_tensors() |
|
self.input_details = self.interpreter.get_input_details() |
|
self.output_details = self.interpreter.get_output_details() |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
|
|
return cls() |
|
|
|
def preprocess(self, image): |
|
|
|
image = image.resize((416, 416)) |
|
image = np.asarray(image) / 255.0 |
|
image = np.expand_dims(image, axis=0) |
|
return image.astype(np.float32) |
|
|
|
def predict(self, image): |
|
|
|
input_data = self.preprocess(image) |
|
self.interpreter.set_tensor(self.input_details[0]['index'], input_data) |
|
self.interpreter.invoke() |
|
outputs = [self.interpreter.get_tensor(output_detail['index']) for output_detail in self.output_details] |
|
return outputs |
|
|
|
def draw_boxes(self, image, boxes): |
|
|
|
draw = ImageDraw.Draw(image) |
|
for box in boxes: |
|
|
|
x_min, y_min, x_max, y_max, _, _ = box |
|
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=2) |
|
return image |
|
|
|
def detect_and_draw_boxes(self, image): |
|
|
|
outputs = self.predict(image) |
|
|
|
|
|
boxes = outputs[0][0] |
|
image_with_boxes = self.draw_boxes(image.copy(), boxes) |
|
return image_with_boxes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|