#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import importlib.resources import json import logging from collections.abc import Iterator from itertools import accumulate from pathlib import Path from pprint import pformat from types import SimpleNamespace from typing import Any import datasets import jsonlines import numpy as np import packaging.version import torch from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub.errors import RevisionNotFoundError from PIL import Image as PILImage from torchvision import transforms from lerobot.common.datasets.backward_compatibility import ( V21_MESSAGE, BackwardCompatibilityError, ForwardCompatibilityError, ) from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.utils.utils import is_valid_numpy_dtype_string from lerobot.configs.types import DictLike, FeatureType, PolicyFeature DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk INFO_PATH = "meta/info.json" EPISODES_PATH = "meta/episodes.jsonl" STATS_PATH = "meta/stats.json" EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" TASKS_PATH = "meta/tasks.jsonl" DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" DATASET_CARD_TEMPLATE = """ --- # Metadata will go there --- This dataset was created using [LeRobot](https://github.com/huggingface/lerobot). ## {} """ DEFAULT_FEATURES = { "timestamp": {"dtype": "float32", "shape": (1,), "names": None}, "frame_index": {"dtype": "int64", "shape": (1,), "names": None}, "episode_index": {"dtype": "int64", "shape": (1,), "names": None}, "index": {"dtype": "int64", "shape": (1,), "names": None}, "task_index": {"dtype": "int64", "shape": (1,), "names": None}, } def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. For example: ``` >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}` >>> print(flatten_dict(dct)) {"a/b": 1, "a/c/d": 2, "e": 3} """ items = [] for k, v in d.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k if isinstance(v, dict): items.extend(flatten_dict(v, new_key, sep=sep).items()) else: items.append((new_key, v)) return dict(items) def unflatten_dict(d: dict, sep: str = "/") -> dict: outdict = {} for key, value in d.items(): parts = key.split(sep) d = outdict for part in parts[:-1]: if part not in d: d[part] = {} d = d[part] d[parts[-1]] = value return outdict def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any: split_keys = flattened_key.split(sep) getter = obj[split_keys[0]] if len(split_keys) == 1: return getter for key in split_keys[1:]: getter = getter[key] return getter def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: serialized_dict = {} for key, value in flatten_dict(stats).items(): if isinstance(value, (torch.Tensor, np.ndarray)): serialized_dict[key] = value.tolist() elif isinstance(value, np.generic): serialized_dict[key] = value.item() elif isinstance(value, (int, float)): serialized_dict[key] = value else: raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.") return unflatten_dict(serialized_dict) def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: # Embed image bytes into the table before saving to parquet format = dataset.format dataset = dataset.with_format("arrow") dataset = dataset.map(embed_table_storage, batched=False) dataset = dataset.with_format(**format) return dataset def load_json(fpath: Path) -> Any: with open(fpath) as f: return json.load(f) def write_json(data: dict, fpath: Path) -> None: fpath.parent.mkdir(exist_ok=True, parents=True) with open(fpath, "w") as f: json.dump(data, f, indent=4, ensure_ascii=False) def load_jsonlines(fpath: Path) -> list[Any]: with jsonlines.open(fpath, "r") as reader: return list(reader) def write_jsonlines(data: dict, fpath: Path) -> None: fpath.parent.mkdir(exist_ok=True, parents=True) with jsonlines.open(fpath, "w") as writer: writer.write_all(data) def append_jsonlines(data: dict, fpath: Path) -> None: fpath.parent.mkdir(exist_ok=True, parents=True) with jsonlines.open(fpath, "a") as writer: writer.write(data) def write_info(info: dict, local_dir: Path): write_json(info, local_dir / INFO_PATH) def load_info(local_dir: Path) -> dict: info = load_json(local_dir / INFO_PATH) for ft in info["features"].values(): ft["shape"] = tuple(ft["shape"]) return info def write_stats(stats: dict, local_dir: Path): serialized_stats = serialize_dict(stats) write_json(serialized_stats, local_dir / STATS_PATH) def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]: stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} return unflatten_dict(stats) def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]: if not (local_dir / STATS_PATH).exists(): return None stats = load_json(local_dir / STATS_PATH) return cast_stats_to_numpy(stats) def write_task(task_index: int, task: dict, local_dir: Path): task_dict = { "task_index": task_index, "task": task, } append_jsonlines(task_dict, local_dir / TASKS_PATH) def load_tasks(local_dir: Path) -> tuple[dict, dict]: tasks = load_jsonlines(local_dir / TASKS_PATH) tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} task_to_task_index = {task: task_index for task_index, task in tasks.items()} return tasks, task_to_task_index def write_episode(episode: dict, local_dir: Path): append_jsonlines(episode, local_dir / EPISODES_PATH) def load_episodes(local_dir: Path) -> dict: episodes = load_jsonlines(local_dir / EPISODES_PATH) return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])} def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path): # We wrap episode_stats in a dictionary since `episode_stats["episode_index"]` # is a dictionary of stats and not an integer. episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)} append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH) def load_episodes_stats(local_dir: Path) -> dict: episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH) return { item["episode_index"]: cast_stats_to_numpy(item["stats"]) for item in sorted(episodes_stats, key=lambda x: x["episode_index"]) } def backward_compatible_episodes_stats( stats: dict[str, dict[str, np.ndarray]], episodes: list[int] ) -> dict[str, dict[str, np.ndarray]]: return dict.fromkeys(episodes, stats) def load_image_as_numpy( fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True ) -> np.ndarray: img = PILImage.open(fpath).convert("RGB") img_array = np.array(img, dtype=dtype) if channel_first: # (H, W, C) -> (C, H, W) img_array = np.transpose(img_array, (2, 0, 1)) if np.issubdtype(dtype, np.floating): img_array /= 255.0 return img_array def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): """Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to a channel last representation (h w c) of uint8 type, to a torch image representation with channel first (c h w) of float32 type in range [0,1]. """ for key in items_dict: first_item = items_dict[key][0] if isinstance(first_item, PILImage.Image): to_tensor = transforms.ToTensor() items_dict[key] = [to_tensor(img) for img in items_dict[key]] elif first_item is None: pass else: items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] return items_dict def is_valid_version(version: str) -> bool: try: packaging.version.parse(version) return True except packaging.version.InvalidVersion: return False def check_version_compatibility( repo_id: str, version_to_check: str | packaging.version.Version, current_version: str | packaging.version.Version, enforce_breaking_major: bool = True, ) -> None: v_check = ( packaging.version.parse(version_to_check) if not isinstance(version_to_check, packaging.version.Version) else version_to_check ) v_current = ( packaging.version.parse(current_version) if not isinstance(current_version, packaging.version.Version) else current_version ) if v_check.major < v_current.major and enforce_breaking_major: raise BackwardCompatibilityError(repo_id, v_check) elif v_check.minor < v_current.minor: logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check)) def get_repo_versions(repo_id: str) -> list[packaging.version.Version]: """Returns available valid versions (branches and tags) on given repo.""" api = HfApi() repo_refs = api.list_repo_refs(repo_id, repo_type="dataset") repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags] repo_versions = [] for ref in repo_refs: with contextlib.suppress(packaging.version.InvalidVersion): repo_versions.append(packaging.version.parse(ref)) return repo_versions def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str: """ Returns the version if available on repo or the latest compatible one. Otherwise, will throw a `CompatibilityError`. """ target_version = ( packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version ) hub_versions = get_repo_versions(repo_id) if not hub_versions: raise RevisionNotFoundError( f"""Your dataset must be tagged with a codebase version. Assuming _version_ is the codebase_version value in the info.json, you can run this: ```python from huggingface_hub import HfApi hub_api = HfApi() hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset") ``` """ ) if target_version in hub_versions: return f"v{target_version}" compatibles = [ v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor ] if compatibles: return_version = max(compatibles) if return_version < target_version: logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}") return f"v{return_version}" lower_major = [v for v in hub_versions if v.major < target_version.major] if lower_major: raise BackwardCompatibilityError(repo_id, max(lower_major)) upper_versions = [v for v in hub_versions if v > target_version] assert len(upper_versions) > 0 raise ForwardCompatibilityError(repo_id, min(upper_versions)) def get_hf_features_from_features(features: dict) -> datasets.Features: hf_features = {} for key, ft in features.items(): if ft["dtype"] == "video": continue elif ft["dtype"] == "image": hf_features[key] = datasets.Image() elif ft["shape"] == (1,): hf_features[key] = datasets.Value(dtype=ft["dtype"]) elif len(ft["shape"]) == 1: hf_features[key] = datasets.Sequence( length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) ) elif len(ft["shape"]) == 2: hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"]) elif len(ft["shape"]) == 3: hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"]) elif len(ft["shape"]) == 4: hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"]) elif len(ft["shape"]) == 5: hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"]) else: raise ValueError(f"Corresponding feature is not valid: {ft}") return datasets.Features(hf_features) def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: camera_ft = {} if robot.cameras: camera_ft = { key: {"dtype": "video" if use_videos else "image", **ft} for key, ft in robot.camera_features.items() } return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES} def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: # TODO(aliberts): Implement "type" in dataset features and simplify this policy_features = {} for key, ft in features.items(): shape = ft["shape"] if ft["dtype"] in ["image", "video"]: type = FeatureType.VISUAL if len(shape) != 3: raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") names = ft["names"] # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) shape = (shape[2], shape[0], shape[1]) elif key == "observation.environment_state": type = FeatureType.ENV elif key.startswith("observation"): type = FeatureType.STATE elif key == "action": type = FeatureType.ACTION else: continue policy_features[key] = PolicyFeature( type=type, shape=shape, ) return policy_features def create_empty_dataset_info( codebase_version: str, fps: int, robot_type: str, features: dict, use_videos: bool, ) -> dict: return { "codebase_version": codebase_version, "robot_type": robot_type, "total_episodes": 0, "total_frames": 0, "total_tasks": 0, "total_videos": 0, "total_chunks": 0, "chunks_size": DEFAULT_CHUNK_SIZE, "fps": fps, "splits": {}, "data_path": DEFAULT_PARQUET_PATH, "video_path": DEFAULT_VIDEO_PATH if use_videos else None, "features": features, } def get_episode_data_index( episode_dicts: dict[dict], episodes: list[int] | None = None ) -> dict[str, torch.Tensor]: episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()} if episodes is not None: episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} cumulative_lengths = list(accumulate(episode_lengths.values())) return { "from": torch.LongTensor([0] + cumulative_lengths[:-1]), "to": torch.LongTensor(cumulative_lengths), } def check_timestamps_sync( timestamps: np.ndarray, episode_indices: np.ndarray, episode_data_index: dict[str, np.ndarray], fps: int, tolerance_s: float, raise_value_error: bool = True, ) -> bool: """ This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance to account for possible numerical error. Args: timestamps (np.ndarray): Array of timestamps in seconds. episode_indices (np.ndarray): Array indicating the episode index for each timestamp. episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to', which identifies indices for the end of each episode. fps (int): Frames per second. Used to check the expected difference between consecutive timestamps. tolerance_s (float): Allowed deviation from the expected (1/fps) difference. raise_value_error (bool): Whether to raise a ValueError if the check fails. Returns: bool: True if all checked timestamp differences lie within tolerance, False otherwise. Raises: ValueError: If the check fails and `raise_value_error` is True. """ if timestamps.shape != episode_indices.shape: raise ValueError( "timestamps and episode_indices should have the same shape. " f"Found {timestamps.shape=} and {episode_indices.shape=}." ) # Consecutive differences diffs = np.diff(timestamps) within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s # Mask to ignore differences at the boundaries between episodes mask = np.ones(len(diffs), dtype=bool) ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode mask[ignored_diffs] = False filtered_within_tolerance = within_tolerance[mask] # Check if all remaining diffs are within tolerance if not np.all(filtered_within_tolerance): # Track original indices before masking original_indices = np.arange(len(diffs)) filtered_indices = original_indices[mask] outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0] outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices] outside_tolerances = [] for idx in outside_tolerance_indices: entry = { "timestamps": [timestamps[idx], timestamps[idx + 1]], "diff": diffs[idx], "episode_index": episode_indices[idx].item() if hasattr(episode_indices[idx], "item") else episode_indices[idx], } outside_tolerances.append(entry) if raise_value_error: raise ValueError( f"""One or several timestamps unexpectedly violate the tolerance inside episode range. This might be due to synchronization issues during data collection. \n{pformat(outside_tolerances)}""" ) return False return True def check_delta_timestamps( delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True ) -> bool: """This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance. This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be actual timestamps from the dataset. """ outside_tolerance = {} for key, delta_ts in delta_timestamps.items(): within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] if not all(within_tolerance): outside_tolerance[key] = [ ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within ] if len(outside_tolerance) > 0: if raise_value_error: raise ValueError( f""" The following delta_timestamps are found outside of tolerance range. Please make sure they are multiples of 1/{fps} +/- tolerance and adjust their values accordingly. \n{pformat(outside_tolerance)} """ ) return False return True def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: delta_indices = {} for key, delta_ts in delta_timestamps.items(): delta_indices[key] = [round(d * fps) for d in delta_ts] return delta_indices def cycle(iterable): """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe. """ iterator = iter(iterable) while True: try: yield next(iterator) except StopIteration: iterator = iter(iterable) def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None: """Create a branch on a existing Hugging Face repo. Delete the branch if it already exists before creating it. """ api = HfApi() branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches refs = [branch.ref for branch in branches] ref = f"refs/heads/{branch}" if ref in refs: api.delete_branch(repo_id, repo_type=repo_type, branch=branch) api.create_branch(repo_id, repo_type=repo_type, branch=branch) def create_lerobot_dataset_card( tags: list | None = None, dataset_info: dict | None = None, **kwargs, ) -> DatasetCard: """ Keyword arguments will be used to replace values in ./lerobot/common/datasets/card_template.md. Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses. """ card_tags = ["LeRobot"] if tags: card_tags += tags if dataset_info: dataset_structure = "[meta/info.json](meta/info.json):\n" dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n" kwargs = {**kwargs, "dataset_structure": dataset_structure} card_data = DatasetCardData( license=kwargs.get("license"), tags=card_tags, task_categories=["robotics"], configs=[ { "config_name": "default", "data_files": "data/*/*.parquet", } ], ) card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text() return DatasetCard.from_template( card_data=card_data, template_str=card_template, **kwargs, ) class IterableNamespace(SimpleNamespace): """ A namespace object that supports both dictionary-like iteration and dot notation access. Automatically converts nested dictionaries into IterableNamespaces. This class extends SimpleNamespace to provide: - Dictionary-style iteration over keys - Access to items via both dot notation (obj.key) and brackets (obj["key"]) - Dictionary-like methods: items(), keys(), values() - Recursive conversion of nested dictionaries Args: dictionary: Optional dictionary to initialize the namespace **kwargs: Additional keyword arguments passed to SimpleNamespace Examples: >>> data = {"name": "Alice", "details": {"age": 25}} >>> ns = IterableNamespace(data) >>> ns.name 'Alice' >>> ns.details.age 25 >>> list(ns.keys()) ['name', 'details'] >>> for key, value in ns.items(): ... print(f"{key}: {value}") name: Alice details: IterableNamespace(age=25) """ def __init__(self, dictionary: dict[str, Any] = None, **kwargs): super().__init__(**kwargs) if dictionary is not None: for key, value in dictionary.items(): if isinstance(value, dict): setattr(self, key, IterableNamespace(value)) else: setattr(self, key, value) def __iter__(self) -> Iterator[str]: return iter(vars(self)) def __getitem__(self, key: str) -> Any: return vars(self)[key] def items(self): return vars(self).items() def values(self): return vars(self).values() def keys(self): return vars(self).keys() def validate_frame(frame: dict, features: dict): optional_features = {"timestamp"} expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"} actual_features = set(frame.keys()) error_message = validate_features_presence(actual_features, expected_features, optional_features) if "task" in frame: error_message += validate_feature_string("task", frame["task"]) common_features = actual_features & (expected_features | optional_features) for name in common_features - {"task"}: error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) if error_message: raise ValueError(error_message) def validate_features_presence( actual_features: set[str], expected_features: set[str], optional_features: set[str] ): error_message = "" missing_features = expected_features - actual_features extra_features = actual_features - (expected_features | optional_features) if missing_features or extra_features: error_message += "Feature mismatch in `frame` dictionary:\n" if missing_features: error_message += f"Missing features: {missing_features}\n" if extra_features: error_message += f"Extra features: {extra_features}\n" return error_message def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): expected_dtype = feature["dtype"] expected_shape = feature["shape"] if is_valid_numpy_dtype_string(expected_dtype): return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) elif expected_dtype in ["image", "video"]: return validate_feature_image_or_video(name, expected_shape, value) elif expected_dtype == "string": return validate_feature_string(name, value) else: raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") def validate_feature_numpy_array( name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray ): error_message = "" if isinstance(value, np.ndarray): actual_dtype = value.dtype actual_shape = value.shape if actual_dtype != np.dtype(expected_dtype): error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n" if actual_shape != expected_shape: error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n" else: error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n" return error_message def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. error_message = "" if isinstance(value, np.ndarray): actual_shape = value.shape c, h, w = expected_shape if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" elif isinstance(value, PILImage.Image): pass else: error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n" return error_message def validate_feature_string(name: str, value: str): if not isinstance(value, str): return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" return "" def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict): if "size" not in episode_buffer: raise ValueError("size key not found in episode_buffer") if "task" not in episode_buffer: raise ValueError("task key not found in episode_buffer") if episode_buffer["episode_index"] != total_episodes: # TODO(aliberts): Add option to use existing episode_index raise NotImplementedError( "You might have manually provided the episode_buffer with an episode_index that doesn't " "match the total number of episodes already in the dataset. This is not supported for now." ) if episode_buffer["size"] == 0: raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") buffer_keys = set(episode_buffer.keys()) - {"task", "size"} if not buffer_keys == set(features): raise ValueError( f"Features from `episode_buffer` don't match the ones in `features`." f"In episode_buffer not in features: {buffer_keys - set(features)}" f"In features not in episode_buffer: {set(features) - buffer_keys}" )