Spaces:
Sleeping
Sleeping
from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset | |
class BaseManyViewDataset(BaseStereoViewDataset): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def sample_frames(self, img_idxs, rng): | |
num_frames = self.num_frames | |
thresh = int(self.min_thresh + self.train_ratio * (self.max_thresh - self.min_thresh)) | |
img_indices = list(range(len(img_idxs))) | |
selected_indices = [] | |
initial_valid_range = max(len(img_indices)//num_frames, len(img_indices) - thresh * (num_frames - 1)) | |
current_index = rng.choice(img_indices[:initial_valid_range]) | |
selected_indices.append(current_index) | |
while len(selected_indices) < num_frames: | |
next_min_index = current_index + 1 | |
next_max_index = min(current_index + thresh, len(img_indices) - (num_frames - len(selected_indices))) | |
possible_indices = [i for i in range(next_min_index, next_max_index + 1) if i not in selected_indices] | |
if not possible_indices: | |
break | |
current_index = rng.choice(possible_indices) | |
selected_indices.append(current_index) | |
if len(selected_indices) < num_frames: | |
return self.sample_frames(img_idxs, rng) | |
selected_img_ids = [img_idxs[i] for i in selected_indices] | |
if rng.choice([True, False]): | |
selected_img_ids.reverse() | |
return selected_img_ids | |
def sample_frame_idx(self, img_idxs, rng, full_video=False): | |
if not full_video: | |
img_idxs = self.sample_frames(img_idxs, rng) | |
else: | |
img_idxs = img_idxs[::self.kf_every] | |
return img_idxs | |