Last commit not found
import os | |
import rasterio | |
import geopandas as gpd | |
from shapely.geometry import box | |
from rasterio.mask import mask | |
from PIL import Image | |
import numpy as np | |
import warnings | |
from rasterio.errors import NodataShadowWarning | |
import sys | |
warnings.filterwarnings("ignore", category=NodataShadowWarning) | |
def cut_trees(output_dir, geojson_path, tif_path): | |
# create output directory if it doesnt exist | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
# Load the GeoDataFrame | |
gdf = gpd.read_file(geojson_path) | |
# Clear the terminal screen | |
os.system('cls' if os.name == 'nt' else 'clear') | |
# Open the .tif file | |
with rasterio.open(tif_path) as src: | |
# Get the bounds of the .tif image | |
tif_bounds = box(*src.bounds) | |
# Get the CRS (Coordinate Reference System) of the .tif image | |
tif_crs = src.crs | |
# Reproject the GeoDataFrame to the CRS of the .tif file | |
gdf = gdf.to_crs(tif_crs) | |
# Loop through each polygon in the GeoDataFrame | |
N = len(gdf) | |
n = int(N/10) | |
image_counter = 0 | |
for idx, row in gdf.iterrows(): | |
if idx % n == 0: | |
progress = f"{round(idx/N*100)} % complete --> {idx}/{N}" | |
sys.stdout.write('\r' + progress) | |
sys.stdout.flush() | |
# Extract the geometry (polygon) | |
geom = row['geometry'] | |
name = row['id'] | |
# Check if the polygon intersects the image bounds | |
if geom.intersects(tif_bounds): | |
# Create a mask for the current polygon | |
out_image, out_transform = mask(src, [geom], crop=True) | |
# Convert the masked image to a numpy array | |
out_image = out_image.transpose(1, 2, 0) # rearrange dimensions for PIL (H, W, C) | |
# Ensure the array is not empty | |
if out_image.size == 0: | |
message = f"{round(idx/N*100)} % complete --> {idx}/{N} | Polygon {idx} resulted in an empty image and will be skipped." | |
sys.stdout.write('\r' + message) | |
sys.stdout.flush() | |
continue | |
# Remove the zero-padded areas (optional) | |
mask_array = (out_image[:, :, 0] != src.nodata) | |
non_zero_rows = np.any(mask_array, axis=1) | |
non_zero_cols = np.any(mask_array, axis=0) | |
# Ensure there are non-zero rows and columns | |
if not np.any(non_zero_rows) or not np.any(non_zero_cols): | |
message = f"{round(idx/N*100)} % complete --> {idx}/{N} | Polygon {idx} resulted in an invalid image area and will be skipped." | |
sys.stdout.write('\r' + message) | |
sys.stdout.flush() | |
continue | |
out_image = out_image[non_zero_rows][:, non_zero_cols] | |
# Convert to a PIL Image and save as PNG | |
out_image = Image.fromarray(out_image.astype(np.uint8)) # Ensure correct type for PIL | |
output_path = os.path.join(output_dir, f'tree_{name}.png') | |
out_image.save(output_path) | |
image_counter += 1 | |
else: | |
message = f"{round(idx/N*100)} % complete --> {idx}/{N} | Polygon {idx} is outside the image bounds and will be skipped." | |
sys.stdout.write('\r' + message) | |
sys.stdout.flush() | |
print(f'\n {image_counter}/{N} Tree images have been successfully saved in the "detected_trees" folder.') | |
def resize_images(input_folder, output_folder, target_size): | |
# Create the output folder if it doesn't exist | |
if not os.path.exists(output_folder): | |
os.makedirs(output_folder) | |
counter = 0 | |
# Loop through all files in the input folder | |
for filename in os.listdir(input_folder): | |
if filename.endswith('.png'): # Check for PNG files | |
# Open image | |
with Image.open(os.path.join(input_folder, filename)) as img: | |
# Resize image while preserving aspect ratio | |
img.thumbnail(target_size, Image.LANCZOS) | |
# Calculate paste position to center image in canvas | |
paste_pos = ((target_size[0] - img.size[0]) // 2, (target_size[1] - img.size[1]) // 2) | |
# Create a new blank canvas with the target size and black background | |
new_img = Image.new("RGBA", target_size, (0, 0, 0, 255)) | |
# Paste resized image onto the canvas | |
new_img.paste(img, paste_pos, img) | |
# Convert to RGB to remove transparency by merging with black background | |
new_img = new_img.convert("RGB") | |
# Save resized image to output folder | |
new_img.save(os.path.join(output_folder, filename)) | |
counter += 1 | |
# Display the counter | |
if counter % 50 == 0: | |
message = f"Processed {counter} images" | |
print(message, end='\r') | |
# Final message after processing all images | |
print(f"Processed a total of {counter} images.") | |
# THIS IS THE FUNCTION TO IMPORT | |
def generate_tree_images(geojson_path, tif_path, target_size = (224, 224)): | |
""" | |
INPUT: geojson path, tif_path that contain the trees, optional target_size of the resulting images | |
RETURNS: nothing | |
Action: It creates two folders: + "detected trees" --> the cut tree images | |
+ "tree_images" --> the processed cut tree images, ready to use for species recognition | |
""" | |
# Set input and output folders | |
folder_cut_trees = "detected_trees" | |
folder_finished_images = "tree_images" | |
# Set target size (width, height) | |
cut_trees(geojson_path = geojson_path, tif_path = tif_path, output_dir = folder_cut_trees) | |
resize_images(input_folder = folder_cut_trees, output_folder = folder_finished_images, target_size = target_size) | |