Spaces:
Running
Running
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset | |
from lerobot.common.policies.factory import make_policy | |
from lerobot.configs.policies import PreTrainedConfig | |
torch.backends.cudnn.benchmark = True | |
def main(): | |
device = "cuda" | |
dataset_repo_id = "danaaubakirova/koch_test" | |
# model_name = "pi0_base" | |
# ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch" | |
ckpt_torch_dir = "lerobot/pi0" | |
dataset = LeRobotDataset(dataset_repo_id, episodes=[0]) | |
dataloader = torch.utils.data.DataLoader( | |
dataset, | |
num_workers=0, | |
batch_size=1, | |
) | |
batch = next(iter(dataloader)) | |
# To device | |
for k in batch: | |
if isinstance(batch[k], torch.Tensor): | |
batch[k] = batch[k].to(device=device, dtype=torch.float32) | |
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) | |
cfg.pretrained_path = ckpt_torch_dir | |
policy = make_policy(cfg, ds_meta=dataset.meta) | |
# policy = torch.compile(policy, mode="reduce-overhead") | |
warmup_iters = 10 | |
benchmark_iters = 30 | |
# Warmup | |
for _ in range(warmup_iters): | |
torch.cuda.synchronize() | |
policy.select_action(batch) | |
policy.reset() | |
torch.cuda.synchronize() | |
# Benchmark | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
start_event.record() | |
for _ in range(benchmark_iters): | |
policy.select_action(batch) | |
policy.reset() | |
end_event.record() | |
# Synchronize and measure time | |
torch.cuda.synchronize() | |
elapsed_time_ms = start_event.elapsed_time(end_event) | |
avg_time_per_iter = elapsed_time_ms / benchmark_iters | |
print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms") | |
if __name__ == "__main__": | |
with torch.inference_mode(): | |
main() | |