|
import numpy as np |
|
import os |
|
import IPython |
|
from cliport import tasks |
|
from cliport.dataset import RavensDataset |
|
from cliport.environments.environment import Environment |
|
|
|
from pygments import highlight |
|
from pygments.lexers import PythonLexer |
|
from pygments.formatters import TerminalFormatter |
|
|
|
import time |
|
import random |
|
import json |
|
import traceback |
|
from gensim.utils import ( |
|
mkdir_if_missing, |
|
save_text, |
|
save_stat, |
|
compute_diversity_score_from_assets, |
|
add_to_txt |
|
) |
|
import pybullet as p |
|
|
|
class SimulationRunner: |
|
""" the main class that runs simulation loop """ |
|
def __init__(self, cfg, agent, critic, memory): |
|
self.cfg = cfg |
|
self.agent = agent |
|
self.critic = critic |
|
self.memory = memory |
|
|
|
|
|
self.syntax_pass_rate = 0 |
|
self.runtime_pass_rate = 0 |
|
self.env_pass_rate = 0 |
|
self.curr_trials = 0 |
|
|
|
self.prompt_folder = f"prompts/{cfg['prompt_folder']}" |
|
self.chat_log = memory.chat_log |
|
self.task_asset_logs = [] |
|
|
|
|
|
|
|
self.generated_task_assets = [] |
|
self.generated_task_programs = [] |
|
self.generated_task_names = [] |
|
self.generated_tasks = [] |
|
self.passed_tasks = [] |
|
|
|
def print_current_stats(self): |
|
""" print the current statistics of the simulation design """ |
|
print("=========================================================") |
|
print(f"{self.cfg['prompt_folder']} Trial {self.curr_trials} SYNTAX_PASS_RATE: {(self.syntax_pass_rate / (self.curr_trials)) * 100:.1f}% RUNTIME_PASS_RATE: {(self.runtime_pass_rate / (self.curr_trials)) * 100:.1f}% ENV_PASS_RATE: {(self.env_pass_rate / (self.curr_trials)) * 100:.1f}%") |
|
print("=========================================================") |
|
|
|
def save_stats(self): |
|
""" save the final simulation statistics """ |
|
self.diversity_score = compute_diversity_score_from_assets(self.task_asset_logs, self.curr_trials) |
|
save_stat(self.cfg, self.cfg['model_output_dir'], self.generated_tasks, self.syntax_pass_rate / (self.curr_trials), |
|
self.runtime_pass_rate / (self.curr_trials), self.env_pass_rate / (self.curr_trials), self.diversity_score) |
|
print("Model Folder: ", self.cfg['model_output_dir']) |
|
print(f"Total {len(self.generated_tasks)} New Tasks:", [task['task-name'] for task in self.generated_tasks]) |
|
try: |
|
print(f"Added {len(self.passed_tasks)} Tasks:", self.passed_tasks) |
|
except: |
|
pass |
|
|
|
def example_task_creation(self): |
|
""" create the task through interactions of agent and critic """ |
|
self.task_creation_pass = True |
|
mkdir_if_missing(self.cfg['model_output_dir']) |
|
|
|
try: |
|
start_time = time.time() |
|
|
|
self.generated_task = {'task-name': 'TASK_NAME_TEMPLATE', 'task-description': 'TASK_STRING_TEMPLATE', 'assets-used': ['ASSET_1', 'ASSET_2', Ellipsis]} |
|
print("generated_task\n", self.generated_task) |
|
yield "Task Generated ==>", None, None |
|
self.generated_asset = self.agent.propose_assets() |
|
|
|
print("generated_asset\n", self.generated_asset) |
|
yield "Task Generated ==> Asset Generated ==> ", None, None |
|
yield "Task Generated ==> Asset Generated ==> API Reviewed ==> ", None, None |
|
yield "Task Generated ==> Asset Generated ==> API Reviewed ==> Error Reviewed ==> ", None, None |
|
|
|
self.curr_task_name = self.generated_task_name = 'BuildWheel' |
|
|
|
self.generated_code = """ |
|
import numpy as np |
|
from cliport.tasks.task import Task |
|
from cliport.utils import utils |
|
|
|
class BuildWheel(Task): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.max_steps = 10 |
|
self.lang_template = "Construct a wheel using blocks and a sphere. First, position eight blocks in a circular layout on the tabletop. Each block should be touching its two neighbors and colored in alternating red and blue. Then place a green sphere in the center of the circular layout, completing the wheel." |
|
self.task_completed_desc = "done building wheel." |
|
self.additional_reset() |
|
|
|
def reset(self, env): |
|
super().reset(env) |
|
|
|
# Add blocks. |
|
block_size = (0.04, 0.04, 0.04) |
|
block_urdf = 'block/block.urdf' |
|
block_colors = [utils.COLORS['red'], utils.COLORS['blue']] |
|
blocks = [] |
|
for i in range(8): |
|
block_pose = self.get_random_pose(env, block_size) |
|
block_id = env.add_object(block_urdf, block_pose, color=block_colors[i % 2]) |
|
blocks.append(block_id) |
|
|
|
# Add sphere. |
|
sphere_size = (0.04, 0.04, 0.04) |
|
sphere_urdf = 'sphere/sphere.urdf' |
|
sphere_color = utils.COLORS['green'] |
|
sphere_pose = ((0.5, 0.0, 0.0), (0,0,0,1)) # fixed pose |
|
sphere_id = env.add_object(sphere_urdf, sphere_pose, color=sphere_color) |
|
|
|
# Goal: blocks are arranged in a circle and sphere is in the center. |
|
circle_radius = 0.1 |
|
circle_center = (0, 0, block_size[2] / 2) |
|
angles = np.linspace(0, 2 * np.pi, 8, endpoint=False) |
|
block_poses = [(circle_center[0] + circle_radius * np.cos(angle), |
|
circle_center[1] + circle_radius * np.sin(angle), |
|
circle_center[2]) for angle in angles] |
|
block_poses = [(utils.apply(sphere_pose, pos), sphere_pose[1]) for pos in block_poses] |
|
self.add_goal(objs=blocks, matches=np.ones((8, 8)), targ_poses=block_poses, replace=False, |
|
rotations=True, metric='pose', params=None, step_max_reward=8 / 9) |
|
|
|
# Goal: sphere is in the center of the blocks. |
|
self.add_goal(objs=[sphere_id], matches=np.ones((1, 1)), targ_poses=[sphere_pose], replace=False, |
|
rotations=False, metric='pose', params=None, step_max_reward=1 / 9) |
|
|
|
self.lang_goals.append(self.lang_template) |
|
""" |
|
print("generated_code\n", self.generated_code) |
|
print("curr_task_name\n", self.curr_task_name) |
|
yield "Task Generated ==> Asset Generated ==> API Reviewed ==> Error Reviewed ==> Code Generated ==> ", self.generated_code, None |
|
|
|
self.generated_tasks.append(self.generated_task) |
|
self.generated_task_assets.append(self.generated_asset) |
|
self.generated_task_programs.append(self.generated_code) |
|
self.generated_task_names.append(self.generated_task_name) |
|
except: |
|
to_print = highlight(f"{str(traceback.format_exc())}", PythonLexer(), TerminalFormatter()) |
|
print("Task Creation Exception:", to_print) |
|
self.task_creation_pass = False |
|
|
|
|
|
print("task creation time {:.3f}".format(time.time() - start_time)) |
|
|
|
def task_creation(self): |
|
""" create the task through interactions of agent and critic """ |
|
self.task_creation_pass = True |
|
mkdir_if_missing(self.cfg['model_output_dir']) |
|
|
|
try: |
|
start_time = time.time() |
|
self.generated_task = self.agent.propose_task(self.generated_task_names) |
|
|
|
|
|
print("generated_task\n", self.generated_task) |
|
|
|
yield "Task Generated ==>", None, None |
|
|
|
self.generated_asset = self.agent.propose_assets() |
|
|
|
|
|
print("generated_asset\n", self.generated_asset) |
|
yield "Task Generated ==> Asset Generated ==> ", None, None |
|
|
|
self.agent.api_review() |
|
|
|
|
|
yield "Task Generated ==> Asset Generated ==> API Reviewed ==> ", None, None |
|
self.critic.error_review(self.generated_task) |
|
|
|
|
|
yield "Task Generated ==> Asset Generated ==> API Reviewed ==> Error Reviewed ==> ", None, None |
|
self.generated_code, self.curr_task_name = self.agent.implement_task() |
|
self.task_asset_logs.append(self.generated_task["assets-used"]) |
|
self.generated_task_name = self.generated_task["task-name"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("generated_code\n", self.generated_code) |
|
print("curr_task_name\n", self.curr_task_name) |
|
yield "Task Generated ==> Asset Generated ==> API Reviewed ==> Error Reviewed ==> Code Generated ==> ", self.generated_code, None |
|
|
|
self.generated_tasks.append(self.generated_task) |
|
self.generated_task_assets.append(self.generated_asset) |
|
self.generated_task_programs.append(self.generated_code) |
|
self.generated_task_names.append(self.generated_task_name) |
|
except: |
|
to_print = highlight(f"{str(traceback.format_exc())}", PythonLexer(), TerminalFormatter()) |
|
print("Task Creation Exception:", to_print) |
|
self.task_creation_pass = False |
|
|
|
|
|
print("task creation time {:.3f}".format(time.time() - start_time)) |
|
|
|
|
|
def setup_env(self): |
|
""" build the new task""" |
|
env = Environment( |
|
self.cfg['assets_root'], |
|
disp=self.cfg['disp'], |
|
shared_memory=self.cfg['shared_memory'], |
|
hz=480, |
|
record_cfg=self.cfg['record'] |
|
) |
|
|
|
task = eval(self.curr_task_name)() |
|
task.mode = self.cfg['mode'] |
|
record = self.cfg['record']['save_video'] |
|
save_data = self.cfg['save_data'] |
|
|
|
|
|
expert = task.oracle(env) |
|
self.cfg['task'] = self.generated_task["task-name"] |
|
data_path = os.path.join(self.cfg['data_dir'], "{}-{}".format(self.generated_task["task-name"], task.mode)) |
|
dataset = RavensDataset(data_path, self.cfg, n_demos=0, augment=False) |
|
print(f"Saving to: {data_path}") |
|
print(f"Mode: {task.mode}") |
|
|
|
|
|
if record: |
|
env.start_rec(f'{dataset.n_episodes+1:06d}') |
|
|
|
return task, dataset, env, expert |
|
|
|
def run_one_episode(self, dataset, expert, env, task, episode, seed): |
|
""" run the new task for one episode """ |
|
add_to_txt( |
|
self.chat_log, f"================= TRIAL: {self.curr_trials}", with_print=True) |
|
record = self.cfg['record']['save_video'] |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
print('Oracle demo: {}/{} | Seed: {}'.format(dataset.n_episodes + 1, self.cfg['n'], seed)) |
|
env.set_task(task) |
|
obs = env.reset() |
|
|
|
info = env.info |
|
reward = 0 |
|
total_reward = 0 |
|
|
|
|
|
for _ in range(task.max_steps): |
|
act = expert.act(obs, info) |
|
episode.append((obs, act, reward, info)) |
|
lang_goal = info['lang_goal'] |
|
obs, reward, done, info = env.step(act) |
|
total_reward += reward |
|
print(f'Total Reward: {total_reward:.3f} | Done: {done} | Goal: {lang_goal}') |
|
if done: |
|
break |
|
|
|
episode.append((obs, None, reward, info)) |
|
return total_reward |
|
|
|
def simulate_task(self): |
|
""" simulate the created task and save demonstrations """ |
|
total_cnt = 0. |
|
reset_success_cnt = 0. |
|
env_success_cnt = 0. |
|
seed = 123 |
|
self.curr_trials += 1 |
|
|
|
if p.isConnected(): |
|
p.disconnect() |
|
|
|
if not self.task_creation_pass: |
|
print("task creation failure => count as syntax exceptions.") |
|
return |
|
|
|
|
|
try: |
|
exec(self.generated_code, globals()) |
|
task, dataset, env, expert = self.setup_env() |
|
self.syntax_pass_rate += 1 |
|
|
|
except: |
|
to_print = highlight(f"{str(traceback.format_exc())}", PythonLexer(), TerminalFormatter()) |
|
save_text(self.cfg['model_output_dir'], self.generated_task_name + '_error', str(traceback.format_exc())) |
|
print("========================================================") |
|
print("Syntax Exception:", to_print) |
|
return |
|
|
|
try: |
|
|
|
env.generated_code = self.generated_code |
|
|
|
episode = [] |
|
|
|
|
|
""" run the new task for one episode """ |
|
add_to_txt( |
|
self.chat_log, f"================= TRIAL: {self.curr_trials}", with_print=True) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
print('Oracle demo: {}/{} | Seed: {}'.format(dataset.n_episodes + 1, self.cfg['n'], seed)) |
|
env.set_task(task) |
|
obs = env.reset() |
|
|
|
info = env.info |
|
reward = 0 |
|
total_reward = 0 |
|
|
|
|
|
start_time = time.time() |
|
print("start sim") |
|
for i in range(task.max_steps): |
|
act = expert.act(obs, info) |
|
episode.append((obs, act, reward, info)) |
|
lang_goal = info['lang_goal'] |
|
env.generated_code = self.generated_code |
|
yield from env.step(act) |
|
|
|
obs, reward, done, info = env.cur_obs, env.cur_reward, env.cur_done, env.cur_info |
|
total_reward += reward |
|
print(f'Total Reward: {total_reward:.3f} | Done: {done} | Goal: {lang_goal}') |
|
|
|
if done: |
|
break |
|
|
|
end_time = time.time() |
|
print("end sim, time used = ", end_time - start_time) |
|
yield "Task Generated ==> Asset Generated ==> API Reviewed ==> Error Reviewed ==> Code Generated ==> Simulation Running completed", self.generated_code, env.video_path |
|
episode.append((obs, None, reward, info)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Runtime Test Pass!") |
|
except: |
|
to_print = highlight(f"{str(traceback.format_exc())}", PythonLexer(), TerminalFormatter()) |
|
save_text(self.cfg['model_output_dir'], self.generated_task_name + '_error', str(traceback.format_exc())) |
|
print("========================================================") |
|
print("Runtime Exception:", to_print) |
|
self.memory.save_run(self.generated_task) |
|
|