Spaces:
Build error
Build error
from __future__ import annotations | |
import os | |
import torch | |
from .common import BaseDataset | |
def balance_with_actions(samples, increase_factor=5, exceptions=None): | |
if exceptions is None: | |
exceptions = [2, 3] | |
sample_to_add = list() | |
if increase_factor > 1: | |
for each_sample in samples: | |
if each_sample["cmd"] not in exceptions: | |
for _ in range(increase_factor - 1): | |
sample_to_add.append(each_sample) | |
return samples + sample_to_add | |
def resample_complete_samples(samples, increase_factor=5): | |
sample_to_add = list() | |
if increase_factor > 1: | |
for each_sample in samples: | |
if (each_sample["speed"] and each_sample["angle"] and each_sample["z"] > 0 | |
and 0 < each_sample["goal"][0] < 1600 and 0 < each_sample["goal"][1] < 900): | |
for _ in range(increase_factor - 1): | |
sample_to_add.append(each_sample) | |
return samples + sample_to_add | |
class NuScenesDataset(BaseDataset): | |
def __init__(self, data_root="data/nuscenes", anno_file="annos/nuScenes.json", | |
target_height=320, target_width=576, num_frames=25): | |
if not os.path.exists(data_root): | |
raise ValueError(f"Cannot find dataset {data_root}") | |
if not os.path.exists(anno_file): | |
raise ValueError(f"Cannot find annotation {anno_file}") | |
super().__init__(data_root, anno_file, target_height, target_width, num_frames) | |
print("nuScenes loaded:", len(self)) | |
self.samples = balance_with_actions(self.samples, increase_factor=5) | |
print("nuScenes balanced:", len(self)) | |
self.samples = resample_complete_samples(self.samples, increase_factor=2) | |
print("nuScenes resampled:", len(self)) | |
self.action_mod = 0 | |
def get_image_path(self, sample_dict, current_index): | |
return os.path.join(self.data_root, sample_dict["frames"][current_index]) | |
def build_data_dict(self, image_seq, sample_dict): | |
# log_cond_aug = self.log_cond_aug_dist.sample() | |
# cond_aug = torch.exp(log_cond_aug) | |
cond_aug = torch.tensor([0.0]) | |
data_dict = { | |
"img_seq": torch.stack(image_seq), | |
"motion_bucket_id": torch.tensor([127]), | |
"fps_id": torch.tensor([9]), | |
"cond_frames_without_noise": image_seq[0], | |
"cond_frames": image_seq[0] + cond_aug * torch.randn_like(image_seq[0]), | |
"cond_aug": cond_aug | |
} | |
if self.action_mod == 0: | |
data_dict["trajectory"] = torch.tensor(sample_dict["traj"][2:]) | |
elif self.action_mod == 1: | |
data_dict["command"] = torch.tensor(sample_dict["cmd"]) | |
elif self.action_mod == 2: | |
# scene might be empty | |
if sample_dict["speed"]: | |
data_dict["speed"] = torch.tensor(sample_dict["speed"][1:]) | |
# scene might be empty | |
if sample_dict["angle"]: | |
data_dict["angle"] = torch.tensor(sample_dict["angle"][1:]) / 780 | |
elif self.action_mod == 3: | |
# point might be invalid | |
if sample_dict["z"] > 0 and 0 < sample_dict["goal"][0] < 1600 and 0 < sample_dict["goal"][1] < 900: | |
data_dict["goal"] = torch.tensor([ | |
sample_dict["goal"][0] / 1600, | |
sample_dict["goal"][1] / 900 | |
]) | |
else: | |
raise ValueError | |
return data_dict | |
def __getitem__(self, index): | |
sample_dict = self.samples[index] | |
self.action_mod = (self.action_mod + index) % 4 | |
image_seq = list() | |
for i in range(self.num_frames): | |
current_index = i | |
img_path = self.get_image_path(sample_dict, current_index) | |
image = self.preprocess_image(img_path) | |
image_seq.append(image) | |
return self.build_data_dict(image_seq, sample_dict) | |