Filipstrozik
Enhance Gradio interface instructions for better user guidance on image uploads and parameter adjustments
9a734e7
import io | |
from ast import mod | |
import gradio as gr | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as mpatches | |
import torchvision.transforms as transforms | |
import torch | |
from huggingface_hub import hf_hub_download | |
from ellipse_rcnn import EllipseRCNN | |
# load model.pth from Filipstrozik/sat-tree-detection-v0 repository in hugging face | |
load_state_dict = torch.load( | |
hf_hub_download("Filipstrozik/sat-tree-detection-v0", "model.pth"), | |
weights_only=True, | |
) | |
model = EllipseRCNN() | |
model.load_state_dict(load_state_dict) | |
model.eval() | |
def conic_center(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
"""Returns center of ellipse in 2D cartesian coordinate system with numerical stability.""" | |
# Extract the top-left 2x2 submatrix of the conic matrix | |
A = conic_matrix[..., :2, :2] | |
# Add stabilization for pseudoinverse computation by clamping singular values | |
A_pinv = torch.linalg.pinv(A, rcond=torch.finfo(A.dtype).eps) | |
# Extract the last two rows for the linear term | |
b = -conic_matrix[..., :2, 2][..., None] | |
# Stabilize any potential numerical instabilities | |
centers = torch.matmul(A_pinv, b).squeeze() | |
return centers[..., 0], centers[..., 1] | |
def ellipse_axes(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
"""Returns semi-major and semi-minor axes of ellipse in 2D cartesian coordinate system.""" | |
lambdas = ( | |
torch.linalg.eigvalsh(conic_matrix[..., :2, :2]) | |
/ (-torch.det(conic_matrix) / torch.det(conic_matrix[..., :2, :2]))[..., None] | |
) | |
axes = torch.sqrt(1 / lambdas) | |
return axes[..., 0], axes[..., 1] | |
def ellipse_angle(conic_matrix: torch.Tensor) -> torch.Tensor: | |
"""Returns angle of ellipse in radians w.r.t. x-axis.""" | |
return ( | |
-torch.atan2( | |
2 * conic_matrix[..., 1, 0], | |
conic_matrix[..., 1, 1] - conic_matrix[..., 0, 0], | |
) | |
/ 2 | |
) | |
def get_ellipse_params_from_matrices(ellipse_matrices): | |
if ellipse_matrices.shape[0] == 0: | |
return None | |
a, b = ellipse_axes(ellipse_matrices) | |
cx, cy = conic_center(ellipse_matrices) | |
theta = ellipse_angle(ellipse_matrices) | |
a = a.view(-1) | |
b = b.view(-1) | |
cx = cx.view(-1) | |
cy = cy.view(-1) | |
theta = theta.view(-1) | |
ellipses = torch.stack([a, b, cx, cy, theta], dim=1).reshape(-1, 5) | |
return ellipses | |
def plot_ellipses( | |
ellipse_params: torch.Tensor, | |
image: torch.Tensor, | |
plot_centers: bool = False, | |
rim_color: str = "r", | |
alpha: float = 0.25, | |
) -> None: | |
if ellipse_params is None: | |
return | |
a, b, cx, cy, theta = ellipse_params.unbind(-1) | |
# multiply all pixel values by 4 | |
cx = cx * 4 | |
cy = cy * 4 | |
# draw ellipses | |
for i in range(len(a)): | |
ellipse = mpatches.Ellipse( | |
(cx[i], cy[i]), | |
width=a[i], | |
height=b[i], | |
angle=theta[i], | |
fill=True, | |
alpha=alpha, | |
color=rim_color, | |
) | |
plt.gca().add_patch(ellipse) | |
if plot_centers: | |
plt.scatter(cx[i], cy[i], c=rim_color, s=10) | |
plt.imshow(image) | |
# Define the necessary transformations and the inverse normalization | |
def invert_normalization(image, mean, std): | |
for t, m, s in zip(image, mean, std): | |
t.mul_(s).add_(m) | |
return torch.clamp(image, 0, 1) | |
def process_image(image): | |
original_size = image.size | |
# Define the transform pipeline | |
transform = transforms.Compose( | |
[ | |
transforms.Resize((1024, 1024)), | |
transforms.PILToTensor(), | |
transforms.ConvertImageDtype(torch.float), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
] | |
) | |
image_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
return image_tensor, original_size | |
def generate_prediction(image, rpn_nms_thresh, score_thresh, nms_thresh): | |
# Preprocess image | |
image_tensor, original_size = process_image(image) | |
image_tensor = image_tensor.to("cpu") | |
# Ensure the model is in evaluation mode | |
model.rpn.nms_thresh = rpn_nms_thresh | |
model.roi_heads.score_thresh = score_thresh | |
model.roi_heads.nms_thresh = nms_thresh | |
with torch.no_grad(): | |
prediction = model(image_tensor)[0] | |
# Invert normalization for display | |
mean = [0.485, 0.456, 0.406] | |
std = [0.229, 0.224, 0.225] | |
inverted_image = ( | |
invert_normalization(image_tensor, mean, std) | |
.squeeze(0) | |
.permute(1, 2, 0) | |
.cpu() | |
.numpy() | |
) | |
# Plot results with ellipses | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(inverted_image) | |
plot_ellipses( | |
get_ellipse_params_from_matrices(prediction["ellipse_matrices"]), | |
inverted_image, | |
plot_centers=True, | |
rim_color="red", | |
alpha=0.25, | |
) | |
red_patch = mpatches.Patch(color="red", label="Predicted") | |
plt.legend(handles=[red_patch], loc="upper right") | |
plt.gca().set_aspect(original_size[0] / original_size[1]) | |
plt.axis("off") | |
plt.tight_layout() | |
# Save the figure to a buffer and return as an image | |
buf = io.BytesIO() | |
plt.savefig(buf, format="png") | |
buf.seek(0) | |
with Image.open(buf) as output_image: | |
output_image = output_image.copy() | |
buf.close() | |
return output_image | |
# Define Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## Tree Detection from Satellite Images") | |
gr.Markdown( | |
"Upload an image and see the detected trees with ellipses. For better predictions, upload a high-resoltion image of orthophotomap with zoom level 18." | |
) | |
gr.Markdown( | |
"Try different values for RPN NMS Threshold, ROI Heads Score Threshold, and ROI Heads NMS Threshold to see how they affect the predictions." | |
) | |
with gr.Row(): | |
image_input = gr.Image(label="Input Image", type="pil") | |
image_output = gr.Image(label="Detected Trees") | |
examples = [ | |
["examples/image1.jpg"], | |
["examples/image2.jpg"], | |
["examples/image3.jpg"], | |
] | |
with gr.Row(): | |
rpn_nms_slider = gr.Slider( | |
0.0, 1.0, value=model.rpn.nms_thresh, label="RPN NMS Threshold" | |
) | |
score_thresh_slider = gr.Slider( | |
0.0, | |
1.0, | |
value=model.roi_heads.score_thresh, | |
label="ROI Heads Score Threshold", | |
) | |
nms_thresh_slider = gr.Slider( | |
0.0, 1.0, value=model.roi_heads.nms_thresh, label="ROI Heads NMS Threshold" | |
) | |
submit_button = gr.Button("Detect Trees") | |
submit_button.click( | |
fn=generate_prediction, | |
inputs=[image_input, rpn_nms_slider, score_thresh_slider, nms_thresh_slider], | |
outputs=image_output, | |
) | |
gr.Examples(examples=examples, inputs=image_input, outputs=image_output) | |
demo.launch() | |