Spaces:
Build error
Build error
File size: 3,914 Bytes
d323598 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from __future__ import annotations
import os
import torch
from .common import BaseDataset
def balance_with_actions(samples, increase_factor=5, exceptions=None):
if exceptions is None:
exceptions = [2, 3]
sample_to_add = list()
if increase_factor > 1:
for each_sample in samples:
if each_sample["cmd"] not in exceptions:
for _ in range(increase_factor - 1):
sample_to_add.append(each_sample)
return samples + sample_to_add
def resample_complete_samples(samples, increase_factor=5):
sample_to_add = list()
if increase_factor > 1:
for each_sample in samples:
if (each_sample["speed"] and each_sample["angle"] and each_sample["z"] > 0
and 0 < each_sample["goal"][0] < 1600 and 0 < each_sample["goal"][1] < 900):
for _ in range(increase_factor - 1):
sample_to_add.append(each_sample)
return samples + sample_to_add
class NuScenesDataset(BaseDataset):
def __init__(self, data_root="data/nuscenes", anno_file="annos/nuScenes.json",
target_height=320, target_width=576, num_frames=25):
if not os.path.exists(data_root):
raise ValueError(f"Cannot find dataset {data_root}")
if not os.path.exists(anno_file):
raise ValueError(f"Cannot find annotation {anno_file}")
super().__init__(data_root, anno_file, target_height, target_width, num_frames)
print("nuScenes loaded:", len(self))
self.samples = balance_with_actions(self.samples, increase_factor=5)
print("nuScenes balanced:", len(self))
self.samples = resample_complete_samples(self.samples, increase_factor=2)
print("nuScenes resampled:", len(self))
self.action_mod = 0
def get_image_path(self, sample_dict, current_index):
return os.path.join(self.data_root, sample_dict["frames"][current_index])
def build_data_dict(self, image_seq, sample_dict):
# log_cond_aug = self.log_cond_aug_dist.sample()
# cond_aug = torch.exp(log_cond_aug)
cond_aug = torch.tensor([0.0])
data_dict = {
"img_seq": torch.stack(image_seq),
"motion_bucket_id": torch.tensor([127]),
"fps_id": torch.tensor([9]),
"cond_frames_without_noise": image_seq[0],
"cond_frames": image_seq[0] + cond_aug * torch.randn_like(image_seq[0]),
"cond_aug": cond_aug
}
if self.action_mod == 0:
data_dict["trajectory"] = torch.tensor(sample_dict["traj"][2:])
elif self.action_mod == 1:
data_dict["command"] = torch.tensor(sample_dict["cmd"])
elif self.action_mod == 2:
# scene might be empty
if sample_dict["speed"]:
data_dict["speed"] = torch.tensor(sample_dict["speed"][1:])
# scene might be empty
if sample_dict["angle"]:
data_dict["angle"] = torch.tensor(sample_dict["angle"][1:]) / 780
elif self.action_mod == 3:
# point might be invalid
if sample_dict["z"] > 0 and 0 < sample_dict["goal"][0] < 1600 and 0 < sample_dict["goal"][1] < 900:
data_dict["goal"] = torch.tensor([
sample_dict["goal"][0] / 1600,
sample_dict["goal"][1] / 900
])
else:
raise ValueError
return data_dict
def __getitem__(self, index):
sample_dict = self.samples[index]
self.action_mod = (self.action_mod + index) % 4
image_seq = list()
for i in range(self.num_frames):
current_index = i
img_path = self.get_image_path(sample_dict, current_index)
image = self.preprocess_image(img_path)
image_seq.append(image)
return self.build_data_dict(image_seq, sample_dict)
|