GenSim3 / cliport /tests /tasks_test.py
gensim2's picture
unlfs
1cc747d
"""Integration tests for dvnets tasks."""
from absl.testing import absltest
from absl.testing import parameterized
from cliport import tasks
from cliport.environments import environment
ASSETS_PATH = 'cliport/environments/assets/'
class TaskTest(parameterized.TestCase):
def _create_env(self):
assets_root = ASSETS_PATH
env = environment.Environment(assets_root)
env.seed(0)
return env
def _run_oracle_in_env(self, env):
agent = env.task.oracle(env)
obs = env.reset()
info = None
done = False
for _ in range(10):
act = agent.act(obs, info)
obs, _, done, info = env.step(act)
if done:
break
@parameterized.named_parameters((
# demo conditioned
'AlignBoxCorner',
tasks.AlignBoxCorner(),
), (
'AssemblingKits',
tasks.AssemblingKits(),
), (
'AssemblingKitsEasy',
tasks.AssemblingKitsEasy(),
), (
'BlockInsertion',
tasks.BlockInsertion(),
), (
'ManipulatingRope',
tasks.ManipulatingRope(),
), (
'PackingBoxes',
tasks.PackingBoxes(),
), (
'PalletizingBoxes',
tasks.PalletizingBoxes(),
), (
'PlaceRedInGreen',
tasks.PlaceRedInGreen(),
), (
'StackBlockPyramid',
tasks.StackBlockPyramid(),
), (
'SweepingPiles',
tasks.SweepingPiles(),
), (
'TowersOfHanoi',
tasks.TowersOfHanoi(),
# goal conditioned
), (
'AlignRope',
tasks.AlignRope(),
), (
'AssemblingKitsSeqSeenColors',
tasks.AssemblingKitsSeqSeenColors(),
), (
'AssemblingKitsSeqUnseenColors',
tasks.AssemblingKitsSeqUnseenColors(),
), (
'AssemblingKitsSeqFull',
tasks.AssemblingKitsSeqFull(),
), (
'PackingShapes',
tasks.PackingShapes(),
), (
'PackingBoxesPairsSeenColors',
tasks.PackingBoxesPairsSeenColors(),
), (
'PackingBoxesPairsUnseenColors',
tasks.PackingBoxesPairsUnseenColors(),
), (
'PackingBoxesPairsFull',
tasks.PackingBoxesPairsFull(),
), (
'PackingSeenGoogleObjectsSeq',
tasks.PackingSeenGoogleObjectsSeq(),
), (
'PackingUnseenGoogleObjectsSeq',
tasks.PackingUnseenGoogleObjectsSeq(),
), (
'PackingSeenGoogleObjectsGroup',
tasks.PackingSeenGoogleObjectsGroup(),
), (
'PackingUnseenGoogleObjectsGroup',
tasks.PackingUnseenGoogleObjectsGroup(),
), (
'PutBlockInBowlSeenColors',
tasks.PutBlockInBowlSeenColors(),
), (
'PutBlockInBowlUnseenColors',
tasks.PutBlockInBowlUnseenColors(),
), (
'PutBlockInBowlFull',
tasks.PutBlockInBowlFull(),
), (
'StackBlockPyramidSeqSeenColors',
tasks.StackBlockPyramidSeqSeenColors(),
), (
'StackBlockPyramidSeqUnseenColors',
tasks.StackBlockPyramidSeqUnseenColors(),
), (
'StackBlockPyramidSeqFull',
tasks.StackBlockPyramidSeqFull(),
), (
'SeparatingPilesSeenColors',
tasks.SeparatingPilesUnseenColors(),
), (
'SeparatingPilesUnseenColors',
tasks.SeparatingPilesUnseenColors(),
), (
'SeparatingPilesFull',
tasks.SeparatingPilesFull(),
), (
'TowersOfHanoiSeqSeenColors',
tasks.TowersOfHanoiSeqSeenColors(),
), (
'TowersOfHanoiSeqUnseenColors',
tasks.TowersOfHanoiSeqUnseenColors(),
), (
'TowersOfHanoiSeqFull',
tasks.TowersOfHanoiSeqFull(),
))
def test_all_tasks(self, dvnets_task):
env = self._create_env()
env.set_task(dvnets_task)
self._run_oracle_in_env(env)
if __name__ == '__main__':
absltest.main()