Weed-Detection-App / detection.py
Timmyafolami's picture
Upload 11 files
15c7eff verified
import os
import cv2
import zipfile
import shapefile
import numpy as np
from shapely.geometry import Polygon
from io import BytesIO
from PIL import Image
import rasterio
from rasterio.windows import Window
from ultralytics import YOLO
from db_bucket import upload_file_to_bucket # Import the bucket upload function
# Paths and configurations
path_to_store_bounding_boxes = 'detect/'
path_to_save_shapefile = 'weed_detections'
slice_size = 3000
# Load YOLO model (update the path to your model)
model = YOLO('new_yolov8_best.pt')
class_names = ["citrus area", "trees", "weeds", "weeds and trees"]
# Function to initialize directories
def initialize_directories():
os.makedirs(path_to_store_bounding_boxes, exist_ok=True)
os.makedirs("slices", exist_ok=True)
# Function to slice the GeoTIFF
async def slice_geotiff(file_path, slice_size=3000):
slices = []
with rasterio.open(file_path) as dataset:
img_width = dataset.width
img_height = dataset.height
transform = dataset.transform
for i in range(0, img_height, slice_size):
for j in range(0, img_width, slice_size):
window = Window(j, i, slice_size, slice_size)
transform_window = rasterio.windows.transform(window, transform)
slice_data = dataset.read(window=window)
slice_img = Image.fromarray(slice_data.transpose(1, 2, 0)) # Convert to HWC format
slice_filename = f"slices/slice_{i}_{j}.png"
slice_img.save(slice_filename)
slices.append((slice_filename, transform_window))
return slices
# Function to create a shapefile with image dimensions and bounding boxes
def create_shapefile_with_latlon(bboxes, shapefile_path="weed_detections.shp"):
w = shapefile.Writer(shapefile_path)
w.field('id', 'C')
for idx, (x1, y1, x2, y2, transform) in enumerate(bboxes):
top_left = rasterio.transform.xy(transform, y1, x1, offset='center')
top_right = rasterio.transform.xy(transform, y1, x2, offset='center')
bottom_left = rasterio.transform.xy(transform, y2, x1, offset='center')
bottom_right = rasterio.transform.xy(transform, y2, x2, offset='center')
poly = Polygon([top_left, top_right, bottom_right, bottom_left, top_left])
w.poly([poly.exterior.coords])
w.record(f'weed_{idx}')
w.close()
# Function to detect weeds in image slices
async def detect_weeds_in_slices(slices):
weed_bboxes = []
img_width, img_height = slice_size, slice_size # Assuming fixed slice size
for slice_filename, transform in slices:
img_array = np.array(Image.open(slice_filename))
results = model.predict(slice_filename, imgsz=640, conf=0.2, iou=0.4)
results = results[0]
for i, box in enumerate(results.boxes):
tensor = box.xyxy[0]
x1 = int(tensor[0].item())
y1 = int(tensor[1].item())
x2 = int(tensor[2].item())
y2 = int(tensor[3].item())
conf = box.conf[0].item()
label = box.cls[0].item()
if class_names[int(label)] == "weeds":
cv2.rectangle(img_array, (x1, y1), (x2, y2), (255, 0, 255), 3)
weed_bboxes.append((x1, y1, x2, y2, transform))
# Save the image with bounding boxes
detected_image_path = os.path.join(path_to_store_bounding_boxes, os.path.basename(slice_filename))
cv2.imwrite(detected_image_path, cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
create_shapefile_with_latlon(weed_bboxes)
async def create_zip():
# Create a zip file
zip_file_path = "weed_detections.zip"
with zipfile.ZipFile(zip_file_path, 'w') as zip_file:
for ext in ['shp', 'shx', 'dbf']:
file_name = f"{path_to_save_shapefile}.{ext}"
if os.path.exists(file_name):
zip_file.write(file_name, os.path.basename(file_name))
return zip_file_path
# Function to clean up created files and directories
def cleanup():
# Remove the zip file
if os.path.exists("weed_detections.zip"):
os.remove("weed_detections.zip")
# Remove the geotiff file
if os.path.exists("uploaded_geotiff.tif"):
os.remove("uploaded_geotiff.tif")
# Remove shapefile components
for ext in ['shp', 'shx', 'dbf']:
file_name = f"{path_to_save_shapefile}.{ext}"
if os.path.exists(file_name):
os.remove(file_name)
# Remove slices
if os.path.exists("slices"):
for file in os.listdir("slices"):
file_path = os.path.join("slices", file)
if os.path.isfile(file_path):
os.remove(file_path)
os.rmdir("slices")
# Remove detected bounding boxes
if os.path.exists("detect"):
for file in os.listdir("detect"):
file_path = os.path.join("detect", file)
if os.path.isfile(file_path):
os.remove(file_path)
os.rmdir("detect")