|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import gzip |
|
import torch |
|
import numpy as np |
|
import torch.utils.data as data |
|
from collections import defaultdict |
|
from dataclasses import dataclass |
|
from typing import List, Optional, Any, Dict, Tuple |
|
|
|
from cotracker.datasets.utils import CoTrackerData |
|
from cotracker.datasets.dataclass_utils import load_dataclass |
|
|
|
|
|
@dataclass |
|
class ImageAnnotation: |
|
|
|
path: str |
|
|
|
size: Tuple[int, int] |
|
|
|
|
|
@dataclass |
|
class DynamicReplicaFrameAnnotation: |
|
"""A dataclass used to load annotations from json.""" |
|
|
|
|
|
sequence_name: str |
|
|
|
frame_number: int |
|
|
|
frame_timestamp: float |
|
|
|
image: ImageAnnotation |
|
meta: Optional[Dict[str, Any]] = None |
|
|
|
camera_name: Optional[str] = None |
|
trajectories: Optional[str] = None |
|
|
|
|
|
class DynamicReplicaDataset(data.Dataset): |
|
def __init__( |
|
self, |
|
root, |
|
split="valid", |
|
traj_per_sample=256, |
|
crop_size=None, |
|
sample_len=-1, |
|
only_first_n_samples=-1, |
|
rgbd_input=False, |
|
): |
|
super(DynamicReplicaDataset, self).__init__() |
|
self.root = root |
|
self.sample_len = sample_len |
|
self.split = split |
|
self.traj_per_sample = traj_per_sample |
|
self.rgbd_input = rgbd_input |
|
self.crop_size = crop_size |
|
frame_annotations_file = f"frame_annotations_{split}.jgz" |
|
self.sample_list = [] |
|
with gzip.open( |
|
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8" |
|
) as zipfile: |
|
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation]) |
|
seq_annot = defaultdict(list) |
|
for frame_annot in frame_annots_list: |
|
if frame_annot.camera_name == "left": |
|
seq_annot[frame_annot.sequence_name].append(frame_annot) |
|
|
|
for seq_name in seq_annot.keys(): |
|
seq_len = len(seq_annot[seq_name]) |
|
|
|
step = self.sample_len if self.sample_len > 0 else seq_len |
|
counter = 0 |
|
|
|
for ref_idx in range(0, seq_len, step): |
|
sample = seq_annot[seq_name][ref_idx : ref_idx + step] |
|
self.sample_list.append(sample) |
|
counter += 1 |
|
if only_first_n_samples > 0 and counter >= only_first_n_samples: |
|
break |
|
|
|
def __len__(self): |
|
return len(self.sample_list) |
|
|
|
def crop(self, rgbs, trajs): |
|
T, N, _ = trajs.shape |
|
|
|
S = len(rgbs) |
|
H, W = rgbs[0].shape[:2] |
|
assert S == T |
|
|
|
H_new = H |
|
W_new = W |
|
|
|
|
|
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2 |
|
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2 |
|
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs] |
|
|
|
trajs[:, :, 0] -= x0 |
|
trajs[:, :, 1] -= y0 |
|
|
|
return rgbs, trajs |
|
|
|
def __getitem__(self, index): |
|
sample = self.sample_list[index] |
|
T = len(sample) |
|
rgbs, visibilities, traj_2d = [], [], [] |
|
|
|
H, W = sample[0].image.size |
|
image_size = (H, W) |
|
|
|
for i in range(T): |
|
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"]) |
|
traj = torch.load(traj_path) |
|
|
|
visibilities.append(traj["verts_inds_vis"].numpy()) |
|
|
|
rgbs.append(traj["img"].numpy()) |
|
traj_2d.append(traj["traj_2d"].numpy()[..., :2]) |
|
|
|
traj_2d = np.stack(traj_2d) |
|
visibility = np.stack(visibilities) |
|
T, N, D = traj_2d.shape |
|
|
|
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample] |
|
|
|
traj_2d = traj_2d[:, visible_inds_sampled] |
|
visibility = visibility[:, visible_inds_sampled] |
|
|
|
if self.crop_size is not None: |
|
rgbs, traj_2d = self.crop(rgbs, traj_2d) |
|
H, W, _ = rgbs[0].shape |
|
image_size = self.crop_size |
|
|
|
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False |
|
visibility[traj_2d[:, :, 0] < 0] = False |
|
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False |
|
visibility[traj_2d[:, :, 1] < 0] = False |
|
|
|
|
|
visible_inds_resampled = visibility.sum(0) > 10 |
|
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled]) |
|
visibility = torch.from_numpy(visibility[:, visible_inds_resampled]) |
|
|
|
rgbs = np.stack(rgbs, 0) |
|
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float() |
|
return CoTrackerData( |
|
video=video, |
|
trajectory=traj_2d, |
|
visibility=visibility, |
|
valid=torch.ones(T, N), |
|
seq_name=sample[0].sequence_name, |
|
) |
|
|