ewanlee commited on
Commit
265d55c
·
1 Parent(s): 1118c95

first commit

Browse files
.gitignore CHANGED
@@ -156,7 +156,7 @@ dmypy.json
156
  # Cython debug symbols
157
  cython_debug/
158
  images/
159
- gpt.py
160
  test.ipynb
161
  results
162
  wandb/
@@ -189,4 +189,7 @@ test_
189
  *.ipynb
190
 
191
  # gradio
192
- flagged
 
 
 
 
156
  # Cython debug symbols
157
  cython_debug/
158
  images/
159
+ # gpt.py
160
  test.ipynb
161
  results
162
  wandb/
 
189
  *.ipynb
190
 
191
  # gradio
192
+ flagged
193
+
194
+ # hf
195
+ policy.pth
app.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import envs
2
+ import deciders
3
+ import distillers
4
+ import prompts as task_prompts
5
+ import datetime
6
+ import time
7
+ from envs.translator import InitSummarizer, CurrSummarizer, FutureSummarizer, Translator
8
+ import gym
9
+ import pandas as pd
10
+ import random
11
+ import datetime
12
+ from loguru import logger
13
+ from argparse import Namespace
14
+ import gradio as gr
15
+ import subprocess
16
+ import openai
17
+ import os
18
+ import shutil
19
+ import subprocess
20
+ from pathlib import Path
21
+ from urllib.request import urlretrieve
22
+
23
+
24
+ def set_seed(seed):
25
+ random.seed(seed)
26
+
27
+ def main_progress(
28
+ api_type, openai_key, env_name, decider_name,
29
+ prompt_level, num_trails, seed
30
+ ):
31
+ init_summarizer = env_name.split("-")[0] + '_init_translator'
32
+ curr_summarizer = env_name.split("-")[0] + '_basic_translator'
33
+ if "Represented" not in init_summarizer:
34
+ init_summarizer = init_summarizer.lower()
35
+ curr_summarizer = curr_summarizer.lower()
36
+ args = Namespace(
37
+ env_name=env_name,
38
+ init_summarizer=init_summarizer,
39
+ curr_summarizer=curr_summarizer,
40
+ decider=decider_name,
41
+ prompt_level=prompt_level,
42
+ num_trails=num_trails,
43
+ seed=seed,
44
+ future_summarizer=None,
45
+ env="base_env",
46
+ gpt_version="gpt-3.5-turbo",
47
+ render="rgb_array",
48
+ max_episode_len=200,
49
+ max_query_tokens=5000,
50
+ max_tokens=2000,
51
+ distiller="traj_distiller",
52
+ prompt_path=None,
53
+ use_short_mem=1,
54
+ short_mem_num=10,
55
+ is_only_local_obs=1,
56
+ api_type=api_type,
57
+ )
58
+
59
+ if args.api_type != "azure" and args.api_type != "openai":
60
+ raise ValueError(f"The {args.api_type} is not supported, please use 'azure' or 'openai' !")
61
+
62
+ # Please note when using "azure", the model name is gpt-35-turbo while using "openai", the model name is "gpt-3.5-turbo"
63
+ if args.api_type == "azure":
64
+ if args.gpt_version == "gpt-3.5-turbo":
65
+ args.gpt_version = 'gpt-35-turbo'
66
+ elif args.api_type == "openai":
67
+ if args.gpt_version == "gpt-35-turbo":
68
+ args.gpt_version = 'gpt-3.5-turbo'
69
+
70
+ # Get the specified translator, environment, and ChatGPT model
71
+ env_class = envs.REGISTRY[args.env]
72
+ init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer], args)
73
+ curr_summarizer = CurrSummarizer(envs.REGISTRY[args.curr_summarizer])
74
+
75
+ if args.future_summarizer:
76
+ future_summarizer = FutureSummarizer(
77
+ envs.REGISTRY[args.future_summarizer],
78
+ envs.REGISTRY["cart_policies"],
79
+ future_horizon=args.future_horizon,
80
+ )
81
+ else:
82
+ future_summarizer = None
83
+
84
+ decider_class = deciders.REGISTRY[args.decider]
85
+ distiller_class = distillers.REGISTRY[args.distiller]
86
+ sampling_env = envs.REGISTRY["sampling_wrapper"](gym.make(args.env_name))
87
+ if args.prompt_level == 5:
88
+ prompts_class = task_prompts.REGISTRY[(args.env_name,args.decider)]()
89
+ else:
90
+ prompts_class = task_prompts.REGISTRY[(args.decider)]()
91
+ translator = Translator(
92
+ init_summarizer, curr_summarizer, future_summarizer, env=sampling_env
93
+ )
94
+ environment = env_class(
95
+ gym.make(args.env_name, render_mode=args.render), translator
96
+ )
97
+
98
+ logfile = (
99
+ f"llm.log/output-{args.env_name}-{args.decider}-{args.gpt_version}-l{args.prompt_level}"
100
+ f"-{datetime.datetime.now().timestamp()}.log"
101
+ )
102
+
103
+ logfile_reflexion = (
104
+ f"llm.log/memory-{args.env_name}-{args.decider}-{args.gpt_version}-l{args.prompt_level}"
105
+ f"-{datetime.datetime.now().timestamp()}.log"
106
+ )
107
+ my_distiller = distiller_class(logfile=logfile_reflexion,args=args)
108
+
109
+ args.game_description = environment.game_description
110
+ args.goal_description = environment.goal_description
111
+ args.action_description = environment.action_description
112
+ args.action_desc_dict = environment.action_desc_dict
113
+ args.reward_desc_dict = environment.reward_desc_dict
114
+
115
+ logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
116
+
117
+ decider = decider_class(openai_key, environment.env.action_space, args, prompts_class, my_distiller, temperature=0.0, logger=logger, max_tokens=args.max_tokens)
118
+
119
+ # Evaluate the translator
120
+ utilities = []
121
+ df = pd.read_csv('record_reflexion.csv', sep=',')
122
+ filtered_df = df[(df['env'] == args.env_name) & (df['decider'] == 'expert') & (df['level'] == 1)]
123
+ expert_score = filtered_df['avg_score'].item()
124
+ seeds = [i for i in range(1000)]
125
+ # prompt_file = "prompt.txt"
126
+ # f = open(prompt_file,"w+")
127
+ num_trails = args.num_trails
128
+ if not "Blackjack" in args.env_name:
129
+ curriculums = 1
130
+ else:
131
+ curriculums = 20
132
+ for curriculum in range(curriculums):
133
+ for trail in range(num_trails):
134
+ if "Blackjack" in args.env_name:
135
+ seed = seeds[curriculum*curriculums + num_trails - trail - 1]
136
+ else:
137
+ seed = args.seed
138
+
139
+ # single run
140
+ # Reset the environment
141
+ if not "Blackjack" in args.env_name:
142
+ set_seed(args.seed)
143
+ seed = args.seed
144
+ # Reset the environment
145
+ state_description, env_info = environment.reset(seed=args.seed)
146
+ else:
147
+ set_seed(seed)
148
+ # Reset the environment
149
+ state_description, env_info = environment.reset(seed=seed)
150
+ game_description = environment.get_game_description()
151
+ goal_description = environment.get_goal_description()
152
+ action_description = environment.get_action_description()
153
+
154
+ # Initialize the statistics
155
+ frames = []
156
+ utility = 0
157
+ current_total_tokens = 0
158
+ current_total_cost = 0
159
+ # state_description, prompt, response, action = None, None, None, None
160
+ start_time = datetime.datetime.now()
161
+ # Run the game for a maximum number of steps
162
+ for round in range(args.max_episode_len):
163
+ # Keep asking ChatGPT for an action until it provides a valid one
164
+ error_flag = True
165
+ retry_num = 1
166
+ for error_i in range(retry_num):
167
+ try:
168
+ action, prompt, response, tokens, cost = decider.act(
169
+ state_description,
170
+ action_description,
171
+ env_info,
172
+ game_description,
173
+ goal_description,
174
+ logfile
175
+ )
176
+
177
+ state_description, reward, termination, truncation, env_info = environment.step_llm(
178
+ action
179
+ )
180
+ if "Cliff" in args.env_name or "Frozen" in args.env_name:
181
+ decider.env_history.add('reward', env_info['potential_state'] + environment.reward_desc_dict[reward])
182
+ else:
183
+ decider.env_history.add('reward', f"The player get rewards {reward}.")
184
+
185
+ utility += reward
186
+
187
+ # Update the statistics
188
+ current_total_tokens += tokens
189
+ current_total_cost += cost
190
+ error_flag = False
191
+ break
192
+ except Exception as e:
193
+ print(e)
194
+ raise e
195
+ if error_i < retry_num-1:
196
+ if "Cliff" in args.env_name or "Frozen" in args.env_name:
197
+ decider.env_history.remove_invalid_state()
198
+ decider.env_history.remove_invalid_state()
199
+ if logger:
200
+ logger.debug(f"Error: {e}, Retry! ({error_i+1}/{retry_num})")
201
+ continue
202
+ if error_flag:
203
+ action = decider.default_action
204
+ state_description, reward, termination, truncation, env_info = environment.step_llm(
205
+ action
206
+ )
207
+
208
+ decider.env_history.add('action', decider.default_action)
209
+
210
+ if "Cliff" in args.env_name or "Frozen" in args.env_name:
211
+ # decider.env_history.add('reward', reward)
212
+ decider.env_history.add('reward', env_info['potential_state'] + environment.reward_desc_dict[reward])
213
+ utility += reward
214
+
215
+
216
+ logger.info(f"Seed: {seed}")
217
+ logger.info(f'The optimal action is: {decider.default_action}.')
218
+ logger.info(f"Now it is round {round}.")
219
+ else:
220
+ current_total_tokens += tokens
221
+ current_total_cost += cost
222
+ logger.info(f"Seed: {seed}")
223
+ logger.info(f"current_total_tokens: {current_total_tokens}")
224
+ logger.info(f"current_total_cost: {current_total_cost}")
225
+ logger.info(f"Now it is round {round}.")
226
+
227
+ # return results
228
+ yield environment.render(), state_description, prompt, response, action
229
+
230
+ if termination or truncation:
231
+ if logger:
232
+ logger.info(f"Terminated!")
233
+ break
234
+ time.sleep(5)
235
+ decider.env_history.add(
236
+ 'terminate_state', environment.get_terminate_state(round+1, args.max_episode_len))
237
+ decider.env_history.add("cummulative_reward", str(utility))
238
+ # Record the final reward
239
+ if logger:
240
+ logger.info(f"Cummulative reward: {utility}.")
241
+ end_time = datetime.datetime.now()
242
+ time_diff = end_time - start_time
243
+ logger.info(f"Time consumer: {time_diff.total_seconds()} s")
244
+
245
+ utilities.append(utility)
246
+ # TODO: set env sucess utility threshold
247
+ if trail < num_trails -1:
248
+ if args.decider in ['reflexion']:
249
+ if utility < expert_score:
250
+ decider.update_mem()
251
+ else:
252
+ decider.update_mem()
253
+ decider.clear_mem()
254
+ return utilities
255
+
256
+ # def pause():
257
+ # for i in range(31415926):
258
+ # time.sleep(0.1)
259
+ # yield i
260
+
261
+ if __name__ == "__main__":
262
+
263
+
264
+ # install Atari ROMs
265
+ subprocess.run(['AutoROM', '--accept-license'])
266
+
267
+ # install mujoco
268
+
269
+ # Step 1: Download and set up MuJoCo
270
+ MUJOCO_URL = "https://github.com/google-deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz"
271
+ MUJOCO_FILENAME = "mujoco210-linux-x86_64.tar.gz"
272
+
273
+ # Download MuJoCo
274
+ print("Downloading MuJoCo...")
275
+ urlretrieve(MUJOCO_URL, MUJOCO_FILENAME)
276
+
277
+ # Create and move to ~/.mujoco directory
278
+ mujoco_dir = Path.home() / ".mujoco"
279
+ mujoco_dir.mkdir(exist_ok=True)
280
+ shutil.move(MUJOCO_FILENAME, str(mujoco_dir / MUJOCO_FILENAME))
281
+
282
+ # Extract the file
283
+ print("Extracting MuJoCo...")
284
+ subprocess.run(["tar", "-zxvf", str(mujoco_dir / MUJOCO_FILENAME)], cwd=mujoco_dir)
285
+
286
+ # Edit .bashrc
287
+ bashrc_path = Path.home() / ".bashrc"
288
+ mujoco_path = mujoco_dir / "mujoco210" / "bin"
289
+ export_line = f"export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:{mujoco_path}\n"
290
+
291
+ with open(bashrc_path, "a") as bashrc_file:
292
+ bashrc_file.write(export_line)
293
+
294
+ # Set LD_LIBRARY_PATH for the current process
295
+ ld_lib_path = os.environ.get("LD_LIBRARY_PATH", "")
296
+ new_ld_lib_path = f"{ld_lib_path}{mujoco_path}"
297
+ os.environ["LD_LIBRARY_PATH"] = new_ld_lib_path
298
+
299
+ # Step 2: Install gym[mujoco]
300
+ print("Installing gym[MuJoCo]...")
301
+ subprocess.run(["pip", "install", "gym[mujoco]"])
302
+
303
+ # # Set render
304
+ os.environ["MUJOCO_GL"] = "egl"
305
+ # os.environ["DISPLAY"] = ":0"
306
+ # print(f'LD_LIBRARY_PATH: {os.environ["LD_LIBRARY_PATH"]}')
307
+ # assert os.path.exists(str(mujoco_path))
308
+ # subprocess.run("cp -r /home/user/.mujoco/mujoco210/bin/* /usr/lib/", shell=True)
309
+ # import mujoco_py
310
+ # flag = 'gpu' in str(mujoco_py.cymj).split('/')[-1]
311
+ # print(f'flag: {flag}')
312
+ # if not flag:
313
+ # ld_lib_path = os.environ.get("LD_LIBRARY_PATH", "")
314
+ # new_ld_lib_path = f"{ld_lib_path}:/usr/lib/nvidia-000"
315
+ # os.environ["LD_LIBRARY_PATH"] = new_ld_lib_path
316
+ # subprocess.run(["sudo", "mkdir", "-p", "/usr/lib/nvidia-000"])
317
+ # assert 'gpu' in str(mujoco_py.cymj).split('/')[-1]
318
+
319
+
320
+ custom_css = """
321
+ #render {
322
+ flex-grow: 1;
323
+ }
324
+ #input_text .tabs {
325
+ display: flex;
326
+ flex-direction: column;
327
+ flex-grow: 1;
328
+ }
329
+ #input_text .tabitem[style="display: block;"] {
330
+ flex-grow: 1;
331
+ display: flex !important;
332
+ }
333
+ #input_text .gap {
334
+ flex-grow: 1;
335
+ }
336
+ #input_text .form {
337
+ flex-grow: 1 !important;
338
+ }
339
+ #input_text .form > :last-child{
340
+ flex-grow: 1;
341
+ }
342
+ """
343
+
344
+ with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css) as demo:
345
+ with gr.Row():
346
+ api_type = gr.Dropdown(["azure", "openai"], label="API Type", scale=1)
347
+ openai_key = gr.Textbox(label="OpenAI API Key", type="password", scale=3)
348
+ with gr.Row():
349
+ env_name = gr.Dropdown(
350
+ ["CartPole-v0",
351
+ "LunarLander-v2",
352
+ "Acrobot-v1",
353
+ "MountainCar-v0",
354
+ "Blackjack-v1",
355
+ "Taxi-v3",
356
+ "CliffWalking-v0",
357
+ "FrozenLake-v1",
358
+ "MountainCarContinuous-v0",
359
+ "Ant-v4",
360
+ "RepresentedBoxing-v0",
361
+ "RepresentedPong-v0",
362
+ "RepresentedMsPacman-v0",
363
+ "RepresentedMontezumaRevenge-v0"],
364
+ label="Environment Name")
365
+ decider_name = gr.Dropdown(
366
+ ["naive_actor",
367
+ "cot_actor",
368
+ "spp_actor",
369
+ "reflexion_actor"],
370
+ label="Decider")
371
+ # prompt_level = gr.Dropdown([1, 2, 3, 4, 5], label="Prompt Level")
372
+ # TODO: support more prompt levels
373
+ prompt_level = gr.Dropdown([1, 3], label="Prompt Level")
374
+ with gr.Row():
375
+ num_trails = gr.Slider(1, 100, 1, label="Number of Trails", scale=2)
376
+ seed = gr.Slider(1, 1000, 1, label="Seed", scale=2)
377
+ run = gr.Button("Run", scale=1)
378
+ # pause_ = gr.Button("Pause")
379
+ # resume = gr.Button("Resume")
380
+ stop = gr.Button("Stop", scale=1)
381
+ with gr.Row():
382
+ with gr.Column():
383
+ render = gr.Image(label="render", elem_id="render")
384
+ with gr.Column(elem_id="input_text"):
385
+ state = gr.Textbox(label="translated state")
386
+ prompt = gr.Textbox(label="prompt", max_lines=20)
387
+ with gr.Row():
388
+ response = gr.Textbox(label="response")
389
+ action = gr.Textbox(label="parsed action")
390
+ run_event = run.click(
391
+ fn=main_progress,
392
+ inputs=[
393
+ api_type, openai_key, env_name,
394
+ decider_name, prompt_level, num_trails, seed],
395
+ outputs=[render, state, prompt, response, action])
396
+ stop.click(fn=None, inputs=None, outputs=None, cancels=[run_event])
397
+ # pause_event = pause_.click(fn=pause, inputs=None, outputs=None)
398
+ # resume.click(fn=None, inputs=None, outputs=None, cancels=[pause_event])
399
+
400
+ demo.launch()
deciders/act.py CHANGED
@@ -26,7 +26,7 @@ class RandomAct():
26
  return action, '', '', '', 0, 0
27
 
28
  class NaiveAct(gpt):
29
- def __init__(self, action_space, args, prompts, distiller, temperature=0.0, max_tokens=2048, logger=None):
30
  self.action_space = action_space
31
  self.temperature = temperature
32
  self.action_desc_dict = args.action_desc_dict
@@ -39,7 +39,7 @@ class NaiveAct(gpt):
39
  else:
40
  model = args.gpt_version
41
  self.encoding = tiktoken.encoding_for_model(model)
42
- super().__init__(args)
43
  self.distiller = distiller
44
  self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
45
  if isinstance(self.action_space, Discrete):
 
26
  return action, '', '', '', 0, 0
27
 
28
  class NaiveAct(gpt):
29
+ def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0.0, max_tokens=2048, logger=None):
30
  self.action_space = action_space
31
  self.temperature = temperature
32
  self.action_desc_dict = args.action_desc_dict
 
39
  else:
40
  model = args.gpt_version
41
  self.encoding = tiktoken.encoding_for_model(model)
42
+ super().__init__(args, openai_key)
43
  self.distiller = distiller
44
  self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
45
  if isinstance(self.action_space, Discrete):
deciders/cot.py CHANGED
@@ -17,8 +17,8 @@ from .utils import run_chain
17
 
18
 
19
  class ChainOfThought(NaiveAct):
20
- def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
21
- super().__init__(action_space, args, prompts, distiller, temperature, max_tokens,logger)
22
 
23
  def act(
24
  self,
 
17
 
18
 
19
  class ChainOfThought(NaiveAct):
20
+ def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
21
+ super().__init__(openai_key, action_space, args, prompts, distiller, temperature, max_tokens,logger)
22
 
23
  def act(
24
  self,
deciders/exe.py CHANGED
@@ -20,8 +20,8 @@ from loguru import logger
20
 
21
 
22
  class EXE(NaiveAct):
23
- def __init__(self, action_space, args, prompts, distiller, temperature=0., max_tokens=None, logger=None, fixed_suggestion=None, fixed_insight=None):
24
- super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
25
  self.pre_memory = []
26
  self.post_memory = []
27
  self.is_first = True
 
20
 
21
 
22
  class EXE(NaiveAct):
23
+ def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0., max_tokens=None, logger=None, fixed_suggestion=None, fixed_insight=None):
24
+ super().__init__(openai_key, action_space, args, prompts, distiller, temperature, max_tokens, logger)
25
  self.pre_memory = []
26
  self.post_memory = []
27
  self.is_first = True
deciders/gpt.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ class gpt:
3
+ def __init__(self, args, api_key=None):
4
+ if args.api_type == "azure":
5
+ openai.api_type = "azure"
6
+ openai.api_version = "2023-05-15"
7
+ # Your Azure OpenAI resource's endpoint value.
8
+ openai.api_base = "https://midivi-main-scu1.openai.azure.com/"
9
+ openai.api_key = api_key
10
+ else:
11
+ openai.api_key = api_key
deciders/reflexion.py CHANGED
@@ -19,8 +19,8 @@ from .utils import run_chain
19
 
20
 
21
  class Reflexion(NaiveAct):
22
- def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
23
- super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
24
 
25
  def num_tokens_from_string(self,string: str) -> int:
26
  """Returns the number of tokens in a text string."""
 
19
 
20
 
21
  class Reflexion(NaiveAct):
22
+ def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
23
+ super().__init__(openai_key, action_space, args, prompts, distiller, temperature, max_tokens, logger)
24
 
25
  def num_tokens_from_string(self,string: str) -> int:
26
  """Returns the number of tokens in a text string."""
deciders/self_consistency.py CHANGED
@@ -17,9 +17,9 @@ from .utils import run_chain
17
 
18
 
19
  class SelfConsistency(NaiveAct):
20
- def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
21
  temperature = 0.7
22
- super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
23
  self.temperature = temperature
24
 
25
  def act(
 
17
 
18
 
19
  class SelfConsistency(NaiveAct):
20
+ def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
21
  temperature = 0.7
22
+ super().__init__(openai_key, action_space, args, prompts, distiller, temperature, max_tokens, logger)
23
  self.temperature = temperature
24
 
25
  def act(
deciders/selfask.py CHANGED
@@ -17,8 +17,8 @@ from .utils import run_chain
17
 
18
 
19
  class SelfAskAct(NaiveAct):
20
- def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
21
- super().__init__(action_space, args, prompts, distiller, temperature, max_tokens,logger)
22
 
23
  def act(
24
  self,
 
17
 
18
 
19
  class SelfAskAct(NaiveAct):
20
+ def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
21
+ super().__init__(openai_key, action_space, args, prompts, distiller, temperature, max_tokens,logger)
22
 
23
  def act(
24
  self,
deciders/spp.py CHANGED
@@ -16,8 +16,8 @@ from .act import NaiveAct
16
  from .utils import run_chain
17
 
18
  class SPP(NaiveAct):
19
- def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
20
- super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
21
 
22
  def act(
23
  self,
 
16
  from .utils import run_chain
17
 
18
  class SPP(NaiveAct):
19
+ def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
20
+ super().__init__(openai_key, action_space, args, prompts, distiller, temperature, max_tokens, logger)
21
 
22
  def act(
23
  self,
deciders/utils.py CHANGED
@@ -19,8 +19,30 @@ Model = Literal["gpt-4", "gpt-35-turbo", "text-davinci-003"]
19
  # from .gpt import gpt
20
  # gpt().__init__()
21
 
22
- import timeout_decorator
23
- @timeout_decorator.timeout(30)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def run_chain(chain, *args, **kwargs):
25
  return chain.run(*args, **kwargs)
26
 
@@ -86,5 +108,4 @@ def get_chat(prompt: str, api_type: str = "azure", model: str = "gpt-35-turbo",
86
  temperature=temperature,
87
  # request_timeout = 1
88
  )
89
- return response.choices[0]["message"]["content"]
90
-
 
19
  # from .gpt import gpt
20
  # gpt().__init__()
21
 
22
+ # import timeout_decorator
23
+ # @timeout_decorator.timeout(30)
24
+ # def run_chain(chain, *args, **kwargs):
25
+ # return chain.run(*args, **kwargs)
26
+ import concurrent.futures
27
+
28
+ def timeout_decorator(timeout):
29
+ def decorator(function):
30
+ def wrapper(*args, **kwargs):
31
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
32
+ future = executor.submit(function, *args, **kwargs)
33
+ try:
34
+ return future.result(timeout)
35
+ except concurrent.futures.TimeoutError:
36
+ raise RuntimeError(
37
+ f"Function '{function.__name__}' timed out after {timeout} seconds"
38
+ )
39
+ except Exception as e:
40
+ raise e
41
+ return wrapper
42
+ return decorator
43
+
44
+
45
+ @timeout_decorator(30)
46
  def run_chain(chain, *args, **kwargs):
47
  return chain.run(*args, **kwargs)
48
 
 
108
  temperature=temperature,
109
  # request_timeout = 1
110
  )
111
+ return response.choices[0]["message"]["content"]
 
envs/__init__.py CHANGED
@@ -18,24 +18,25 @@ from .atari import mspacman_policies, mspacman_translator
18
  from .atari import montezumarevenge_policies, montezumarevenge_translator
19
  register_environments()
20
 
 
21
 
22
  REGISTRY = {}
23
  REGISTRY["sampling_wrapper"] = SettableStateEnv
24
  REGISTRY["base_env"] = BaseEnv
25
- REGISTRY["cart_init_translator"] = cartpole_translator.GameDescriber
26
- REGISTRY["cart_basic_translator"] = cartpole_translator.BasicStateSequenceTranslator
27
  REGISTRY["acrobot_init_translator"] = acrobot_translator.GameDescriber
28
  REGISTRY["acrobot_basic_translator"] = acrobot_translator.BasicStateSequenceTranslator
29
  REGISTRY["mountaincar_init_translator"] = mountaincar_translator.GameDescriber
30
  REGISTRY["mountaincar_basic_translator"] = mountaincar_translator.BasicStateSequenceTranslator
31
 
32
- REGISTRY["cart_policies"] = [cartpole_policies.dedicated_1_policy, cartpole_policies.dedicated_2_policy, cartpole_policies.pseudo_random_policy, cartpole_policies.real_random_policy]
33
  REGISTRY["acrobot_policies"] = [acrobot_policies.dedicated_1_policy, acrobot_policies.dedicated_2_policy, acrobot_policies.dedicated_3_policy, acrobot_policies.pseudo_random_policy, acrobot_policies.real_random_policy]
34
  REGISTRY["mountaincar_policies"] = [mountaincar_policies.dedicated_1_policy, mountaincar_policies.dedicated_2_policy, mountaincar_policies.dedicated_3_policy, mountaincar_policies.pseudo_random_policy, mountaincar_policies.real_random_policy]
35
 
36
- REGISTRY["lunarLander_init_translator"] = LunarLander_translator.GameDescriber
37
- REGISTRY["lunarLander_basic_translator"] = LunarLander_translator.BasicStateSequenceTranslator
38
- REGISTRY["lunarLander_policies"] = [LunarLander_policies.dedicated_1_policy, LunarLander_policies.dedicated_2_policy, LunarLander_policies.dedicated_3_policy,LunarLander_policies.dedicated_4_policy, LunarLander_policies.pseudo_random_policy, LunarLander_policies.real_random_policy]
39
 
40
  REGISTRY["blackjack_init_translator"] = blackjack_translator.GameDescriber
41
  REGISTRY["blackjack_basic_translator"] = blackjack_translator.BasicStateSequenceTranslator
@@ -54,9 +55,9 @@ REGISTRY["frozenlake_basic_translator"] = frozenlake_translator.BasicStateSequen
54
  REGISTRY["frozenlake_policies"] = [frozenlake_policies.dedicated_1_policy, frozenlake_policies.dedicated_2_policy, frozenlake_policies.dedicated_3_policy, frozenlake_policies.dedicated_4_policy, frozenlake_policies.pseudo_random_policy, frozenlake_policies.real_random_policy]
55
 
56
 
57
- REGISTRY["mountaincarContinuous_init_translator"] = mountaincarContinuous_translator.GameDescriber
58
- REGISTRY["mountaincarContinuous_basic_translator"] = mountaincarContinuous_translator.BasicStateSequenceTranslator
59
- REGISTRY["mountaincarContinuous_policies"] = [mountaincarContinuous_policies.pseudo_random_policy, mountaincarContinuous_policies.real_random_policy]
60
 
61
 
62
  REGISTRY["RepresentedBoxing_init_translator"] = Boxing_translator.GameDescriber
@@ -138,47 +139,6 @@ REGISTRY["RepresentedMontezumaRevenge_basic_policies"] = [
138
  montezumarevenge_policies.dedicated_18_policy,
139
  ]
140
 
141
- REGISTRY["RepresentedMsPacman_init_translator"] = mspacman_translator.GameDescriber
142
- REGISTRY["RepresentedMsPacman_basic_translator"] = mspacman_translator.BasicStateSequenceTranslator
143
- REGISTRY["RepresentedMsPacman_basic_policies"] = [
144
- mspacman_policies.real_random_policy,
145
- mspacman_policies.pseudo_random_policy,
146
- mspacman_policies.dedicated_1_policy,
147
- mspacman_policies.dedicated_2_policy,
148
- mspacman_policies.dedicated_3_policy,
149
- mspacman_policies.dedicated_4_policy,
150
- mspacman_policies.dedicated_5_policy,
151
- mspacman_policies.dedicated_6_policy,
152
- mspacman_policies.dedicated_7_policy,
153
- mspacman_policies.dedicated_8_policy,
154
- mspacman_policies.dedicated_9_policy,
155
- ]
156
-
157
- REGISTRY["RepresentedMontezumaRevenge_init_translator"] = montezumarevenge_translator.GameDescriber
158
- REGISTRY["RepresentedMontezumaRevenge_basic_translator"] = montezumarevenge_translator.BasicStateSequenceTranslator
159
- REGISTRY["RepresentedMontezumaRevenge_basic_policies"] = [
160
- montezumarevenge_policies.real_random_policy,
161
- montezumarevenge_policies.pseudo_random_policy,
162
- montezumarevenge_policies.dedicated_1_policy,
163
- montezumarevenge_policies.dedicated_2_policy,
164
- montezumarevenge_policies.dedicated_3_policy,
165
- montezumarevenge_policies.dedicated_4_policy,
166
- montezumarevenge_policies.dedicated_5_policy,
167
- montezumarevenge_policies.dedicated_6_policy,
168
- montezumarevenge_policies.dedicated_7_policy,
169
- montezumarevenge_policies.dedicated_8_policy,
170
- montezumarevenge_policies.dedicated_9_policy,
171
- montezumarevenge_policies.dedicated_10_policy,
172
- montezumarevenge_policies.dedicated_11_policy,
173
- montezumarevenge_policies.dedicated_12_policy,
174
- montezumarevenge_policies.dedicated_13_policy,
175
- montezumarevenge_policies.dedicated_14_policy,
176
- montezumarevenge_policies.dedicated_15_policy,
177
- montezumarevenge_policies.dedicated_16_policy,
178
- montezumarevenge_policies.dedicated_17_policy,
179
- montezumarevenge_policies.dedicated_18_policy,
180
- ]
181
-
182
  ## For mujoco env
183
 
184
 
@@ -196,12 +156,12 @@ from .mujoco import walker2d_translator, walker2d_policies
196
 
197
 
198
 
199
- REGISTRY["invertedPendulum_init_translator"] = invertedPendulum_translator.GameDescriber
200
- REGISTRY["invertedPendulum_basic_translator"] = invertedPendulum_translator.BasicStateSequenceTranslator
201
- REGISTRY["invertedPendulum_policies"] = [invertedPendulum_policies.pseudo_random_policy, invertedPendulum_policies.real_random_policy]
202
- REGISTRY["invertedDoublePendulum_init_translator"] = invertedDoublePendulum_translator.GameDescriber
203
- REGISTRY["invertedDoublePendulum_basic_translator"] = invertedDoublePendulum_translator.BasicStateSequenceTranslator
204
- REGISTRY["invertedDoublePendulum_policies"] = [invertedDoublePendulum_policies.pseudo_random_policy, invertedDoublePendulum_policies.real_random_policy]
205
 
206
 
207
  REGISTRY["swimmer_init_translator"] = swimmer_translator.GameDescriber
 
18
  from .atari import montezumarevenge_policies, montezumarevenge_translator
19
  register_environments()
20
 
21
+ from .mujoco import ant_translator, ant_policies
22
 
23
  REGISTRY = {}
24
  REGISTRY["sampling_wrapper"] = SettableStateEnv
25
  REGISTRY["base_env"] = BaseEnv
26
+ REGISTRY["cartpole_init_translator"] = cartpole_translator.GameDescriber
27
+ REGISTRY["cartpole_basic_translator"] = cartpole_translator.BasicStateSequenceTranslator
28
  REGISTRY["acrobot_init_translator"] = acrobot_translator.GameDescriber
29
  REGISTRY["acrobot_basic_translator"] = acrobot_translator.BasicStateSequenceTranslator
30
  REGISTRY["mountaincar_init_translator"] = mountaincar_translator.GameDescriber
31
  REGISTRY["mountaincar_basic_translator"] = mountaincar_translator.BasicStateSequenceTranslator
32
 
33
+ REGISTRY["cartpole_policies"] = [cartpole_policies.dedicated_1_policy, cartpole_policies.dedicated_2_policy, cartpole_policies.pseudo_random_policy, cartpole_policies.real_random_policy]
34
  REGISTRY["acrobot_policies"] = [acrobot_policies.dedicated_1_policy, acrobot_policies.dedicated_2_policy, acrobot_policies.dedicated_3_policy, acrobot_policies.pseudo_random_policy, acrobot_policies.real_random_policy]
35
  REGISTRY["mountaincar_policies"] = [mountaincar_policies.dedicated_1_policy, mountaincar_policies.dedicated_2_policy, mountaincar_policies.dedicated_3_policy, mountaincar_policies.pseudo_random_policy, mountaincar_policies.real_random_policy]
36
 
37
+ REGISTRY["lunarlander_init_translator"] = LunarLander_translator.GameDescriber
38
+ REGISTRY["lunarlander_basic_translator"] = LunarLander_translator.BasicStateSequenceTranslator
39
+ REGISTRY["lunarlander_policies"] = [LunarLander_policies.dedicated_1_policy, LunarLander_policies.dedicated_2_policy, LunarLander_policies.dedicated_3_policy,LunarLander_policies.dedicated_4_policy, LunarLander_policies.pseudo_random_policy, LunarLander_policies.real_random_policy]
40
 
41
  REGISTRY["blackjack_init_translator"] = blackjack_translator.GameDescriber
42
  REGISTRY["blackjack_basic_translator"] = blackjack_translator.BasicStateSequenceTranslator
 
55
  REGISTRY["frozenlake_policies"] = [frozenlake_policies.dedicated_1_policy, frozenlake_policies.dedicated_2_policy, frozenlake_policies.dedicated_3_policy, frozenlake_policies.dedicated_4_policy, frozenlake_policies.pseudo_random_policy, frozenlake_policies.real_random_policy]
56
 
57
 
58
+ REGISTRY["mountaincarcontinuous_init_translator"] = mountaincarContinuous_translator.GameDescriber
59
+ REGISTRY["mountaincarcontinuous_basic_translator"] = mountaincarContinuous_translator.BasicStateSequenceTranslator
60
+ REGISTRY["mountaincarcontinuous_policies"] = [mountaincarContinuous_policies.pseudo_random_policy, mountaincarContinuous_policies.real_random_policy]
61
 
62
 
63
  REGISTRY["RepresentedBoxing_init_translator"] = Boxing_translator.GameDescriber
 
139
  montezumarevenge_policies.dedicated_18_policy,
140
  ]
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  ## For mujoco env
143
 
144
 
 
156
 
157
 
158
 
159
+ REGISTRY["invertedpendulum_init_translator"] = invertedPendulum_translator.GameDescriber
160
+ REGISTRY["invertedpendulum_basic_translator"] = invertedPendulum_translator.BasicStateSequenceTranslator
161
+ REGISTRY["invertedpendulum_policies"] = [invertedPendulum_policies.pseudo_random_policy, invertedPendulum_policies.real_random_policy]
162
+ REGISTRY["inverteddoublependulum_init_translator"] = invertedDoublePendulum_translator.GameDescriber
163
+ REGISTRY["inverteddoublependulum_basic_translator"] = invertedDoublePendulum_translator.BasicStateSequenceTranslator
164
+ REGISTRY["inverteddoublependulum_policies"] = [invertedDoublePendulum_policies.pseudo_random_policy, invertedDoublePendulum_policies.real_random_policy]
165
 
166
 
167
  REGISTRY["swimmer_init_translator"] = swimmer_translator.GameDescriber
packages.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ swig
2
+ libosmesa6-dev
3
+ libgl1-mesa-glx
4
+ libglfw3
5
+ libglew-dev
6
+ patchelf
7
+ libxrender1
8
+ libgl1-mesa-dev
9
+ xpra
10
+ libglfw3-dev
requirements.txt ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # _libgcc_mutex==0.1
2
+ # # _openmp_mutex==5.1
3
+ # asttokens==2.0.5
4
+ # async-timeout==4.0.2
5
+ # backcall==0.2.0
6
+ # # blas==1.0
7
+ # brotlipy==0.7.0
8
+ # # ca-certificates==2023.01.10
9
+ # cached-property==1.5.2
10
+ # cffi==1.15.1
11
+ # chardet==4.0.0
12
+ # comm==0.1.2
13
+ # cryptography==39.0.1
14
+ # # cudatoolkit==11.3.1
15
+ # debugpy==1.5.1
16
+ # decorator==5.1.1
17
+ # executing==0.8.3
18
+ # frozenlist==1.3.3
19
+ # # hdf5==1.10.6
20
+ # idna==3.4
21
+ # importlib_metadata==6.0.0
22
+ # intel-openmp==2023.1.0
23
+ # ipykernel==6.19.2
24
+ # ipython==8.12.0
25
+ # jedi==0.18.1
26
+ # jupyter_client==8.1.0
27
+ # jupyter_core==5.3.0
28
+ # # ld_impl_linux-64==2.38
29
+ # # libffi==3.4.4
30
+ # # libgcc-ng==11.2.0
31
+ # # libgfortran-ng==11.2.0
32
+ # # libgfortran5==11.2.0
33
+ # # libgomp==11.2.0
34
+ # # libllvm14==14.0.6
35
+ # # libprotobuf==3.20.3
36
+ # # libsodium==1.0.18
37
+ # # libstdcxx-ng==11.2.0
38
+ # matplotlib-inline==0.1.6
39
+ # mkl==2023.1.0
40
+ # mkl-service==2.4.0
41
+ # mkl_fft==1.3.6
42
+ # mkl_random==1.2.2
43
+ # # ncurses==6.4
44
+ # nest-asyncio==1.5.6
45
+ # numpy==1.24.3
46
+ # # numpy-base==1.24.3
47
+ # # openssl==3.0.10
48
+ # packaging==23.1
49
+ # parso==0.8.3
50
+ # # pcre==8.45
51
+ # pexpect==4.8.0
52
+ # pickleshare==0.7.5
53
+ # pip==23.1.2
54
+ # platformdirs==2.5.2
55
+ # prompt-toolkit==3.0.36
56
+ # ptyprocess==0.7.0
57
+ # pure_eval==0.2.2
58
+ # pycparser==2.21
59
+ # pygments==2.15.1
60
+ # pyopenssl==23.0.0
61
+ # pysocks==1.7.1
62
+ # # python==3.8.16
63
+ # python-dateutil==2.8.2
64
+ # # python_abi==3.8
65
+ # pyzmq==25.1.0
66
+ # # readline==8.2
67
+ # setuptools==67.8.0
68
+ # six==1.16.0
69
+ # # sqlite==3.41.2
70
+ # stack_data==0.2.0
71
+ # tbb==2021.8.0
72
+ # # tk==8.6.12
73
+ # tornado==6.2
74
+ # traitlets==5.7.1
75
+ # typing_extensions==4.7.1
76
+ # wcwidth==0.2.5
77
+ # wheel==0.38.4
78
+ # # xz==5.4.2
79
+ # # yaml==0.2.5
80
+ # # zeromq==4.3.4
81
+ # # zlib==1.2.13
82
+ ale-py==0.8.1
83
+ absl-py==1.4.0
84
+ aiohttp==3.8.4
85
+ aiosignal==1.3.1
86
+ annotated-types==0.5.0
87
+ anyio==3.7.1
88
+ appdirs==1.4.4
89
+ aquarel==0.0.5
90
+ attrs==23.1.0
91
+ box2d-py==2.3.5
92
+ cachetools==5.3.1
93
+ certifi==2023.5.7
94
+ charset-normalizer==3.1.0
95
+ click==8.1.6
96
+ cloudpickle==2.2.1
97
+ colorama==0.4.6
98
+ contourpy==1.1.0
99
+ cycler==0.11.0
100
+ dataclasses-json==0.5.14
101
+ distro==1.8.0
102
+ docker-pycreds==0.4.0
103
+ exceptiongroup==1.2.0
104
+ filelock==3.12.3
105
+ fonttools==4.40.0
106
+ fsspec==2023.6.0
107
+ gitdb==4.0.10
108
+ gitpython==3.1.32
109
+ google-auth==2.22.0
110
+ google-auth-oauthlib==1.0.0
111
+ greenlet==2.0.2
112
+ grpcio==1.57.0
113
+ gym==0.26.2
114
+ gym-notices==0.0.8
115
+ gym[accept-rom-license]
116
+ h11==0.14.0
117
+ h5py==3.9.0
118
+ httpcore==1.0.2
119
+ httpx==0.25.2
120
+ # huggingface-hub==0.16.4
121
+ importlib-metadata==6.6.0
122
+ importlib-resources==5.12.0
123
+ joblib==1.3.2
124
+ kiwisolver==1.4.4
125
+ langchain==0.0.270
126
+ langsmith==0.0.25
127
+ llvmlite==0.40.1
128
+ logger==1.4
129
+ loguru==0.7.0
130
+ markdown==3.4.4
131
+ markupsafe==2.1.3
132
+ marshmallow==3.20.1
133
+ matplotlib==3.7.1
134
+ multidict==6.0.4
135
+ mypy-extensions==1.0.0
136
+ numba==0.57.1
137
+ numexpr==2.8.5
138
+ oauthlib==3.2.2
139
+ openai==0.27.8
140
+ pandas==2.0.3
141
+ pathtools==0.1.2
142
+ pillow==9.5.0
143
+ protobuf==3.19.6
144
+ psutil==5.9.5
145
+ pyasn1==0.5.0
146
+ pyasn1-modules==0.3.0
147
+ # pydantic==1.10.11
148
+ # pydantic-core==2.6.1
149
+ pygame==2.1.0
150
+ pyparsing==3.0.9
151
+ pytz==2023.3.post1
152
+ pyyaml==6.0.1
153
+ regex==2023.8.8
154
+ requests==2.31.0
155
+ requests-oauthlib==1.3.1
156
+ rsa==4.9
157
+ safetensors==0.3.3
158
+ seaborn==0.13.0
159
+ sentry-sdk==1.28.1
160
+ setproctitle==1.3.2
161
+ smmap==5.0.0
162
+ sniffio==1.3.0
163
+ sqlalchemy==2.0.20
164
+ swig==4.1.1
165
+ tenacity==8.2.3
166
+ tensorboard==2.14.0
167
+ tensorboard-data-server==0.7.1
168
+ threadpoolctl==3.2.0
169
+ tiktoken==0.4.0
170
+ timeout-decorator==0.5.0
171
+ tokenizers==0.13.3
172
+ tqdm==4.65.0
173
+ transformers==4.30.2
174
+ typing-inspect==0.9.0
175
+ tzdata==2023.3
176
+ urllib3==1.26.16
177
+ v==1
178
+ wandb==0.15.5
179
+ werkzeug==2.3.7
180
+ win32-setctime==1.1.0
181
+ yarl==1.9.2
182
+ zipp==3.15.0
183
+ git+https://[email protected]/hyyh28/atari-representation-learning.git
184
+ gradio
185
+ # gradio==4.13.0
186
+ mujoco-py==2.1.2.14
187
+ cython==0.29.37
188
+ ruamel.yaml==0.18.5
yaml2rep.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ruamel.yaml
2
+
3
+ yaml = ruamel.yaml.YAML()
4
+ data = yaml.load(open('environment.yaml'))
5
+
6
+ requirements = []
7
+ for dep in data['dependencies']:
8
+ if isinstance(dep, str):
9
+ package, package_version = dep.split('=')
10
+ requirements.append(package + '==' + package_version)
11
+ elif isinstance(dep, dict):
12
+ for preq in dep.get('pip', []):
13
+ requirements.append(preq)
14
+
15
+ with open('requirements.txt', 'w') as fp:
16
+ for requirement in requirements:
17
+ print(requirement, file=fp)