Spaces:
Running
Running
File size: 6,788 Bytes
529ed6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from torch import nn
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
def get_policy_class(name: str) -> PreTrainedPolicy:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
if name == "tdmpc":
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy
elif name == "diffusion":
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
return DiffusionPolicy
elif name == "act":
from lerobot.common.policies.act.modeling_act import ACTPolicy
return ACTPolicy
elif name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy
elif name == "pi0":
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
return PI0Policy
elif name == "pi0fast":
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
return PI0FASTPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
if policy_type == "tdmpc":
return TDMPCConfig(**kwargs)
elif policy_type == "diffusion":
return DiffusionConfig(**kwargs)
elif policy_type == "act":
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
) -> PreTrainedPolicy:
"""Make an instance of a policy class.
This function exists because (for now) we need to parse features from either a dataset or an environment
in order to properly dimension and instantiate a policy for that dataset or environment.
Args:
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
be loaded with the weights from that path.
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
provided if ds_meta is not. Defaults to None.
Raises:
ValueError: Either ds_meta or env and env_cfg must be provided.
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
Returns:
PreTrainedPolicy: _description_
"""
if bool(ds_meta) == bool(env_cfg):
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
# NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If
# you want this op to be added in priority during the prototype phase of this feature, please comment on
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
# slower than running natively on MPS.
if cfg.type == "vqbet" and cfg.device == "mps":
raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend."
)
policy_cls = get_policy_class(cfg.type)
kwargs = {}
if ds_meta is not None:
features = dataset_to_policy_features(ds_meta.features)
kwargs["dataset_stats"] = ds_meta.stats
else:
if not cfg.pretrained_path:
logging.warning(
"You are instantiating a policy from scratch and its features are parsed from an environment "
"rather than a dataset. Normalization modules inside the policy will have infinite values "
"by default without stats from a dataset."
)
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
if cfg.pretrained_path:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
policy = policy_cls.from_pretrained(**kwargs)
else:
# Make a fresh policy.
policy = policy_cls(**kwargs)
policy.to(cfg.device)
assert isinstance(policy, nn.Module)
# policy = torch.compile(policy, mode="reduce-overhead")
return policy
|