l337chode commited on
Commit
d4e016a
·
verified ·
1 Parent(s): cb4ee8d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from PIL import ImageDraw
5
+
6
+ class YOLOv8TFLiteForDroneDetection(PreTrainedModel):
7
+ def __init__(self, config):
8
+ super().__init__(config)
9
+ self.model_path = "best_float32.tflite" # Path to the TFLite model file
10
+ self.interpreter = tf.lite.Interpreter(model_path=self.model_path)
11
+ self.interpreter.allocate_tensors()
12
+ self.input_details = self.interpreter.get_input_details()
13
+ self.output_details = self.interpreter.get_output_details()
14
+
15
+ @classmethod
16
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
17
+ # This method is used to load the model from a path or a Hugging Face model hub
18
+ return cls()
19
+
20
+ def preprocess(self, image):
21
+ # Resize and normalize input image
22
+ image = image.resize((416, 416)) # Resize to match the model input size
23
+ image = np.asarray(image) / 255.0 # Normalize pixel values to [0, 1]
24
+ image = np.expand_dims(image, axis=0) # Add batch dimension
25
+ return image.astype(np.float32)
26
+
27
+ def predict(self, image):
28
+ # Perform inference on the input image
29
+ input_data = self.preprocess(image)
30
+ self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
31
+ self.interpreter.invoke()
32
+ outputs = [self.interpreter.get_tensor(output_detail['index']) for output_detail in self.output_details]
33
+ return outputs
34
+
35
+ def draw_boxes(self, image, boxes):
36
+ # Draw bounding boxes on the image
37
+ draw = ImageDraw.Draw(image)
38
+ for box in boxes:
39
+ # Box format: [x_min, y_min, x_max, y_max, confidence, class_id]
40
+ x_min, y_min, x_max, y_max, _, _ = box
41
+ draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=2)
42
+ return image
43
+
44
+ def detect_and_draw_boxes(self, image):
45
+ # Detect drones and draw bounding boxes
46
+ outputs = self.predict(image)
47
+ # Process outputs as needed
48
+ # Example: assuming the first output is bounding box coordinates
49
+ boxes = outputs[0][0] # Assuming batch size is 1
50
+ image_with_boxes = self.draw_boxes(image.copy(), boxes)
51
+ return image_with_boxes
52
+
53
+ # Example usage:
54
+ # model = YOLOv8TFLiteForDroneDetection.from_pretrained("path_to_tflite_model")
55
+ # image = Image.open("drone_image.jpg")
56
+ # image_with_boxes = model.detect_and_draw_boxes(image)
57
+ # image_with_boxes.show()
58
+
59
+