|
import io |
|
from ast import mod |
|
import gradio as gr |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as mpatches |
|
import torchvision.transforms as transforms |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from ellipse_rcnn import EllipseRCNN |
|
|
|
|
|
|
|
load_state_dict = torch.load( |
|
hf_hub_download("Filipstrozik/sat-tree-detection-v0", "model.pth"), |
|
weights_only=True, |
|
) |
|
model = EllipseRCNN() |
|
|
|
model.load_state_dict(load_state_dict) |
|
model.eval() |
|
|
|
|
|
def conic_center(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
"""Returns center of ellipse in 2D cartesian coordinate system with numerical stability.""" |
|
|
|
A = conic_matrix[..., :2, :2] |
|
|
|
|
|
A_pinv = torch.linalg.pinv(A, rcond=torch.finfo(A.dtype).eps) |
|
|
|
|
|
b = -conic_matrix[..., :2, 2][..., None] |
|
|
|
|
|
centers = torch.matmul(A_pinv, b).squeeze() |
|
|
|
return centers[..., 0], centers[..., 1] |
|
|
|
|
|
def ellipse_axes(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
"""Returns semi-major and semi-minor axes of ellipse in 2D cartesian coordinate system.""" |
|
lambdas = ( |
|
torch.linalg.eigvalsh(conic_matrix[..., :2, :2]) |
|
/ (-torch.det(conic_matrix) / torch.det(conic_matrix[..., :2, :2]))[..., None] |
|
) |
|
axes = torch.sqrt(1 / lambdas) |
|
return axes[..., 0], axes[..., 1] |
|
|
|
|
|
def ellipse_angle(conic_matrix: torch.Tensor) -> torch.Tensor: |
|
"""Returns angle of ellipse in radians w.r.t. x-axis.""" |
|
return ( |
|
-torch.atan2( |
|
2 * conic_matrix[..., 1, 0], |
|
conic_matrix[..., 1, 1] - conic_matrix[..., 0, 0], |
|
) |
|
/ 2 |
|
) |
|
|
|
|
|
def get_ellipse_params_from_matrices(ellipse_matrices): |
|
if ellipse_matrices.shape[0] == 0: |
|
return None |
|
a, b = ellipse_axes(ellipse_matrices) |
|
cx, cy = conic_center(ellipse_matrices) |
|
theta = ellipse_angle(ellipse_matrices) |
|
|
|
a = a.view(-1) |
|
b = b.view(-1) |
|
cx = cx.view(-1) |
|
cy = cy.view(-1) |
|
theta = theta.view(-1) |
|
|
|
ellipses = torch.stack([a, b, cx, cy, theta], dim=1).reshape(-1, 5) |
|
return ellipses |
|
|
|
|
|
def plot_ellipses( |
|
ellipse_params: torch.Tensor, |
|
image: torch.Tensor, |
|
plot_centers: bool = False, |
|
rim_color: str = "r", |
|
alpha: float = 0.25, |
|
) -> None: |
|
if ellipse_params is None: |
|
return |
|
a, b, cx, cy, theta = ellipse_params.unbind(-1) |
|
|
|
|
|
cx = cx * 4 |
|
cy = cy * 4 |
|
|
|
|
|
for i in range(len(a)): |
|
ellipse = mpatches.Ellipse( |
|
(cx[i], cy[i]), |
|
width=a[i], |
|
height=b[i], |
|
angle=theta[i], |
|
fill=True, |
|
alpha=alpha, |
|
color=rim_color, |
|
) |
|
plt.gca().add_patch(ellipse) |
|
|
|
if plot_centers: |
|
plt.scatter(cx[i], cy[i], c=rim_color, s=10) |
|
|
|
plt.imshow(image) |
|
|
|
|
|
|
|
def invert_normalization(image, mean, std): |
|
for t, m, s in zip(image, mean, std): |
|
t.mul_(s).add_(m) |
|
return torch.clamp(image, 0, 1) |
|
|
|
|
|
def process_image(image): |
|
original_size = image.size |
|
|
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize((1024, 1024)), |
|
transforms.PILToTensor(), |
|
transforms.ConvertImageDtype(torch.float), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
image_tensor = transform(image).unsqueeze(0) |
|
return image_tensor, original_size |
|
|
|
|
|
def generate_prediction(image, rpn_nms_thresh, score_thresh, nms_thresh): |
|
|
|
image_tensor, original_size = process_image(image) |
|
image_tensor = image_tensor.to("cpu") |
|
|
|
|
|
model.rpn.nms_thresh = rpn_nms_thresh |
|
model.roi_heads.score_thresh = score_thresh |
|
model.roi_heads.nms_thresh = nms_thresh |
|
|
|
with torch.no_grad(): |
|
prediction = model(image_tensor)[0] |
|
|
|
|
|
mean = [0.485, 0.456, 0.406] |
|
std = [0.229, 0.224, 0.225] |
|
inverted_image = ( |
|
invert_normalization(image_tensor, mean, std) |
|
.squeeze(0) |
|
.permute(1, 2, 0) |
|
.cpu() |
|
.numpy() |
|
) |
|
|
|
|
|
plt.figure(figsize=(10, 10)) |
|
plt.imshow(inverted_image) |
|
plot_ellipses( |
|
get_ellipse_params_from_matrices(prediction["ellipse_matrices"]), |
|
inverted_image, |
|
plot_centers=True, |
|
rim_color="red", |
|
alpha=0.25, |
|
) |
|
red_patch = mpatches.Patch(color="red", label="Predicted") |
|
plt.legend(handles=[red_patch], loc="upper right") |
|
plt.gca().set_aspect(original_size[0] / original_size[1]) |
|
plt.axis("off") |
|
plt.tight_layout() |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png") |
|
buf.seek(0) |
|
with Image.open(buf) as output_image: |
|
output_image = output_image.copy() |
|
buf.close() |
|
return output_image |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Tree Detection from Satellite Images") |
|
gr.Markdown("Upload an image and see the detected trees with ellipses.") |
|
|
|
with gr.Row(): |
|
image_input = gr.Image(label="Input Image", type="pil") |
|
image_output = gr.Image(label="Detected Trees") |
|
|
|
examples = [ |
|
["examples/image1.jpg"], |
|
["examples/image2.jpg"], |
|
["examples/image3.jpg"], |
|
] |
|
|
|
with gr.Row(): |
|
rpn_nms_slider = gr.Slider( |
|
0.0, 1.0, value=model.rpn.nms_thresh, label="RPN NMS Threshold" |
|
) |
|
score_thresh_slider = gr.Slider( |
|
0.0, |
|
1.0, |
|
value=model.roi_heads.score_thresh, |
|
label="ROI Heads Score Threshold", |
|
) |
|
nms_thresh_slider = gr.Slider( |
|
0.0, 1.0, value=model.roi_heads.nms_thresh, label="ROI Heads NMS Threshold" |
|
) |
|
|
|
submit_button = gr.Button("Detect Trees") |
|
submit_button.click( |
|
fn=generate_prediction, |
|
inputs=[image_input, rpn_nms_slider, score_thresh_slider, nms_thresh_slider], |
|
outputs=image_output, |
|
) |
|
|
|
gr.Examples(examples=examples, inputs=image_input, outputs=image_output) |
|
|
|
|
|
demo.launch() |
|
|