Stable-X's picture
Upload folder using huggingface_hub
e4bf056 verified
raw
history blame
2.07 kB
from .arkit import ArkitScene
from .blendedmvs import BlendMVS
from .co3d import Co3d
from .habitat import habitat
from .scannet import Scannet
from .scannetpp import Scannetpp
from .seven_scenes import SevenScenes
from .nrgbd import NRGBD
from .dtu import DTU
from .demo import Demo
from dust3r.datasets.utils.transforms import *
def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True):
import torch
from croco.utils.misc import get_world_size, get_rank
# pytorch dataset
if isinstance(dataset, str):
dataset = eval(dataset)
world_size = get_world_size()
rank = get_rank()
try:
sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size,
rank=rank, drop_last=drop_last)
except (AttributeError, NotImplementedError):
# not avail for this dataset
if torch.distributed.is_initialized():
sampler = torch.utils.data.DistributedSampler(
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last
)
elif shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
data_loader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_mem,
drop_last=drop_last,
)
return data_loader
def build_dataset(dataset, batch_size, num_workers, test=False):
split = ['Train', 'Test'][test]
print(f'Building {split} Data loader for dataset: ', dataset)
loader = get_data_loader(dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_mem=True,
shuffle=not (test),
drop_last=not (test))
print(f"{split} dataset length: ", len(loader))
return loader