hma / datasets /extern /frodobot.py
LeroyWaa's picture
draft
246c106
raw
history blame
5.23 kB
# --------------------------------------------------------
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import random
import os
import time
import sys
import numpy as np
import IPython
import torch
from tqdm import tqdm
from collections import OrderedDict
import os
import PIL.Image
import numpy as np
from typing import Union, List
from pathlib import Path
import re
CURRENT_DIR = os.path.dirname(__file__)
import cv2
from os.path import expanduser
import pickle
import cv2
from matplotlib import pyplot as plt
import pandas as pd
import json
RESOLUTION = (480, 480)
DATA = "/home/liruiw/Projects/frodobot/"
# https://colab.research.google.com/#scrollTo=50ce529a-a20a-4852-9a5a-114b52b98f2e&fileId=https%3A//huggingface.co/datasets/frodobots/FrodoBots-2K/blob/main/helpercode.ipynb
# #### control data
import pandas as pd
# print(f"{dataset_dir}/control_data_{ride_id}.csv")
def convert_img_dataset(
dataset_dir="/home/liruiw/Projects/frodobot/output_rides_22",
env_names=None,
gui=False,
episode_num_pertask=2000,
**kwargs,
):
# convert to a list of episodes that can be added to replay buffer
for eps_file in os.listdir(dataset_dir)[:50]: # 50 trajectories
dataset_dir_ = os.path.join(dataset_dir, eps_file)
if os.path.isdir(dataset_dir_):
ride_id = dataset_dir_.split("_")[-2]
print(dataset_dir_)
##### control data
control = pd.read_csv(f"{dataset_dir_}/control_data_{ride_id}.csv")
control_data_dict = control.set_index("timestamp").T.to_dict("list")
control_sorted_keys = sorted(list(control_data_dict.keys()))
##### IMU data
gyro_data = pd.read_csv(f"{dataset_dir_}/imu_data_{ride_id}.csv")[["timestamp", "gyroscope"]]
gyro_data_dict = gyro_data.set_index("timestamp").T.to_dict("list")
gyro_sorted_keys = sorted(list(gyro_data_dict.keys()))
compass_data = pd.read_csv(f"{dataset_dir_}/imu_data_{ride_id}.csv")[["timestamp", "compass"]]
compass_data_dict = compass_data.set_index("timestamp").T.to_dict("list")
compass_sorted_keys = sorted(list(compass_data_dict.keys()))
accel_data = pd.read_csv(f"{dataset_dir_}/imu_data_{ride_id}.csv")[["timestamp", "accelerometer"]]
accel_data_dict = accel_data.set_index("timestamp").T.to_dict("list")
accel_sorted_keys = sorted(list(accel_data_dict.keys()))
##### Camera data
camera_data = pd.read_csv(f"{dataset_dir_}/front_camera_timestamps_{ride_id}.csv")
camera_data_dict = camera_data.set_index("timestamp").T.to_dict("list")
camera_sorted_keys = sorted(list(camera_data_dict.keys()))
images = sorted(os.listdir(f"{dataset_dir_}/front_camera/"))
# #### front camera video
# front_camera = f"{dataset_dir}/recordings/0f0e8539d249f38e3ae7b18660f5af8c_ride_39572__uid_s_1000__uid_e_video_20240502221408754.ts"
languages = "drive around to play" # dummy
steps = []
SUBSAMPLE_IDX = 5
for idx, control_t in enumerate(control_sorted_keys):
# enumerate along actions and only pick matched timesteps
action = control_data_dict[control_t]
camera_t = camera_sorted_keys[np.argmin(np.array(camera_sorted_keys) - control_t)]
camera_path = images[camera_data_dict[camera_t][0]]
img = cv2.resize(cv2.imread(f"{dataset_dir_}/front_camera/{camera_path}"), None, fx=0.5, fy=0.5)
gyro = gyro_data_dict[gyro_sorted_keys[np.argmin(np.array(gyro_sorted_keys) - control_t)]]
first_three_strings = eval(gyro[0])[0][:3]
gyro_array = np.array(first_three_strings, dtype=float)
compass = compass_data_dict[compass_sorted_keys[np.argmin(np.array(compass_sorted_keys) - control_t)]]
first_three_strings = eval(compass[0])[0][:3]
compass_array = np.array(first_three_strings, dtype=float)
accel = accel_data_dict[accel_sorted_keys[np.argmin(np.array(accel_sorted_keys) - control_t)]]
first_three_strings = eval(accel[0])[0][:3]
accel_array = np.array(first_three_strings, dtype=float)
prop = np.concatenate((gyro_array, compass_array, accel_array))
step = {
"observation": {"state": prop, "image": img},
"action": action,
"language_instruction": languages,
}
steps.append(OrderedDict(step))
data_dict = {"steps": steps}
yield data_dict
class RolloutRunner:
"""evaluate policy rollouts"""
def __init__(self, env_names, episode_num, save_video=False):
self.env_names = env_names
self.episode_num = episode_num
self.envs = []
self.scene_files = []
self.save_video = save_video
@torch.no_grad()
def run(self, policy, save_video=False, gui=False, video_postfix="", seed=233, env_name=None, **kwargs):
pass