File size: 1,326 Bytes
8fc2b4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import numpy as np
import os
import hydra
import random
import re
import openai
import IPython
import time
import pybullet as p
import traceback
from datetime import datetime
from pprint import pprint
import cv2
import re
import random
import json
from gensim.agent import Agent
from gensim.critic import Critic
from gensim.sim_runner import SimulationRunner
from gensim.memory import Memory
from gensim.utils import set_gpt_model, clear_messages
@hydra.main(config_path='../cliport/cfg', config_name='data', version_base="1.2")
def main(cfg):
openai.api_key = cfg['openai_key']
model_time = datetime.now().strftime("%d_%m_%Y_%H:%M:%S")
cfg['model_output_dir'] = os.path.join(cfg['output_folder'], cfg['prompt_folder'] + "_" + model_time)
if 'seed' in cfg:
cfg['model_output_dir'] = cfg['model_output_dir'] + f"_{cfg['seed']}"
set_gpt_model(cfg['gpt_model'])
memory = Memory(cfg)
agent = Agent(cfg, memory)
critic = Critic(cfg, memory)
simulation_runner = SimulationRunner(cfg, agent, critic, memory)
for trial_i in range(cfg['trials']):
simulation_runner.task_creation()
simulation_runner.simulate_task()
simulation_runner.print_current_stats()
# clear_messages()
simulation_runner.save_stats()
if __name__ == '__main__':
main()
|