Spaces:
Running
Running
#!/usr/bin/env python | |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import contextlib | |
import importlib.resources | |
import json | |
import logging | |
from collections.abc import Iterator | |
from itertools import accumulate | |
from pathlib import Path | |
from pprint import pformat | |
from types import SimpleNamespace | |
from typing import Any | |
import datasets | |
import jsonlines | |
import numpy as np | |
import packaging.version | |
import torch | |
from datasets.table import embed_table_storage | |
from huggingface_hub import DatasetCard, DatasetCardData, HfApi | |
from huggingface_hub.errors import RevisionNotFoundError | |
from PIL import Image as PILImage | |
from torchvision import transforms | |
from lerobot.common.datasets.backward_compatibility import ( | |
V21_MESSAGE, | |
BackwardCompatibilityError, | |
ForwardCompatibilityError, | |
) | |
from lerobot.common.robot_devices.robots.utils import Robot | |
from lerobot.common.utils.utils import is_valid_numpy_dtype_string | |
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature | |
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk | |
INFO_PATH = "meta/info.json" | |
EPISODES_PATH = "meta/episodes.jsonl" | |
STATS_PATH = "meta/stats.json" | |
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" | |
TASKS_PATH = "meta/tasks.jsonl" | |
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" | |
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" | |
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" | |
DATASET_CARD_TEMPLATE = """ | |
--- | |
# Metadata will go there | |
--- | |
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot). | |
## {} | |
""" | |
DEFAULT_FEATURES = { | |
"timestamp": {"dtype": "float32", "shape": (1,), "names": None}, | |
"frame_index": {"dtype": "int64", "shape": (1,), "names": None}, | |
"episode_index": {"dtype": "int64", "shape": (1,), "names": None}, | |
"index": {"dtype": "int64", "shape": (1,), "names": None}, | |
"task_index": {"dtype": "int64", "shape": (1,), "names": None}, | |
} | |
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: | |
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. | |
For example: | |
``` | |
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}` | |
>>> print(flatten_dict(dct)) | |
{"a/b": 1, "a/c/d": 2, "e": 3} | |
""" | |
items = [] | |
for k, v in d.items(): | |
new_key = f"{parent_key}{sep}{k}" if parent_key else k | |
if isinstance(v, dict): | |
items.extend(flatten_dict(v, new_key, sep=sep).items()) | |
else: | |
items.append((new_key, v)) | |
return dict(items) | |
def unflatten_dict(d: dict, sep: str = "/") -> dict: | |
outdict = {} | |
for key, value in d.items(): | |
parts = key.split(sep) | |
d = outdict | |
for part in parts[:-1]: | |
if part not in d: | |
d[part] = {} | |
d = d[part] | |
d[parts[-1]] = value | |
return outdict | |
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any: | |
split_keys = flattened_key.split(sep) | |
getter = obj[split_keys[0]] | |
if len(split_keys) == 1: | |
return getter | |
for key in split_keys[1:]: | |
getter = getter[key] | |
return getter | |
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: | |
serialized_dict = {} | |
for key, value in flatten_dict(stats).items(): | |
if isinstance(value, (torch.Tensor, np.ndarray)): | |
serialized_dict[key] = value.tolist() | |
elif isinstance(value, np.generic): | |
serialized_dict[key] = value.item() | |
elif isinstance(value, (int, float)): | |
serialized_dict[key] = value | |
else: | |
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.") | |
return unflatten_dict(serialized_dict) | |
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: | |
# Embed image bytes into the table before saving to parquet | |
format = dataset.format | |
dataset = dataset.with_format("arrow") | |
dataset = dataset.map(embed_table_storage, batched=False) | |
dataset = dataset.with_format(**format) | |
return dataset | |
def load_json(fpath: Path) -> Any: | |
with open(fpath) as f: | |
return json.load(f) | |
def write_json(data: dict, fpath: Path) -> None: | |
fpath.parent.mkdir(exist_ok=True, parents=True) | |
with open(fpath, "w") as f: | |
json.dump(data, f, indent=4, ensure_ascii=False) | |
def load_jsonlines(fpath: Path) -> list[Any]: | |
with jsonlines.open(fpath, "r") as reader: | |
return list(reader) | |
def write_jsonlines(data: dict, fpath: Path) -> None: | |
fpath.parent.mkdir(exist_ok=True, parents=True) | |
with jsonlines.open(fpath, "w") as writer: | |
writer.write_all(data) | |
def append_jsonlines(data: dict, fpath: Path) -> None: | |
fpath.parent.mkdir(exist_ok=True, parents=True) | |
with jsonlines.open(fpath, "a") as writer: | |
writer.write(data) | |
def write_info(info: dict, local_dir: Path): | |
write_json(info, local_dir / INFO_PATH) | |
def load_info(local_dir: Path) -> dict: | |
info = load_json(local_dir / INFO_PATH) | |
for ft in info["features"].values(): | |
ft["shape"] = tuple(ft["shape"]) | |
return info | |
def write_stats(stats: dict, local_dir: Path): | |
serialized_stats = serialize_dict(stats) | |
write_json(serialized_stats, local_dir / STATS_PATH) | |
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]: | |
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} | |
return unflatten_dict(stats) | |
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]: | |
if not (local_dir / STATS_PATH).exists(): | |
return None | |
stats = load_json(local_dir / STATS_PATH) | |
return cast_stats_to_numpy(stats) | |
def write_task(task_index: int, task: dict, local_dir: Path): | |
task_dict = { | |
"task_index": task_index, | |
"task": task, | |
} | |
append_jsonlines(task_dict, local_dir / TASKS_PATH) | |
def load_tasks(local_dir: Path) -> tuple[dict, dict]: | |
tasks = load_jsonlines(local_dir / TASKS_PATH) | |
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} | |
task_to_task_index = {task: task_index for task_index, task in tasks.items()} | |
return tasks, task_to_task_index | |
def write_episode(episode: dict, local_dir: Path): | |
append_jsonlines(episode, local_dir / EPISODES_PATH) | |
def load_episodes(local_dir: Path) -> dict: | |
episodes = load_jsonlines(local_dir / EPISODES_PATH) | |
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])} | |
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path): | |
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]` | |
# is a dictionary of stats and not an integer. | |
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)} | |
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH) | |
def load_episodes_stats(local_dir: Path) -> dict: | |
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH) | |
return { | |
item["episode_index"]: cast_stats_to_numpy(item["stats"]) | |
for item in sorted(episodes_stats, key=lambda x: x["episode_index"]) | |
} | |
def backward_compatible_episodes_stats( | |
stats: dict[str, dict[str, np.ndarray]], episodes: list[int] | |
) -> dict[str, dict[str, np.ndarray]]: | |
return dict.fromkeys(episodes, stats) | |
def load_image_as_numpy( | |
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True | |
) -> np.ndarray: | |
img = PILImage.open(fpath).convert("RGB") | |
img_array = np.array(img, dtype=dtype) | |
if channel_first: # (H, W, C) -> (C, H, W) | |
img_array = np.transpose(img_array, (2, 0, 1)) | |
if np.issubdtype(dtype, np.floating): | |
img_array /= 255.0 | |
return img_array | |
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): | |
"""Get a transform function that convert items from Hugging Face dataset (pyarrow) | |
to torch tensors. Importantly, images are converted from PIL, which corresponds to | |
a channel last representation (h w c) of uint8 type, to a torch image representation | |
with channel first (c h w) of float32 type in range [0,1]. | |
""" | |
for key in items_dict: | |
first_item = items_dict[key][0] | |
if isinstance(first_item, PILImage.Image): | |
to_tensor = transforms.ToTensor() | |
items_dict[key] = [to_tensor(img) for img in items_dict[key]] | |
elif first_item is None: | |
pass | |
else: | |
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] | |
return items_dict | |
def is_valid_version(version: str) -> bool: | |
try: | |
packaging.version.parse(version) | |
return True | |
except packaging.version.InvalidVersion: | |
return False | |
def check_version_compatibility( | |
repo_id: str, | |
version_to_check: str | packaging.version.Version, | |
current_version: str | packaging.version.Version, | |
enforce_breaking_major: bool = True, | |
) -> None: | |
v_check = ( | |
packaging.version.parse(version_to_check) | |
if not isinstance(version_to_check, packaging.version.Version) | |
else version_to_check | |
) | |
v_current = ( | |
packaging.version.parse(current_version) | |
if not isinstance(current_version, packaging.version.Version) | |
else current_version | |
) | |
if v_check.major < v_current.major and enforce_breaking_major: | |
raise BackwardCompatibilityError(repo_id, v_check) | |
elif v_check.minor < v_current.minor: | |
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check)) | |
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]: | |
"""Returns available valid versions (branches and tags) on given repo.""" | |
api = HfApi() | |
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset") | |
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags] | |
repo_versions = [] | |
for ref in repo_refs: | |
with contextlib.suppress(packaging.version.InvalidVersion): | |
repo_versions.append(packaging.version.parse(ref)) | |
return repo_versions | |
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str: | |
""" | |
Returns the version if available on repo or the latest compatible one. | |
Otherwise, will throw a `CompatibilityError`. | |
""" | |
target_version = ( | |
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version | |
) | |
hub_versions = get_repo_versions(repo_id) | |
if not hub_versions: | |
raise RevisionNotFoundError( | |
f"""Your dataset must be tagged with a codebase version. | |
Assuming _version_ is the codebase_version value in the info.json, you can run this: | |
```python | |
from huggingface_hub import HfApi | |
hub_api = HfApi() | |
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset") | |
``` | |
""" | |
) | |
if target_version in hub_versions: | |
return f"v{target_version}" | |
compatibles = [ | |
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor | |
] | |
if compatibles: | |
return_version = max(compatibles) | |
if return_version < target_version: | |
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}") | |
return f"v{return_version}" | |
lower_major = [v for v in hub_versions if v.major < target_version.major] | |
if lower_major: | |
raise BackwardCompatibilityError(repo_id, max(lower_major)) | |
upper_versions = [v for v in hub_versions if v > target_version] | |
assert len(upper_versions) > 0 | |
raise ForwardCompatibilityError(repo_id, min(upper_versions)) | |
def get_hf_features_from_features(features: dict) -> datasets.Features: | |
hf_features = {} | |
for key, ft in features.items(): | |
if ft["dtype"] == "video": | |
continue | |
elif ft["dtype"] == "image": | |
hf_features[key] = datasets.Image() | |
elif ft["shape"] == (1,): | |
hf_features[key] = datasets.Value(dtype=ft["dtype"]) | |
elif len(ft["shape"]) == 1: | |
hf_features[key] = datasets.Sequence( | |
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) | |
) | |
elif len(ft["shape"]) == 2: | |
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"]) | |
elif len(ft["shape"]) == 3: | |
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"]) | |
elif len(ft["shape"]) == 4: | |
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"]) | |
elif len(ft["shape"]) == 5: | |
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"]) | |
else: | |
raise ValueError(f"Corresponding feature is not valid: {ft}") | |
return datasets.Features(hf_features) | |
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: | |
camera_ft = {} | |
if robot.cameras: | |
camera_ft = { | |
key: {"dtype": "video" if use_videos else "image", **ft} | |
for key, ft in robot.camera_features.items() | |
} | |
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES} | |
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: | |
# TODO(aliberts): Implement "type" in dataset features and simplify this | |
policy_features = {} | |
for key, ft in features.items(): | |
shape = ft["shape"] | |
if ft["dtype"] in ["image", "video"]: | |
type = FeatureType.VISUAL | |
if len(shape) != 3: | |
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") | |
names = ft["names"] | |
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. | |
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) | |
shape = (shape[2], shape[0], shape[1]) | |
elif key == "observation.environment_state": | |
type = FeatureType.ENV | |
elif key.startswith("observation"): | |
type = FeatureType.STATE | |
elif key == "action": | |
type = FeatureType.ACTION | |
else: | |
continue | |
policy_features[key] = PolicyFeature( | |
type=type, | |
shape=shape, | |
) | |
return policy_features | |
def create_empty_dataset_info( | |
codebase_version: str, | |
fps: int, | |
robot_type: str, | |
features: dict, | |
use_videos: bool, | |
) -> dict: | |
return { | |
"codebase_version": codebase_version, | |
"robot_type": robot_type, | |
"total_episodes": 0, | |
"total_frames": 0, | |
"total_tasks": 0, | |
"total_videos": 0, | |
"total_chunks": 0, | |
"chunks_size": DEFAULT_CHUNK_SIZE, | |
"fps": fps, | |
"splits": {}, | |
"data_path": DEFAULT_PARQUET_PATH, | |
"video_path": DEFAULT_VIDEO_PATH if use_videos else None, | |
"features": features, | |
} | |
def get_episode_data_index( | |
episode_dicts: dict[dict], episodes: list[int] | None = None | |
) -> dict[str, torch.Tensor]: | |
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()} | |
if episodes is not None: | |
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} | |
cumulative_lengths = list(accumulate(episode_lengths.values())) | |
return { | |
"from": torch.LongTensor([0] + cumulative_lengths[:-1]), | |
"to": torch.LongTensor(cumulative_lengths), | |
} | |
def check_timestamps_sync( | |
timestamps: np.ndarray, | |
episode_indices: np.ndarray, | |
episode_data_index: dict[str, np.ndarray], | |
fps: int, | |
tolerance_s: float, | |
raise_value_error: bool = True, | |
) -> bool: | |
""" | |
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance | |
to account for possible numerical error. | |
Args: | |
timestamps (np.ndarray): Array of timestamps in seconds. | |
episode_indices (np.ndarray): Array indicating the episode index for each timestamp. | |
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to', | |
which identifies indices for the end of each episode. | |
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps. | |
tolerance_s (float): Allowed deviation from the expected (1/fps) difference. | |
raise_value_error (bool): Whether to raise a ValueError if the check fails. | |
Returns: | |
bool: True if all checked timestamp differences lie within tolerance, False otherwise. | |
Raises: | |
ValueError: If the check fails and `raise_value_error` is True. | |
""" | |
if timestamps.shape != episode_indices.shape: | |
raise ValueError( | |
"timestamps and episode_indices should have the same shape. " | |
f"Found {timestamps.shape=} and {episode_indices.shape=}." | |
) | |
# Consecutive differences | |
diffs = np.diff(timestamps) | |
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s | |
# Mask to ignore differences at the boundaries between episodes | |
mask = np.ones(len(diffs), dtype=bool) | |
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode | |
mask[ignored_diffs] = False | |
filtered_within_tolerance = within_tolerance[mask] | |
# Check if all remaining diffs are within tolerance | |
if not np.all(filtered_within_tolerance): | |
# Track original indices before masking | |
original_indices = np.arange(len(diffs)) | |
filtered_indices = original_indices[mask] | |
outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0] | |
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices] | |
outside_tolerances = [] | |
for idx in outside_tolerance_indices: | |
entry = { | |
"timestamps": [timestamps[idx], timestamps[idx + 1]], | |
"diff": diffs[idx], | |
"episode_index": episode_indices[idx].item() | |
if hasattr(episode_indices[idx], "item") | |
else episode_indices[idx], | |
} | |
outside_tolerances.append(entry) | |
if raise_value_error: | |
raise ValueError( | |
f"""One or several timestamps unexpectedly violate the tolerance inside episode range. | |
This might be due to synchronization issues during data collection. | |
\n{pformat(outside_tolerances)}""" | |
) | |
return False | |
return True | |
def check_delta_timestamps( | |
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True | |
) -> bool: | |
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance. | |
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be | |
actual timestamps from the dataset. | |
""" | |
outside_tolerance = {} | |
for key, delta_ts in delta_timestamps.items(): | |
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] | |
if not all(within_tolerance): | |
outside_tolerance[key] = [ | |
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within | |
] | |
if len(outside_tolerance) > 0: | |
if raise_value_error: | |
raise ValueError( | |
f""" | |
The following delta_timestamps are found outside of tolerance range. | |
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust | |
their values accordingly. | |
\n{pformat(outside_tolerance)} | |
""" | |
) | |
return False | |
return True | |
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: | |
delta_indices = {} | |
for key, delta_ts in delta_timestamps.items(): | |
delta_indices[key] = [round(d * fps) for d in delta_ts] | |
return delta_indices | |
def cycle(iterable): | |
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders. | |
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe. | |
""" | |
iterator = iter(iterable) | |
while True: | |
try: | |
yield next(iterator) | |
except StopIteration: | |
iterator = iter(iterable) | |
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None: | |
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already | |
exists before creating it. | |
""" | |
api = HfApi() | |
branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches | |
refs = [branch.ref for branch in branches] | |
ref = f"refs/heads/{branch}" | |
if ref in refs: | |
api.delete_branch(repo_id, repo_type=repo_type, branch=branch) | |
api.create_branch(repo_id, repo_type=repo_type, branch=branch) | |
def create_lerobot_dataset_card( | |
tags: list | None = None, | |
dataset_info: dict | None = None, | |
**kwargs, | |
) -> DatasetCard: | |
""" | |
Keyword arguments will be used to replace values in ./lerobot/common/datasets/card_template.md. | |
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses. | |
""" | |
card_tags = ["LeRobot"] | |
if tags: | |
card_tags += tags | |
if dataset_info: | |
dataset_structure = "[meta/info.json](meta/info.json):\n" | |
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n" | |
kwargs = {**kwargs, "dataset_structure": dataset_structure} | |
card_data = DatasetCardData( | |
license=kwargs.get("license"), | |
tags=card_tags, | |
task_categories=["robotics"], | |
configs=[ | |
{ | |
"config_name": "default", | |
"data_files": "data/*/*.parquet", | |
} | |
], | |
) | |
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text() | |
return DatasetCard.from_template( | |
card_data=card_data, | |
template_str=card_template, | |
**kwargs, | |
) | |
class IterableNamespace(SimpleNamespace): | |
""" | |
A namespace object that supports both dictionary-like iteration and dot notation access. | |
Automatically converts nested dictionaries into IterableNamespaces. | |
This class extends SimpleNamespace to provide: | |
- Dictionary-style iteration over keys | |
- Access to items via both dot notation (obj.key) and brackets (obj["key"]) | |
- Dictionary-like methods: items(), keys(), values() | |
- Recursive conversion of nested dictionaries | |
Args: | |
dictionary: Optional dictionary to initialize the namespace | |
**kwargs: Additional keyword arguments passed to SimpleNamespace | |
Examples: | |
>>> data = {"name": "Alice", "details": {"age": 25}} | |
>>> ns = IterableNamespace(data) | |
>>> ns.name | |
'Alice' | |
>>> ns.details.age | |
25 | |
>>> list(ns.keys()) | |
['name', 'details'] | |
>>> for key, value in ns.items(): | |
... print(f"{key}: {value}") | |
name: Alice | |
details: IterableNamespace(age=25) | |
""" | |
def __init__(self, dictionary: dict[str, Any] = None, **kwargs): | |
super().__init__(**kwargs) | |
if dictionary is not None: | |
for key, value in dictionary.items(): | |
if isinstance(value, dict): | |
setattr(self, key, IterableNamespace(value)) | |
else: | |
setattr(self, key, value) | |
def __iter__(self) -> Iterator[str]: | |
return iter(vars(self)) | |
def __getitem__(self, key: str) -> Any: | |
return vars(self)[key] | |
def items(self): | |
return vars(self).items() | |
def values(self): | |
return vars(self).values() | |
def keys(self): | |
return vars(self).keys() | |
def validate_frame(frame: dict, features: dict): | |
optional_features = {"timestamp"} | |
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"} | |
actual_features = set(frame.keys()) | |
error_message = validate_features_presence(actual_features, expected_features, optional_features) | |
if "task" in frame: | |
error_message += validate_feature_string("task", frame["task"]) | |
common_features = actual_features & (expected_features | optional_features) | |
for name in common_features - {"task"}: | |
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) | |
if error_message: | |
raise ValueError(error_message) | |
def validate_features_presence( | |
actual_features: set[str], expected_features: set[str], optional_features: set[str] | |
): | |
error_message = "" | |
missing_features = expected_features - actual_features | |
extra_features = actual_features - (expected_features | optional_features) | |
if missing_features or extra_features: | |
error_message += "Feature mismatch in `frame` dictionary:\n" | |
if missing_features: | |
error_message += f"Missing features: {missing_features}\n" | |
if extra_features: | |
error_message += f"Extra features: {extra_features}\n" | |
return error_message | |
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): | |
expected_dtype = feature["dtype"] | |
expected_shape = feature["shape"] | |
if is_valid_numpy_dtype_string(expected_dtype): | |
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) | |
elif expected_dtype in ["image", "video"]: | |
return validate_feature_image_or_video(name, expected_shape, value) | |
elif expected_dtype == "string": | |
return validate_feature_string(name, value) | |
else: | |
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") | |
def validate_feature_numpy_array( | |
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray | |
): | |
error_message = "" | |
if isinstance(value, np.ndarray): | |
actual_dtype = value.dtype | |
actual_shape = value.shape | |
if actual_dtype != np.dtype(expected_dtype): | |
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n" | |
if actual_shape != expected_shape: | |
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n" | |
else: | |
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n" | |
return error_message | |
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): | |
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. | |
error_message = "" | |
if isinstance(value, np.ndarray): | |
actual_shape = value.shape | |
c, h, w = expected_shape | |
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): | |
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" | |
elif isinstance(value, PILImage.Image): | |
pass | |
else: | |
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n" | |
return error_message | |
def validate_feature_string(name: str, value: str): | |
if not isinstance(value, str): | |
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" | |
return "" | |
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict): | |
if "size" not in episode_buffer: | |
raise ValueError("size key not found in episode_buffer") | |
if "task" not in episode_buffer: | |
raise ValueError("task key not found in episode_buffer") | |
if episode_buffer["episode_index"] != total_episodes: | |
# TODO(aliberts): Add option to use existing episode_index | |
raise NotImplementedError( | |
"You might have manually provided the episode_buffer with an episode_index that doesn't " | |
"match the total number of episodes already in the dataset. This is not supported for now." | |
) | |
if episode_buffer["size"] == 0: | |
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") | |
buffer_keys = set(episode_buffer.keys()) - {"task", "size"} | |
if not buffer_keys == set(features): | |
raise ValueError( | |
f"Features from `episode_buffer` don't match the ones in `features`." | |
f"In episode_buffer not in features: {buffer_keys - set(features)}" | |
f"In features not in episode_buffer: {set(features) - buffer_keys}" | |
) | |