File size: 2,216 Bytes
1cc747d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import numpy as np import os import pybullet as p import random from cliport.tasks import primitives from cliport.tasks.grippers import Spatula from cliport.tasks.task import Task from cliport.utils import utils import numpy as np from cliport.tasks.task import Task from cliport.utils import utils class InsertAndStack(Task): """Insert the ell into the fixture and then stack blocks on top of it.""" def __init__(self): super().__init__() self.max_steps = 10 self.metric = 'pose' self.lang_template = "insert the ell into the fixture and then stack blocks on top of it" self.task_completed_desc = "done insert-and-stack." def reset(self, env): super().reset(env) # Add ell ell_size = (0.1, 0.1, 0.1) ell_pose = self.get_random_pose(env, ell_size) ell_urdf = 'insertion/ell.urdf' ell_id = env.add_object(ell_urdf, ell_pose) self.color_random_bright(ell_id) # Add fixture fixture_size = (0.12, 0.12, 0.1) fixture_pose = self.get_random_pose(env, fixture_size) fixture_urdf = 'insertion/fixture.urdf' fixture_id = env.add_object(fixture_urdf, fixture_pose) self.color_random_bright(fixture_id) # Add blocks block_size = (0.04, 0.04, 0.04) block_urdf = 'stacking/block.urdf' blocks = [] for _ in range(3): # We want 3 blocks block_pose = self.get_random_pose(env, block_size) block_id = env.add_object(block_urdf, block_pose) self.color_random_bright(block_id) # Randomly color the blocks blocks.append(block_id) # Define the zone zone_size = (0.1, 0.1, 0.1) zone_pose = self.get_random_pose(env, zone_size) zone_urdf = 'zone/zone.urdf' env.add_object(zone_urdf, zone_pose, 'fixed') # Zone is static # Set task goals objs = [ell_id] + blocks goal_poses = [fixture_pose] * (len(objs)) self.add_goal(objs=objs, matches=np.ones((len(objs), 1)), targ_poses=goal_poses, replace=False, rotations=True, metric='pose', params=None, step_max_reward=1, language_goal=self.lang_template) |