Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import math | |
import os | |
import random | |
from pathlib import Path | |
import numpy as np | |
import torch | |
from einops import rearrange | |
from torch.utils.data import Dataset as TorchDataset | |
from datasets.encode_openx_dataset import DATA_FREQ_TABLE | |
from genie.config import GenieConfig | |
from genie.st_mask_git import cosine_schedule | |
def normalize_actions(actions): | |
""" | |
compute mean and std of actions. Normalize actions is done inside the network. | |
""" | |
mean = np.mean(actions, axis=0).tolist() | |
std = np.std(actions, axis=0).tolist() | |
return actions, [mean, std] | |
class RawImageDataset(TorchDataset): | |
""" Loads raw uint8 tokens as memmap-backed array """ | |
def __init__( | |
self, | |
data_dir, | |
window_size, | |
stride=1, | |
filter_interrupts=True, | |
filter_overlaps=False, | |
use_actions=False, | |
max_traj_num=1000000, | |
compute_stride_from_freq_table=True, | |
natural_hz=2, | |
datio_noise_ratio=0.0, | |
domain=None, | |
): | |
""" | |
Args: | |
data_dir: directory with the same format as `data/train_v0` and `data/val_v0`. | |
Notably, has `video.bin` and `metadata.json` | |
window_size: number of frames per "video" sequence | |
stride: frame skip | |
filter_interrupts: Under 3% of training frame sequences are the concatenation of two different clips. | |
If filter_interrupts is True, will filter out these sequences using the segment ids. | |
filter_overlaps: If False (default), one frame will appear in multiple examples; | |
e.g. frame 0 might appear as the first frame in example 0 and also the second frame in example 15. | |
If True, will filter out examples so that each frame appears at most once in the dataset. | |
use_actions: If True, will load the actions from the `actions` folder for the models | |
""" | |
data_dir = Path(data_dir) | |
with open(data_dir / "metadata.json") as f: | |
self.metadata = json.load(f) | |
# TODO: assert not quantized in metadata | |
shape = (self.metadata["num_images"], self.metadata["h"], self.metadata["w"], 3) # | |
video_tokens_path, segment_ids_path, action_tokens_path = [data_dir / f"{name}.bin" | |
for name in ["video", "segment_ids", "actions"]] | |
token_dtype = np.dtype(self.metadata.get("token_dtype", "uint8")) | |
self.data = np.memmap(video_tokens_path, mode="r", shape=shape, dtype=token_dtype) | |
self.window_size, self.stride = window_size, stride | |
self.datio_noise_ratio = datio_noise_ratio | |
if domain is not None: # TODO: remove | |
self.name = domain | |
else: | |
self.name = self.metadata["name"] | |
if compute_stride_from_freq_table: | |
self.stride = max(DATA_FREQ_TABLE.get(self.name, 1) // natural_hz, 1) | |
self.n_action = self.metadata.get("action_dim", 1) * (self.stride) | |
# actions/ - a folder of action arrays stored in np.float32 format. For frame i, | |
# the corresponding action is given by joint_pos[i], driving_command[i], neck_desired[i] | |
if use_actions: | |
actions = [] | |
# hack here for the separations in the 1x datasets | |
for action_file in sorted((data_dir / "actions").iterdir()): | |
actions.append(np.memmap(action_file, dtype=np.float32, mode="r").reshape(len(self.data), -1)) | |
self.actions = np.concatenate(actions, axis=-1) | |
self.actions, self.action_stat = normalize_actions(self.actions) | |
if os.path.isfile(segment_ids_path): | |
self.segment_ids = np.memmap( | |
segment_ids_path, | |
dtype=np.int32, | |
mode="r", | |
shape=(self.metadata["num_images"],) | |
) | |
else: | |
self.segment_ids = None | |
if filter_interrupts: | |
raise NotImplementedError("Cannot filter interrupted sequences without segment ids.") | |
# Number of frames between the first and last frames of a video sequence (excluding one endpoint frame) | |
self.video_len = (self.window_size - 1) * self.stride | |
self.valid_start_inds = [] | |
for start_ind in range(len(self.data) - self.video_len - self.stride): | |
# Assuming `segment_ids` is monotonically increasing, a sequence is interrupted (or too short) | |
# if the first and last frames have different segment ids. | |
if not (filter_interrupts and self.segment_ids[start_ind] != self.segment_ids[start_ind + self.video_len]): | |
self.valid_start_inds.append(start_ind) | |
if len(self.valid_start_inds) >= max_traj_num: | |
break | |
if filter_overlaps: | |
# Instead of using a sliding window, use each frame at most once | |
filtered_start_inds = [] | |
for start_ind in self.valid_start_inds: | |
overlapping_start_inds = {start_ind - i * self.stride for i in range(1, self.window_size)} | |
# all sequences from `overlapping_start_inds` will also contain `start_ind`, | |
# so exclude sequence starting from `start_ind` if any of `overlapping_start_inds` is already being used | |
for existing_start_ind in filtered_start_inds[-self.window_size * self.stride:]: | |
# Bound could be improved | |
if existing_start_ind in overlapping_start_inds: | |
break | |
else: | |
filtered_start_inds.append(start_ind) | |
self.valid_start_inds = filtered_start_inds | |
print(f"Loaded {len(self)} sequences from {data_dir} {self.stride=} {self.window_size=} {self.n_action=}") | |
def __len__(self): | |
return len(self.valid_start_inds) | |
def __getitem__(self, idx): | |
""" | |
Returns a flattened sequence of tokens representing `self.window_size` frames, | |
spaced `self.stride` apart. | |
""" | |
start_ind = self.valid_start_inds[idx] | |
x = self.data[start_ind : start_ind + self.video_len + 1 : self.stride].copy() | |
x = torch.FloatTensor(x).float() | |
# reconstructions since the input ids and the labels are the same | |
attention_mask = torch.ones_like(x) | |
data_dict = { | |
"images": x, | |
"labels": x, # Do we need labels/attention mask? | |
"attention_mask": attention_mask, | |
"h": self.metadata["h"], | |
"w": self.metadata["w"], | |
} | |
if hasattr(self, "actions"): | |
# we want to have all actions within the stride to predict the next frame at the end of the stride | |
# we will concatenate the actions from [window_size, d_action] to [window_size, d_action * stride] | |
data_dict['action_ids'] = self.actions[start_ind:start_ind + self.video_len + self.stride].reshape(self.window_size, -1) | |
data_dict['action_ids'] = torch.from_numpy(data_dict['action_ids'].astype(np.float32)) | |
data_dict["domain"] = self.name | |
return data_dict | |