import io from ast import mod import gradio as gr from PIL import Image import matplotlib.pyplot as plt import matplotlib.patches as mpatches import torchvision.transforms as transforms import torch from huggingface_hub import hf_hub_download from ellipse_rcnn import EllipseRCNN # load model.pth from Filipstrozik/sat-tree-detection-v0 repository in hugging face load_state_dict = torch.load( hf_hub_download("Filipstrozik/sat-tree-detection-v0", "model.pth"), weights_only=True, ) model = EllipseRCNN() model.load_state_dict(load_state_dict) model.eval() def conic_center(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Returns center of ellipse in 2D cartesian coordinate system with numerical stability.""" # Extract the top-left 2x2 submatrix of the conic matrix A = conic_matrix[..., :2, :2] # Add stabilization for pseudoinverse computation by clamping singular values A_pinv = torch.linalg.pinv(A, rcond=torch.finfo(A.dtype).eps) # Extract the last two rows for the linear term b = -conic_matrix[..., :2, 2][..., None] # Stabilize any potential numerical instabilities centers = torch.matmul(A_pinv, b).squeeze() return centers[..., 0], centers[..., 1] def ellipse_axes(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Returns semi-major and semi-minor axes of ellipse in 2D cartesian coordinate system.""" lambdas = ( torch.linalg.eigvalsh(conic_matrix[..., :2, :2]) / (-torch.det(conic_matrix) / torch.det(conic_matrix[..., :2, :2]))[..., None] ) axes = torch.sqrt(1 / lambdas) return axes[..., 0], axes[..., 1] def ellipse_angle(conic_matrix: torch.Tensor) -> torch.Tensor: """Returns angle of ellipse in radians w.r.t. x-axis.""" return ( -torch.atan2( 2 * conic_matrix[..., 1, 0], conic_matrix[..., 1, 1] - conic_matrix[..., 0, 0], ) / 2 ) def get_ellipse_params_from_matrices(ellipse_matrices): if ellipse_matrices.shape[0] == 0: return None a, b = ellipse_axes(ellipse_matrices) cx, cy = conic_center(ellipse_matrices) theta = ellipse_angle(ellipse_matrices) a = a.view(-1) b = b.view(-1) cx = cx.view(-1) cy = cy.view(-1) theta = theta.view(-1) ellipses = torch.stack([a, b, cx, cy, theta], dim=1).reshape(-1, 5) return ellipses def plot_ellipses( ellipse_params: torch.Tensor, image: torch.Tensor, plot_centers: bool = False, rim_color: str = "r", alpha: float = 0.25, ) -> None: if ellipse_params is None: return a, b, cx, cy, theta = ellipse_params.unbind(-1) # multiply all pixel values by 4 cx = cx * 4 cy = cy * 4 # draw ellipses for i in range(len(a)): ellipse = mpatches.Ellipse( (cx[i], cy[i]), width=a[i], height=b[i], angle=theta[i], fill=True, alpha=alpha, color=rim_color, ) plt.gca().add_patch(ellipse) if plot_centers: plt.scatter(cx[i], cy[i], c=rim_color, s=10) plt.imshow(image) # Define the necessary transformations and the inverse normalization def invert_normalization(image, mean, std): for t, m, s in zip(image, mean, std): t.mul_(s).add_(m) return torch.clamp(image, 0, 1) def process_image(image): original_size = image.size # Define the transform pipeline transform = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) image_tensor = transform(image).unsqueeze(0) # Add batch dimension return image_tensor, original_size def generate_prediction(image, rpn_nms_thresh, score_thresh, nms_thresh): # Preprocess image image_tensor, original_size = process_image(image) image_tensor = image_tensor.to("cpu") # Ensure the model is in evaluation mode model.rpn.nms_thresh = rpn_nms_thresh model.roi_heads.score_thresh = score_thresh model.roi_heads.nms_thresh = nms_thresh with torch.no_grad(): prediction = model(image_tensor)[0] # Invert normalization for display mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] inverted_image = ( invert_normalization(image_tensor, mean, std) .squeeze(0) .permute(1, 2, 0) .cpu() .numpy() ) # Plot results with ellipses plt.figure(figsize=(10, 10)) plt.imshow(inverted_image) plot_ellipses( get_ellipse_params_from_matrices(prediction["ellipse_matrices"]), inverted_image, plot_centers=True, rim_color="red", alpha=0.25, ) red_patch = mpatches.Patch(color="red", label="Predicted") plt.legend(handles=[red_patch], loc="upper right") plt.gca().set_aspect(original_size[0] / original_size[1]) plt.axis("off") plt.tight_layout() # Save the figure to a buffer and return as an image buf = io.BytesIO() plt.savefig(buf, format="png") buf.seek(0) with Image.open(buf) as output_image: output_image = output_image.copy() buf.close() return output_image # Define Gradio interface with gr.Blocks() as demo: gr.Markdown("## Tree Detection from Satellite Images") gr.Markdown( "Upload an image and see the detected trees with ellipses. For better predictions, upload a high-resoltion image of orthophotomap with zoom level 18." ) gr.Markdown( "Try different values for RPN NMS Threshold, ROI Heads Score Threshold, and ROI Heads NMS Threshold to see how they affect the predictions." ) with gr.Row(): image_input = gr.Image(label="Input Image", type="pil") image_output = gr.Image(label="Detected Trees") examples = [ ["examples/image1.jpg"], ["examples/image2.jpg"], ["examples/image3.jpg"], ] with gr.Row(): rpn_nms_slider = gr.Slider( 0.0, 1.0, value=model.rpn.nms_thresh, label="RPN NMS Threshold" ) score_thresh_slider = gr.Slider( 0.0, 1.0, value=model.roi_heads.score_thresh, label="ROI Heads Score Threshold", ) nms_thresh_slider = gr.Slider( 0.0, 1.0, value=model.roi_heads.nms_thresh, label="ROI Heads NMS Threshold" ) submit_button = gr.Button("Detect Trees") submit_button.click( fn=generate_prediction, inputs=[image_input, rpn_nms_slider, score_thresh_slider, nms_thresh_slider], outputs=image_output, ) gr.Examples(examples=examples, inputs=image_input, outputs=image_output) demo.launch()