|
"""Image dataset.""" |
|
|
|
import os |
|
import pickle |
|
import warnings |
|
|
|
import numpy as np |
|
from torch.utils.data import Dataset |
|
|
|
from cliport import tasks |
|
from cliport.tasks import cameras |
|
from cliport.utils import utils |
|
import traceback |
|
|
|
|
|
PIXEL_SIZE = 0.003125 |
|
CAMERA_CONFIG = cameras.RealSenseD415.CONFIG |
|
BOUNDS = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.28]]) |
|
|
|
|
|
TASK_NAMES = (tasks.names).keys() |
|
TASK_NAMES = sorted(TASK_NAMES)[::-1] |
|
|
|
|
|
class RavensDataset(Dataset): |
|
"""A simple image dataset class.""" |
|
|
|
def __init__(self, path, cfg, n_demos=0, augment=False): |
|
"""A simple RGB-D image dataset.""" |
|
self._path = path |
|
|
|
self.cfg = cfg |
|
self.sample_set = [] |
|
self.max_seed = -1 |
|
self.n_episodes = 0 |
|
self.images = self.cfg['dataset']['images'] |
|
self.cache = self.cfg['dataset']['cache'] |
|
self.n_demos = n_demos |
|
self.augment = augment |
|
|
|
self.aug_theta_sigma = self.cfg['dataset']['augment']['theta_sigma'] if 'augment' in self.cfg['dataset'] else 60 |
|
self.pix_size = 0.003125 |
|
self.in_shape = (320, 160, 6) |
|
self.cam_config = cameras.RealSenseD415.CONFIG |
|
self.bounds = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.28]]) |
|
|
|
|
|
color_path = os.path.join(self._path, 'action') |
|
if os.path.exists(color_path): |
|
for fname in sorted(os.listdir(color_path)): |
|
if '.pkl' in fname: |
|
seed = int(fname[(fname.find('-') + 1):-4]) |
|
self.n_episodes += 1 |
|
self.max_seed = max(self.max_seed, seed) |
|
|
|
self._cache = {} |
|
|
|
if self.n_demos > 0: |
|
self.images = self.cfg['dataset']['images'] |
|
self.cache = self.cfg['dataset']['cache'] |
|
|
|
|
|
if self.n_demos > self.n_episodes: |
|
|
|
print(f"Requested training on {self.n_demos} demos, but only {self.n_episodes} demos exist in the dataset path: {self._path}.") |
|
self.n_demos = self.n_episodes |
|
|
|
episodes = np.random.choice(range(self.n_episodes), self.n_demos, False) |
|
self.set(episodes) |
|
|
|
|
|
def add(self, seed, episode): |
|
"""Add an episode to the dataset. |
|
|
|
Args: |
|
seed: random seed used to initialize the episode. |
|
episode: list of (obs, act, reward, info) tuples. |
|
""" |
|
color, depth, action, reward, info = [], [], [], [], [] |
|
for obs, act, r, i in episode: |
|
color.append(obs['color']) |
|
depth.append(obs['depth']) |
|
action.append(act) |
|
reward.append(r) |
|
info.append(i) |
|
|
|
color = np.uint8(color) |
|
depth = np.float32(depth) |
|
|
|
def dump(data, field): |
|
field_path = os.path.join(self._path, field) |
|
if not os.path.exists(field_path): |
|
os.makedirs(field_path) |
|
fname = f'{self.n_episodes:06d}-{seed}.pkl' |
|
with open(os.path.join(field_path, fname), 'wb') as f: |
|
pickle.dump(data, f) |
|
|
|
dump(color, 'color') |
|
dump(depth, 'depth') |
|
dump(action, 'action') |
|
dump(reward, 'reward') |
|
dump(info, 'info') |
|
|
|
self.n_episodes += 1 |
|
self.max_seed = max(self.max_seed, seed) |
|
|
|
def set(self, episodes): |
|
"""Limit random samples to specific fixed set.""" |
|
self.sample_set = episodes |
|
|
|
def load(self, episode_id, images=True, cache=False): |
|
|
|
def load_field(episode_id, field, fname): |
|
|
|
|
|
if cache: |
|
if episode_id in self._cache: |
|
if field in self._cache[episode_id]: |
|
return self._cache[episode_id][field] |
|
else: |
|
self._cache[episode_id] = {} |
|
|
|
|
|
path = os.path.join(self._path, field) |
|
data = pickle.load(open(os.path.join(path, fname), 'rb')) |
|
if cache: |
|
self._cache[episode_id][field] = data |
|
return data |
|
|
|
|
|
seed = None |
|
path = os.path.join(self._path, 'action') |
|
for fname in sorted(os.listdir(path)): |
|
if f'{episode_id:06d}' in fname: |
|
seed = int(fname[(fname.find('-') + 1):-4]) |
|
|
|
|
|
color = load_field(episode_id, 'color', fname) |
|
depth = load_field(episode_id, 'depth', fname) |
|
action = load_field(episode_id, 'action', fname) |
|
reward = load_field(episode_id, 'reward', fname) |
|
info = load_field(episode_id, 'info', fname) |
|
|
|
|
|
episode = [] |
|
for i in range(len(action)): |
|
obs = {'color': color[i], 'depth': depth[i]} if images else {} |
|
episode.append((obs, action[i], reward[i], info[i])) |
|
return episode, seed |
|
|
|
print(f'{episode_id:06d} not in ', path) |
|
|
|
def get_image(self, obs, cam_config=None): |
|
"""Stack color and height images image.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if cam_config is None: |
|
cam_config = self.cam_config |
|
|
|
|
|
cmap, hmap = utils.get_fused_heightmap( |
|
obs, cam_config, self.bounds, self.pix_size) |
|
img = np.concatenate((cmap, |
|
hmap[Ellipsis, None], |
|
hmap[Ellipsis, None], |
|
hmap[Ellipsis, None]), axis=2) |
|
assert img.shape == self.in_shape, img.shape |
|
return img |
|
|
|
def process_sample(self, datum, augment=True): |
|
|
|
(obs, act, _, info) = datum |
|
img = self.get_image(obs) |
|
|
|
|
|
|
|
|
|
p0, p1 = np.zeros(1), np.zeros(1) |
|
p0_theta, p1_theta = np.zeros(1), np.zeros(1) |
|
perturb_params = np.zeros(5) |
|
|
|
if act: |
|
p0_xyz, p0_xyzw = act['pose0'] |
|
p1_xyz, p1_xyzw = act['pose1'] |
|
p0 = utils.xyz_to_pix(p0_xyz, self.bounds, self.pix_size) |
|
p0_theta = -np.float32(utils.quatXYZW_to_eulerXYZ(p0_xyzw)[2]) |
|
p1 = utils.xyz_to_pix(p1_xyz, self.bounds, self.pix_size) |
|
p1_theta = -np.float32(utils.quatXYZW_to_eulerXYZ(p1_xyzw)[2]) |
|
p1_theta = p1_theta - p0_theta |
|
p0_theta = 0 |
|
|
|
|
|
if augment: |
|
img, _, (p0, p1), perturb_params = utils.perturb(img, [p0, p1], theta_sigma=self.aug_theta_sigma) |
|
|
|
if self.cfg['train']['data_augmentation']: |
|
|
|
|
|
|
|
color = img[...,:3] |
|
depth = img[...,3:] |
|
original_color = color.copy() |
|
original_depth = depth.copy() |
|
|
|
from cliport.utils.dataaug import chromatic_transform, add_noise, add_noise_depth |
|
if np.random.rand(1) > 0.1: |
|
color = chromatic_transform(color.astype(np.uint8)) |
|
if np.random.rand(1) > 0.1: |
|
color = add_noise(color) |
|
if np.random.rand(1) > 0.1: |
|
depth = add_noise_depth(depth) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
color = color.astype(np.float32) |
|
im = np.concatenate((color, depth), axis=-1) |
|
|
|
sample = { |
|
'img': img.copy(), |
|
'p0': np.array(p0).copy(), 'p0_theta': np.array(p0_theta).copy(), |
|
'p1': np.array(p1).copy(), 'p1_theta': np.array(p1_theta).copy() , |
|
'perturb_params': np.array(perturb_params).copy() |
|
} |
|
|
|
|
|
if 'lang_goal' not in info: |
|
warnings.warn("No language goal. Defaulting to 'task completed.'") |
|
|
|
if info and 'lang_goal' in info: |
|
sample['lang_goal'] = info['lang_goal'] |
|
else: |
|
sample['lang_goal'] = "task completed." |
|
|
|
return sample |
|
|
|
def process_goal(self, goal, perturb_params): |
|
|
|
(obs, act, _, info) = goal |
|
img = self.get_image(obs) |
|
|
|
|
|
|
|
|
|
p0, p1 = np.zeros(1), np.zeros(1) |
|
p0_theta, p1_theta = np.zeros(1), np.zeros(1) |
|
|
|
|
|
|
|
if perturb_params is not None and len(perturb_params) > 1: |
|
img = utils.apply_perturbation(img, perturb_params) |
|
|
|
sample = { |
|
'img': img.copy(), |
|
'p0': p0 , 'p0_theta': np.array(p0_theta).copy(), |
|
'p1': p1, 'p1_theta': np.array(p1_theta).copy(), |
|
'perturb_params': np.array(perturb_params).copy() |
|
} |
|
|
|
|
|
if 'lang_goal' not in info: |
|
warnings.warn("No language goal. Defaulting to 'task completed.'") |
|
|
|
|
|
if info and 'lang_goal' in info: |
|
sample['lang_goal'] = info['lang_goal'] |
|
else: |
|
sample['lang_goal'] = "task completed." |
|
|
|
return sample |
|
|
|
def __len__(self): |
|
return len(self.sample_set) |
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
|
|
|
|
|
|
episode_id = self.sample_set[idx] |
|
res = self.load(episode_id, self.images, self.cache) |
|
if res is None: |
|
print("in get item", episode_id, self._path) |
|
print("load sample return None. Reload") |
|
print("Exception:", str(traceback.format_exc())) |
|
return self[0] |
|
|
|
episode, _ = res |
|
|
|
is_sequential_task = '-seq' in self._path.split("/")[-1] |
|
|
|
|
|
i = np.random.choice(range(len(episode)-1)) |
|
g = i+1 if is_sequential_task else -1 |
|
sample, goal = episode[i], episode[g] |
|
|
|
|
|
sample = self.process_sample(sample, augment=self.augment) |
|
goal = self.process_goal(goal, perturb_params=sample['perturb_params']) |
|
return sample, goal |
|
|
|
|
|
class RavensMultiTaskDataset(RavensDataset): |
|
|
|
|
|
def __init__(self, path, cfg, group='multi-all', |
|
mode='train', n_demos=100, augment=False): |
|
"""A multi-task dataset.""" |
|
self.root_path = path |
|
self.mode = mode |
|
if group not in self.MULTI_TASKS: |
|
|
|
self.tasks = list(set(group)) |
|
else: |
|
self.tasks = self.MULTI_TASKS[group][mode] |
|
|
|
print("self.tasks:", self.tasks) |
|
self.attr_train_task = self.MULTI_TASKS[group]['attr_train_task'] if group in self.MULTI_TASKS and 'attr_train_task' in self.MULTI_TASKS[group] else None |
|
|
|
self.cfg = cfg |
|
self.sample_set = {} |
|
self.max_seed = -1 |
|
self.n_episodes = 0 |
|
self.images = self.cfg['dataset']['images'] |
|
self.cache = self.cfg['dataset']['cache'] |
|
self.n_demos = n_demos |
|
self.augment = augment |
|
|
|
self.aug_theta_sigma = self.cfg['dataset']['augment']['theta_sigma'] if 'augment' in self.cfg['dataset'] else 60 |
|
self.pix_size = 0.003125 |
|
self.in_shape = (320, 160, 6) |
|
self.cam_config = cameras.RealSenseD415.CONFIG |
|
self.bounds = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.28]]) |
|
|
|
self.n_episodes = {} |
|
episodes = {} |
|
|
|
for task in self.tasks: |
|
task_path = os.path.join(self.root_path, f'{task}-{mode}') |
|
action_path = os.path.join(task_path, 'action') |
|
n_episodes = 0 |
|
if os.path.exists(action_path): |
|
for fname in sorted(os.listdir(action_path)): |
|
if '.pkl' in fname: |
|
n_episodes += 1 |
|
self.n_episodes[task] = n_episodes |
|
|
|
if n_episodes == 0: |
|
raise Exception(f"{task}-{mode} has 0 episodes. Remove it from the list in dataset.py") |
|
|
|
|
|
episodes[task] = np.random.choice(range(self.n_demos), min(self.n_demos, n_episodes), False) |
|
|
|
if self.n_demos > 0: |
|
self.images = self.cfg['dataset']['images'] |
|
self.cache = False |
|
self.set(episodes) |
|
|
|
self._path = None |
|
self._task = None |
|
|
|
def __len__(self): |
|
|
|
total_episodes = 0 |
|
for _, episode_ids in self.sample_set.items(): |
|
total_episodes += len(episode_ids) |
|
avg_episodes = total_episodes |
|
return avg_episodes |
|
|
|
def __getitem__(self, idx): |
|
|
|
self._task = self.tasks[idx % len(self.tasks)] |
|
self._path = os.path.join(self.root_path, f'{self._task}') |
|
|
|
|
|
if len(self.sample_set[self._task]) > 0: |
|
episode_id = np.random.choice(self.sample_set[self._task]) |
|
else: |
|
episode_id = np.random.choice(range(self.n_episodes[self._task])) |
|
|
|
res = self.load(episode_id, self.images, self.cache) |
|
if res is None: |
|
print("failed in get item", episode_id, self._task, self._path) |
|
print("Exception:", str(traceback.format_exc())) |
|
|
|
return self[np.random.randint(len(self))] |
|
|
|
episode, _ = res |
|
|
|
|
|
is_sequential_task = '-seq' in self._path.split("/")[-1] |
|
|
|
|
|
if len(episode) > 1: |
|
i = np.random.choice(range(len(episode)-1)) |
|
g = i+1 if is_sequential_task else -1 |
|
sample, goal = episode[i], episode[g] |
|
else: |
|
sample, goal = episode[0], episode[0] |
|
|
|
|
|
sample = self.process_sample(sample, augment=self.augment) |
|
goal = self.process_goal(goal, perturb_params=sample['perturb_params']) |
|
|
|
return sample, goal |
|
|
|
def add(self, seed, episode): |
|
raise Exception("Adding tasks not supported with multi-task dataset") |
|
|
|
def load(self, episode_id, images=True, cache=False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._path = os.path.join(self.root_path, f'{self._task}-{self.mode}') |
|
return super().load(episode_id, images, cache) |
|
|
|
def get_curr_task(self): |
|
return self._task |
|
|
|
|
|
MULTI_TASKS = { |
|
|
|
'multi-gpt-test': { |
|
'train': ['align-box-corner', 'rainbow-stack'], |
|
'val': ['align-box-corner', 'rainbow-stack'], |
|
'test': ['align-box-corner', 'rainbow-stack'] |
|
}, |
|
|
|
|
|
'multi-all': { |
|
'train': [ |
|
'align-box-corner', |
|
'assembling-kits', |
|
'block-insertion', |
|
'manipulating-rope', |
|
'packing-boxes', |
|
'palletizing-boxes', |
|
'place-red-in-green', |
|
'stack-block-pyramid', |
|
'sweeping-piles', |
|
'towers-of-hanoi', |
|
'align-rope', |
|
'assembling-kits-seq-unseen-colors', |
|
'packing-boxes-pairs-unseen-colors', |
|
'packing-shapes', |
|
'packing-unseen-google-objects-seq', |
|
'packing-unseen-google-objects-group', |
|
'put-block-in-bowl-unseen-colors', |
|
'stack-block-pyramid-seq-unseen-colors', |
|
'separating-piles-unseen-colors', |
|
'towers-of-hanoi-seq-unseen-colors', |
|
], |
|
'val': [ |
|
'align-box-corner', |
|
'assembling-kits', |
|
'block-insertion', |
|
'manipulating-rope', |
|
'packing-boxes', |
|
'palletizing-boxes', |
|
'place-red-in-green', |
|
'stack-block-pyramid', |
|
'sweeping-piles', |
|
'towers-of-hanoi', |
|
'align-rope', |
|
'assembling-kits-seq-seen-colors', |
|
'assembling-kits-seq-unseen-colors', |
|
'packing-boxes-pairs-seen-colors', |
|
'packing-boxes-pairs-unseen-colors', |
|
'packing-shapes', |
|
'packing-seen-google-objects-seq', |
|
'packing-unseen-google-objects-seq', |
|
'packing-seen-google-objects-group', |
|
'packing-unseen-google-objects-group', |
|
'put-block-in-bowl-seen-colors', |
|
'put-block-in-bowl-unseen-colors', |
|
'stack-block-pyramid-seq-seen-colors', |
|
'stack-block-pyramid-seq-unseen-colors', |
|
'separating-piles-seen-colors', |
|
'separating-piles-unseen-colors', |
|
'towers-of-hanoi-seq-seen-colors', |
|
'towers-of-hanoi-seq-unseen-colors', |
|
], |
|
'test': [ |
|
'align-box-corner', |
|
'assembling-kits', |
|
'block-insertion', |
|
'manipulating-rope', |
|
'packing-boxes', |
|
'palletizing-boxes', |
|
'place-red-in-green', |
|
'stack-block-pyramid', |
|
'sweeping-piles', |
|
'towers-of-hanoi', |
|
'align-rope', |
|
'assembling-kits-seq-seen-colors', |
|
'assembling-kits-seq-unseen-colors', |
|
'packing-boxes-pairs-seen-colors', |
|
'packing-boxes-pairs-unseen-colors', |
|
'packing-shapes', |
|
'packing-seen-google-objects-seq', |
|
'packing-unseen-google-objects-seq', |
|
'packing-seen-google-objects-group', |
|
'packing-unseen-google-objects-group', |
|
'put-block-in-bowl-seen-colors', |
|
'put-block-in-bowl-unseen-colors', |
|
'stack-block-pyramid-seq-seen-colors', |
|
'stack-block-pyramid-seq-unseen-colors', |
|
'separating-piles-seen-colors', |
|
'separating-piles-unseen-colors', |
|
'towers-of-hanoi-seq-seen-colors', |
|
'towers-of-hanoi-seq-unseen-colors', |
|
], |
|
}, |
|
|
|
|
|
'multi-demo-conditioned': { |
|
'train': [ |
|
'align-box-corner', |
|
'assembling-kits', |
|
'block-insertion', |
|
'manipulating-rope', |
|
'packing-boxes', |
|
'palletizing-boxes', |
|
'place-red-in-green', |
|
'stack-block-pyramid', |
|
'sweeping-piles', |
|
'towers-of-hanoi', |
|
], |
|
'val': [ |
|
'align-box-corner', |
|
'assembling-kits', |
|
'block-insertion', |
|
'manipulating-rope', |
|
'packing-boxes', |
|
'palletizing-boxes', |
|
'place-red-in-green', |
|
'stack-block-pyramid', |
|
'sweeping-piles', |
|
'towers-of-hanoi', |
|
], |
|
'test': [ |
|
'align-box-corner', |
|
'assembling-kits', |
|
'block-insertion', |
|
'manipulating-rope', |
|
'packing-boxes', |
|
'palletizing-boxes', |
|
'place-red-in-green', |
|
'stack-block-pyramid', |
|
'sweeping-piles', |
|
'towers-of-hanoi', |
|
], |
|
}, |
|
|
|
|
|
'multi-language-conditioned': { |
|
'train': [ |
|
'align-rope', |
|
'assembling-kits-seq-unseen-colors', |
|
'packing-boxes-pairs-unseen-colors', |
|
'packing-shapes', |
|
'packing-unseen-google-objects-seq', |
|
'packing-unseen-google-objects-group', |
|
'put-block-in-bowl-unseen-colors', |
|
'stack-block-pyramid-seq-unseen-colors', |
|
'separating-piles-unseen-colors', |
|
'towers-of-hanoi-seq-unseen-colors', |
|
], |
|
'val': [ |
|
'align-rope', |
|
'assembling-kits-seq-seen-colors', |
|
'assembling-kits-seq-unseen-colors', |
|
'packing-boxes-pairs-seen-colors', |
|
'packing-boxes-pairs-unseen-colors', |
|
'packing-shapes', |
|
'packing-seen-google-objects-seq', |
|
'packing-unseen-google-objects-seq', |
|
'packing-seen-google-objects-group', |
|
'packing-unseen-google-objects-group', |
|
'put-block-in-bowl-seen-colors', |
|
'put-block-in-bowl-unseen-colors', |
|
'stack-block-pyramid-seq-seen-colors', |
|
'stack-block-pyramid-seq-unseen-colors', |
|
'separating-piles-seen-colors', |
|
'separating-piles-unseen-colors', |
|
'towers-of-hanoi-seq-seen-colors', |
|
'towers-of-hanoi-seq-unseen-colors', |
|
], |
|
'test': [ |
|
'align-rope', |
|
'assembling-kits-seq-seen-colors', |
|
'assembling-kits-seq-unseen-colors', |
|
'packing-boxes-pairs-seen-colors', |
|
'packing-boxes-pairs-unseen-colors', |
|
'packing-shapes', |
|
'packing-seen-google-objects-seq', |
|
'packing-unseen-google-objects-seq', |
|
'packing-seen-google-objects-group', |
|
'packing-unseen-google-objects-group', |
|
'put-block-in-bowl-seen-colors', |
|
'put-block-in-bowl-unseen-colors', |
|
'stack-block-pyramid-seq-seen-colors', |
|
'stack-block-pyramid-seq-unseen-colors', |
|
'separating-piles-seen-colors', |
|
'separating-piles-unseen-colors', |
|
'towers-of-hanoi-seq-seen-colors', |
|
'towers-of-hanoi-seq-unseen-colors', |
|
], |
|
}, |
|
|
|
|
|
|
|
'multi-attr-align-rope': { |
|
'train': [ |
|
'assembling-kits-seq-full', |
|
'packing-boxes-pairs-full', |
|
'packing-shapes', |
|
'packing-seen-google-objects-seq', |
|
'packing-seen-google-objects-group', |
|
'put-block-in-bowl-full', |
|
'stack-block-pyramid-seq-full', |
|
'separating-piles-full', |
|
'towers-of-hanoi-seq-full', |
|
], |
|
'val': [ |
|
'align-rope', |
|
], |
|
'test': [ |
|
'align-rope', |
|
], |
|
'attr_train_task': None, |
|
}, |
|
|
|
'multi-attr-packing-shapes': { |
|
'train': [ |
|
'align-rope', |
|
'assembling-kits-seq-full', |
|
'packing-boxes-pairs-full', |
|
'packing-seen-google-objects-seq', |
|
'packing-seen-google-objects-group', |
|
'put-block-in-bowl-full', |
|
'stack-block-pyramid-seq-full', |
|
'separating-piles-full', |
|
'towers-of-hanoi-seq-full', |
|
], |
|
'val': [ |
|
'packing-shapes', |
|
], |
|
'test': [ |
|
'packing-shapes', |
|
], |
|
'attr_train_task': None, |
|
}, |
|
|
|
'multi-attr-assembling-kits-seq-unseen-colors': { |
|
'train': [ |
|
'align-rope', |
|
'assembling-kits-seq-seen-colors', |
|
'packing-boxes-pairs-full', |
|
'packing-shapes', |
|
'packing-seen-google-objects-seq', |
|
'packing-seen-google-objects-group', |
|
'put-block-in-bowl-full', |
|
'stack-block-pyramid-seq-full', |
|
'separating-piles-full', |
|
'towers-of-hanoi-seq-full', |
|
], |
|
'val': [ |
|
'assembling-kits-seq-unseen-colors', |
|
], |
|
'test': [ |
|
'assembling-kits-seq-unseen-colors', |
|
], |
|
'attr_train_task': 'assembling-kits-seq-seen-colors', |
|
}, |
|
|
|
'multi-attr-packing-boxes-pairs-unseen-colors': { |
|
'train': [ |
|
'align-rope', |
|
'assembling-kits-seq-full', |
|
'packing-boxes-pairs-seen-colors', |
|
'packing-shapes', |
|
'packing-seen-google-objects-seq', |
|
'packing-seen-google-objects-group', |
|
'put-block-in-bowl-full', |
|
'stack-block-pyramid-seq-full', |
|
'separating-piles-full', |
|
'towers-of-hanoi-seq-full', |
|
], |
|
'val': [ |
|
'packing-boxes-pairs-unseen-colors', |
|
], |
|
'test': [ |
|
'packing-boxes-pairs-unseen-colors', |
|
], |
|
'attr_train_task': 'packing-boxes-pairs-seen-colors', |
|
}, |
|
|
|
'multi-attr-packing-unseen-google-objects-seq': { |
|
'train': [ |
|
'align-rope', |
|
'assembling-kits-seq-full', |
|
'packing-boxes-pairs-full', |
|
'packing-shapes', |
|
'packing-seen-google-objects-group', |
|
'put-block-in-bowl-full', |
|
'stack-block-pyramid-seq-full', |
|
'separating-piles-full', |
|
'towers-of-hanoi-seq-full', |
|
], |
|
'val': [ |
|
'packing-unseen-google-objects-seq', |
|
], |
|
'test': [ |
|
'packing-unseen-google-objects-seq', |
|
], |
|
'attr_train_task': 'packing-seen-google-objects-group', |
|
}, |
|
|
|
'multi-attr-packing-unseen-google-objects-group': { |
|
'train': [ |
|
'align-rope', |
|
'assembling-kits-seq-full', |
|
'packing-boxes-pairs-full', |
|
'packing-shapes', |
|
'packing-seen-google-objects-seq', |
|
'put-block-in-bowl-full', |
|
'stack-block-pyramid-seq-full', |
|
'separating-piles-full', |
|
'towers-of-hanoi-seq-full', |
|
], |
|
'val': [ |
|
'packing-unseen-google-objects-group', |
|
], |
|
'test': [ |
|
'packing-unseen-google-objects-group', |
|
], |
|
'attr_train_task': 'packing-seen-google-objects-seq', |
|
}, |
|
|
|
'multi-attr-put-block-in-bowl-unseen-colors': { |
|
'train': [ |
|
'align-rope', |
|
'assembling-kits-seq-full', |
|
'packing-boxes-pairs-full', |
|
'packing-shapes', |
|
'packing-seen-google-objects-seq', |
|
'packing-seen-google-objects-group', |
|
'put-block-in-bowl-seen-colors', |
|
'stack-block-pyramid-seq-full', |
|
'separating-piles-full', |
|
'towers-of-hanoi-seq-full', |
|
], |
|
'val': [ |
|
'put-block-in-bowl-unseen-colors', |
|
], |
|
'test': [ |
|
'put-block-in-bowl-unseen-colors', |
|
], |
|
'attr_train_task': 'put-block-in-bowl-seen-colors', |
|
}, |
|
|
|
'multi-attr-stack-block-pyramid-seq-unseen-colors': { |
|
'train': [ |
|
'align-rope', |
|
'assembling-kits-seq-full', |
|
'packing-boxes-pairs-full', |
|
'packing-shapes', |
|
'packing-seen-google-objects-seq', |
|
'packing-seen-google-objects-group', |
|
'put-block-in-bowl-full', |
|
'stack-block-pyramid-seq-seen-colors', |
|
'separating-piles-full', |
|
'towers-of-hanoi-seq-full', |
|
], |
|
'val': [ |
|
'stack-block-pyramid-seq-unseen-colors', |
|
], |
|
'test': [ |
|
'stack-block-pyramid-seq-unseen-colors', |
|
], |
|
'attr_train_task': 'stack-block-pyramid-seq-seen-colors', |
|
}, |
|
|
|
'multi-attr-separating-piles-unseen-colors': { |
|
'train': [ |
|
'align-rope', |
|
'assembling-kits-seq-full', |
|
'packing-boxes-pairs-full', |
|
'packing-shapes', |
|
'packing-seen-google-objects-seq', |
|
'packing-seen-google-objects-group', |
|
'put-block-in-bowl-full', |
|
'stack-block-pyramid-seq-full', |
|
'separating-piles-seen-colors', |
|
'towers-of-hanoi-seq-full', |
|
], |
|
'val': [ |
|
'separating-piles-unseen-colors', |
|
], |
|
'test': [ |
|
'separating-piles-unseen-colors', |
|
], |
|
'attr_train_task': 'separating-piles-seen-colors', |
|
}, |
|
|
|
'multi-attr-towers-of-hanoi-seq-unseen-colors': { |
|
'train': [ |
|
'align-rope', |
|
'assembling-kits-seq-full', |
|
'packing-boxes-pairs-full', |
|
'packing-shapes', |
|
'packing-seen-google-objects-seq', |
|
'packing-seen-google-objects-group', |
|
'put-block-in-bowl-full', |
|
'stack-block-pyramid-seq-full', |
|
'separating-piles-full', |
|
'towers-of-hanoi-seq-seen-colors', |
|
], |
|
'val': [ |
|
'towers-of-hanoi-seq-unseen-colors', |
|
], |
|
'test': [ |
|
'towers-of-hanoi-seq-unseen-colors', |
|
], |
|
'attr_train_task': 'towers-of-hanoi-seq-seen-colors', |
|
}, |
|
|
|
} |
|
|
|
|
|
|
|
class RavenMultiTaskDatasetBalance(RavensMultiTaskDataset): |
|
def __init__(self, path, cfg, group='multi-all', |
|
mode='train', n_demos=100, augment=False, balance_weight=0.1): |
|
"""A multi-task dataset for balancing data.""" |
|
self.root_path = path |
|
self.mode = mode |
|
if group not in self.MULTI_TASKS: |
|
|
|
self.tasks = group |
|
else: |
|
self.tasks = self.MULTI_TASKS[group][mode] |
|
|
|
print("self.tasks:", self.tasks) |
|
self.attr_train_task = self.MULTI_TASKS[group]['attr_train_task'] if group in self.MULTI_TASKS and 'attr_train_task' in self.MULTI_TASKS[group] else None |
|
|
|
self.cfg = cfg |
|
self.sample_set = {} |
|
self.max_seed = -1 |
|
self.n_episodes = 0 |
|
self.images = self.cfg['dataset']['images'] |
|
self.cache = self.cfg['dataset']['cache'] |
|
self.n_demos = n_demos |
|
self.augment = augment |
|
|
|
self.aug_theta_sigma = self.cfg['dataset']['augment']['theta_sigma'] if 'augment' in self.cfg['dataset'] else 60 |
|
self.pix_size = 0.003125 |
|
self.in_shape = (320, 160, 6) |
|
self.cam_config = cameras.RealSenseD415.CONFIG |
|
self.bounds = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.28]]) |
|
|
|
self.n_episodes = {} |
|
episodes = {} |
|
|
|
for task in self.tasks: |
|
task_path = os.path.join(self.root_path, f'{task}-{mode}') |
|
action_path = os.path.join(task_path, 'action') |
|
n_episodes = 0 |
|
if os.path.exists(action_path): |
|
for fname in sorted(os.listdir(action_path)): |
|
if '.pkl' in fname: |
|
n_episodes += 1 |
|
self.n_episodes[task] = n_episodes |
|
|
|
if n_episodes == 0: |
|
raise Exception(f"{task}-{mode} has 0 episodes. Remove it from the list in dataset.py") |
|
|
|
|
|
if task in self.ORIGINAL_NAMES and self.mode == 'train': |
|
assert self.n_demos < 200 |
|
episodes[task] = np.random.choice(range(n_episodes), min(int(self.n_demos*balance_weight), n_episodes), False) |
|
else: |
|
episodes[task] = np.random.choice(range(n_episodes), min(self.n_demos, n_episodes), False) |
|
|
|
if self.n_demos > 0: |
|
self.images = self.cfg['dataset']['images'] |
|
self.cache = False |
|
self.set(episodes) |
|
|
|
self._path = None |
|
self._task = None |
|
|
|
|
|
|
|
ORIGINAL_NAMES = [ |
|
|
|
'align-box-corner', |
|
'assembling-kits', |
|
'assembling-kits-easy', |
|
'block-insertion', |
|
'block-insertion-easy', |
|
'block-insertion-nofixture', |
|
'block-insertion-sixdof', |
|
'block-insertion-translation', |
|
'manipulating-rope', |
|
'packing-boxes', |
|
'palletizing-boxes', |
|
'place-red-in-green', |
|
'stack-block-pyramid', |
|
'sweeping-piles', |
|
'towers-of-hanoi', |
|
'gen-task', |
|
|
|
'align-rope', |
|
'assembling-kits-seq', |
|
'assembling-kits-seq-seen-colors', |
|
'assembling-kits-seq-unseen-colors', |
|
'assembling-kits-seq-full', |
|
'packing-shapes', |
|
'packing-boxes-pairs', |
|
'packing-boxes-pairs-seen-colors', |
|
'packing-boxes-pairs-unseen-colors', |
|
'packing-boxes-pairs-full', |
|
'packing-seen-google-objects-seq', |
|
'packing-unseen-google-objects-seq', |
|
'packing-seen-google-objects-group', |
|
'packing-unseen-google-objects-group', |
|
'put-block-in-bowl', |
|
'put-block-in-bowl-seen-colors', |
|
'put-block-in-bowl-unseen-colors', |
|
'put-block-in-bowl-full', |
|
'stack-block-pyramid-seq', |
|
'stack-block-pyramid-seq-seen-colors', |
|
'stack-block-pyramid-seq-unseen-colors', |
|
'stack-block-pyramid-seq-full', |
|
'separating-piles', |
|
'separating-piles-seen-colors', |
|
'separating-piles-unseen-colors', |
|
'separating-piles-full', |
|
'towers-of-hanoi-seq', |
|
'towers-of-hanoi-seq-seen-colors', |
|
'towers-of-hanoi-seq-unseen-colors', |
|
'towers-of-hanoi-seq-full', |
|
] |