HC-Net / app.py
wxl
add get bev for front view image
93bde9f
raw
history blame
4.03 kB
import numpy as np
import torch
import scipy.io as io
import numpy as np
import warnings
import torch.nn.functional as F
import gradio as gr
import torchgeometry as tgm
from models.utils.torch_geometry import get_perspective_transform, warp_perspective
warnings.filterwarnings("ignore")
def get_BEV_kitti(front_img, fov, pitch, scale, out_size):
Hp, Wp = front_img.shape[:2]
Wo,Ho = int(Wp*scale),int(Wp*scale)
fov = fov *torch.pi/180 #
theta = pitch*torch.pi/180 # Camera pitch angle
f = Hp/2/torch.tan(torch.tensor(fov))
phi = torch.pi/2 - fov
delta = torch.pi/2+theta - torch.tensor(phi)
l = torch.sqrt(f**2+(Hp/2)**2)
h = l*torch.sin(delta)
f_ = l*torch.cos(delta)
######################
frame = torch.from_numpy(front_img).to(device)
out = torch.zeros((2, 2,2)).to(device)
y = (torch.ones((2, 2)).to(device).T *(torch.arange(0,Ho, step=Ho-1)).to(device)).T
x = torch.ones((2, 2)).to(device) *torch.arange(0, Wo, step=Wo-1).to(device)
l0 = torch.ones((2, 2)).to(device)*Ho - y
l1 = torch.ones((2, 2)).to(device) * f_+ l0
f1_0 = torch.arctan(h/l1)
f1_1 = torch.ones((2, 2)).to(device)*(torch.pi/2+theta) - f1_0
y_ = l0*torch.sin(f1_0)/torch.sin(f1_1)
j_p = torch.ones((2, 2)).to(device) * Hp - y_
i_p = torch.ones((2, 2)).to(device) * Wp/2 -(f_+torch.sin(torch.tensor(theta))*(torch.ones((2, 2)).to(device)*Hp-j_p))*(Wo/2*torch.ones((2, 2)).to(device)-x)/l1
out[:,:,0] = i_p.reshape((2, 2))
out[:,:,1] = j_p.reshape((2, 2))
four_point_org = out.permute(2,0,1)
four_point_new = torch.stack((x,y), dim = -1).permute(2,0,1)
four_point_org = four_point_org.unsqueeze(0).flatten(2).permute(0, 2, 1)
four_point_new = four_point_new.unsqueeze(0).flatten(2).permute(0, 2, 1)
H = get_perspective_transform(four_point_org, four_point_new)
scale1,scale2 = out_size/Wo,out_size/Ho
T3 = np.array([[scale1, 0, 0], [0, scale2, 0], [0, 0, 1]])
Homo = torch.matmul(torch.tensor(T3).unsqueeze(0).to(device).float(), H)
BEV = warp_perspective(frame.permute(2,0,1).unsqueeze(0).float(), Homo, (out_size,out_size))
BEV = BEV[0].cpu().int().permute(1,2,0).numpy().astype(np.uint8)
return BEV
@torch.no_grad()
def KittiBEV():
torch.cuda.empty_cache()
with gr.Blocks() as demo:
gr.Markdown(
"""
# HC-Net: Fine-Grained Cross-View Geo-Localization Using a Correlation-Aware Homography Estimator
## Get BEV from front-view image.
""")
with gr.Row():
front_img = gr.Image(label="Front-view Image").style(height=450)
BEV_output = gr.Image(label="BEV Image").style(height=450)
fov = gr.Slider(1,90, value=20, label="FOV")
pitch = gr.Slider(-180, 180, value=0, label="Pitch")
scale = gr.Slider(1, 10, value=1.0, label="Scale")
out_size = gr.Slider(500, 1000, value=500, label="Out size")
btn = gr.Button(value="Get BEV Image")
btn.click(get_BEV_kitti,inputs= [front_img, fov, pitch, scale, out_size], outputs=BEV_output, queue=False)
gr.Markdown(
"""
### Note:
- If you wish to acquire **quantitative localization error results** for your uploaded data, kindly supply the real GPS for the ground image as well as the corresponding GPS for the center of the satellite image.
- When inputting GPS coordinates, please make sure their precision extends to **at least six decimal places**.
""")
gr.Markdown("## Image Examples")
gr.Examples(
examples=[['./figure/exp1.jpg', 27, 7, 6, 1000]],
inputs= [front_img, fov, pitch, scale, out_size],
outputs=[BEV_output],
fn=get_BEV_kitti,
cache_examples=False,
)
demo.launch(server_port=7981)
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
KittiBEV()