Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import warnings | |
import gradio as gr | |
from models.utils.torch_geometry import get_perspective_transform, warp_perspective | |
warnings.filterwarnings("ignore") | |
def get_BEV_kitti(front_img, fov, pitch, scale, out_size): | |
Hp, Wp = front_img.shape[:2] | |
Wo,Ho = int(Wp*scale),int(Wp*scale) | |
fov = fov *torch.pi/180 # | |
theta = pitch*torch.pi/180 # Camera pitch angle | |
f = Hp/2/torch.tan(torch.tensor(fov)) | |
phi = torch.pi/2 - fov | |
delta = torch.pi/2+theta - torch.tensor(phi) | |
l = torch.sqrt(f**2+(Hp/2)**2) | |
h = l*torch.sin(delta) | |
f_ = l*torch.cos(delta) | |
###################### | |
frame = torch.from_numpy(front_img).to(device) | |
out = torch.zeros((2, 2,2)).to(device) | |
y = (torch.ones((2, 2)).to(device).T *(torch.arange(0,Ho, step=Ho-1)).to(device)).T | |
x = torch.ones((2, 2)).to(device) *torch.arange(0, Wo, step=Wo-1).to(device) | |
l0 = torch.ones((2, 2)).to(device)*Ho - y | |
l1 = torch.ones((2, 2)).to(device) * f_+ l0 | |
f1_0 = torch.arctan(h/l1) | |
f1_1 = torch.ones((2, 2)).to(device)*(torch.pi/2+theta) - f1_0 | |
y_ = l0*torch.sin(f1_0)/torch.sin(f1_1) | |
j_p = torch.ones((2, 2)).to(device) * Hp - y_ | |
i_p = torch.ones((2, 2)).to(device) * Wp/2 -(f_+torch.sin(torch.tensor(theta))*(torch.ones((2, 2)).to(device)*Hp-j_p))*(Wo/2*torch.ones((2, 2)).to(device)-x)/l1 | |
out[:,:,0] = i_p.reshape((2, 2)) | |
out[:,:,1] = j_p.reshape((2, 2)) | |
four_point_org = out.permute(2,0,1) | |
four_point_new = torch.stack((x,y), dim = -1).permute(2,0,1) | |
four_point_org = four_point_org.unsqueeze(0).flatten(2).permute(0, 2, 1) | |
four_point_new = four_point_new.unsqueeze(0).flatten(2).permute(0, 2, 1) | |
H = get_perspective_transform(four_point_org, four_point_new) | |
scale1,scale2 = out_size/Wo,out_size/Ho | |
T3 = np.array([[scale1, 0, 0], [0, scale2, 0], [0, 0, 1]]) | |
Homo = torch.matmul(torch.tensor(T3).unsqueeze(0).to(device).float(), H) | |
BEV = warp_perspective(frame.permute(2,0,1).unsqueeze(0).float(), Homo, (out_size,out_size)) | |
BEV = BEV[0].cpu().int().permute(1,2,0).numpy().astype(np.uint8) | |
return BEV | |
def KittiBEV(): | |
torch.cuda.empty_cache() | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# HC-Net: Fine-Grained Cross-View Geo-Localization Using a Correlation-Aware Homography Estimator | |
## Get BEV from front-view image. | |
""") | |
with gr.Row(): | |
front_img = gr.Image(label="Front-view Image").style(height=450) | |
BEV_output = gr.Image(label="BEV Image").style(height=450) | |
fov = gr.Slider(1,90, value=20, label="FOV") | |
pitch = gr.Slider(-180, 180, value=0, label="Pitch") | |
scale = gr.Slider(1, 10, value=1.0, label="Scale") | |
out_size = gr.Slider(500, 1000, value=500, label="Out size") | |
btn = gr.Button(value="Get BEV Image") | |
btn.click(get_BEV_kitti,inputs= [front_img, fov, pitch, scale, out_size], outputs=BEV_output, queue=False) | |
gr.Markdown( | |
""" | |
### Note: | |
- If you wish to acquire **quantitative localization error results** for your uploaded data, kindly supply the real GPS for the ground image as well as the corresponding GPS for the center of the satellite image. | |
- When inputting GPS coordinates, please make sure their precision extends to **at least six decimal places**. | |
""") | |
gr.Markdown("## Image Examples") | |
gr.Examples( | |
examples=[['./figure/exp1.jpg', 27, 7, 6, 1000], | |
['./figure/exp2.jpg', 17.5, 0.8, 4, 1000]], | |
inputs= [front_img, fov, pitch, scale, out_size], | |
outputs=[BEV_output], | |
fn=get_BEV_kitti, | |
cache_examples=False, | |
) | |
demo.launch(server_port=7981) | |
if __name__ == '__main__': | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
KittiBEV() |