Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.data import Dataset | |
import os | |
import cv2 | |
# @Time : 2023-02-13 22:56 | |
# @Author : Wang Zhen | |
# @Email : [email protected] | |
# @File : SatelliteTool.py | |
# @Project : TGRS_seqmatch_2023_1 | |
import numpy as np | |
import random | |
from utils.geo import BoundaryBox, Projection | |
from osm.tiling import TileManager,MapTileManager | |
from pathlib import Path | |
from torchvision import transforms | |
from torch.utils.data import DataLoader | |
class UavMapPair(Dataset): | |
def __init__( | |
self, | |
root: Path, | |
city:str, | |
training:bool, | |
transform | |
): | |
super().__init__() | |
# self.root = root | |
# city = 'Manhattan' | |
# root = '/root/DATASET/CrossModel/' | |
# root=Path(root) | |
self.uav_image_path = root/city/'uav' | |
self.map_path = root/city/'map' | |
self.map_vis = root / city / 'map_vis' | |
info_path = root / city / 'info.csv' | |
self.info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1) | |
self.transform=transform | |
self.training=training | |
def random_center_crop(self,image): | |
height, width = image.shape[:2] | |
# 随机生成剪裁尺寸 | |
crop_size = random.randint(min(height, width) // 2, min(height, width)) | |
# 计算剪裁的起始坐标 | |
start_x = (width - crop_size) // 2 | |
start_y = (height - crop_size) // 2 | |
# 进行剪裁 | |
cropped_image = image[start_y:start_y + crop_size, start_x:start_x + crop_size] | |
return cropped_image | |
def __getitem__(self, index: int): | |
id, uav_name, map_name, \ | |
uav_long, uav_lat, \ | |
map_long, map_lat, \ | |
tile_size_meters, pixel_per_meter, \ | |
u, v, yaw,dis=self.info[index] | |
uav_image=cv2.imread(str(self.uav_image_path/uav_name)) | |
if self.training: | |
uav_image =self.random_center_crop(uav_image) | |
uav_image=cv2.cvtColor(uav_image,cv2.COLOR_BGR2RGB) | |
if self.transform: | |
uav_image=self.transform(uav_image) | |
map=np.load(str(self.map_path/map_name)) | |
return { | |
'map':torch.from_numpy(np.ascontiguousarray(map)).long(), | |
'image':torch.tensor(uav_image), | |
'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float(), | |
'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float(), | |
"uv":torch.tensor([float(u), float(v)]).float(), | |
} | |
def __len__(self): | |
return len(self.info) | |
if __name__ == '__main__': | |
root=Path('/root/DATASET/OrienterNet/UavMap/') | |
city='NewYork' | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize(256), | |
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
]) | |
dataset=UavMapPair( | |
root=root, | |
city=city, | |
transform=transform | |
) | |
datasetloder = DataLoader(dataset, batch_size=3) | |
for batch, i in enumerate(datasetloder): | |
pass | |
# 将PyTorch张量转换为PIL图像 | |
# pil_image = Image.fromarray(i['uav_image'][0].permute(1, 2, 0).byte().numpy()) | |
# 显示图像 | |
# 将PyTorch张量转换为NumPy数组 | |
# numpy_array = i['uav_image'][0].numpy() | |
# | |
# # 显示图像 | |
# plt.imshow(numpy_array.transpose(1, 2, 0)) | |
# plt.axis('off') | |
# plt.show() | |
# | |
# map_viz, label = Colormap.apply(i['map'][0]) | |
# map_viz = map_viz * 255 | |
# map_viz = map_viz.astype(np.uint8) | |
# plot_images([map_viz], titles=["OpenStreetMap raster"]) | |