Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# @Author : xuelun | |
import os | |
from tqdm import tqdm | |
from argparse import ArgumentParser | |
from torch.utils.data import DataLoader | |
from datasets.walk import cfg | |
from datasets.walk.walk import WALKDataset | |
def propagate(loader, seq): | |
for i, _ in enumerate(tqdm( | |
loader, ncols=80, bar_format="{l_bar}{bar:3}{r_bar}", total=len(loader), | |
desc=f'[ {seq[:min(10, len(seq)-1)]:<10} ] [ {len(loader):<5} ]')): | |
continue | |
def init_dataset(seq_name_): | |
train_cfg = cfg.DATASET.TRAIN | |
base_input = { | |
'df': 8, | |
'mode': 'train', | |
'augment_fn': None, | |
'PROPAGATING': True, | |
'seq_name': seq_name_, | |
'max_resize': [1280, 720], | |
'padding': cfg.DATASET.TRAIN.PADDING, | |
'max_samples': cfg.DATASET.TRAIN.MAX_SAMPLES, | |
'min_overlap_score': cfg.DATASET.TRAIN.MIN_OVERLAP_SCORE, | |
'max_overlap_score': cfg.DATASET.TRAIN.MAX_OVERLAP_SCORE | |
} | |
cfg_input = { | |
k: getattr(train_cfg, k) | |
for k in [ | |
'DATA_ROOT', 'NPZ_ROOT', 'STEP', 'PIX_THR', 'FIX_MATCHES', 'SOURCE_ROOT', | |
'MAX_CANDIDATE_MATCHES', 'MIN_FINAL_MATCHES', 'MIN_FILTER_MATCHES', | |
'VIDEO_IMAGE_ROOT', 'PROPAGATE_ROOT', 'PSEUDO_LABELS' | |
] | |
} | |
# 合并配置 | |
input_ = { | |
**base_input, | |
**cfg_input, | |
'root_dir': cfg_input['DATA_ROOT'], | |
'npz_root': cfg_input['NPZ_ROOT'] | |
} | |
dataset = WALKDataset(**input_) | |
return dataset | |
# noinspection PyUnusedLocal | |
def collate_fn(batch): | |
return None | |
if __name__ == '__main__': | |
parser = ArgumentParser() | |
parser.add_argument('seq_names', type=str, nargs='+') | |
args = parser.parse_args() | |
if os.path.isfile(args.seq_names[0]): | |
with open(args.seq_names[0], 'r') as f: | |
seq_names = [line.strip() for line in f.readlines()] | |
else: | |
seq_names = args.seq_names | |
for seq_name in seq_names: | |
dataset_ = init_dataset(seq_name) | |
loader_params = {'batch_size': 1, 'shuffle': False, 'num_workers': 3, | |
'pin_memory': True, 'drop_last': False} | |
loader_ = DataLoader(dataset_, collate_fn=collate_fn, **loader_params) | |
propagate(loader_, seq_name) | |