hzxwonder commited on
Commit
a32f407
·
1 Parent(s): a1c0b89

update gpt

Browse files
Files changed (5) hide show
  1. README.md +17 -10
  2. deciders/act.py +1 -1
  3. deciders/utils.py +2 -2
  4. environment.yaml +2 -0
  5. main_reflexion.py +9 -0
README.md CHANGED
@@ -34,25 +34,32 @@ For `L5` level, we handcraft the few shot examples with domain knowledge in `pro
34
  ```python
35
  import openai
36
  class gpt:
37
- def __init__(self,):
38
- openai.api_type = "azure"
39
- openai.api_version = "2023-05-15"
40
- # Your Azure OpenAI resource's endpoint value.
41
- openai.api_base = "https://js-partner.openai.azure.com/"
42
- openai.api_key = "your azure openai key"
 
 
 
43
  ```
44
-
45
  2. Install Requirements
46
 
47
  ```
48
- conda env create --file environment.yml
49
  ```
50
 
51
  3. Testing
52
- The project can be run using the provided test.sh script. This script runs a series of commands, each of which initiates a Gym environment and applies different translators to it.
53
 
54
  Here is an example of how to run the script:
55
 
56
  ```
57
- ./test.sh
 
 
 
 
58
  ```
 
 
34
  ```python
35
  import openai
36
  class gpt:
37
+ def __init__(self, args):
38
+ if args.api_type == "azure":
39
+ openai.api_type = "azure"
40
+ openai.api_version = "2023-05-15"
41
+ # Your Azure OpenAI resource's endpoint value.
42
+ openai.api_base = "https://midivi-main-scu1.openai.azure.com/"
43
+ openai.api_key = "your azure key"
44
+ else:
45
+ openai.api_key = "your openai key"
46
  ```
 
47
  2. Install Requirements
48
 
49
  ```
50
+ conda env create --file environment.yaml
51
  ```
52
 
53
  3. Testing
54
+ The project can be run using the provided .sh script in shell/ folder. This script runs a series of commands, each of which initiates a Gym environment and applies different translators to it.
55
 
56
  Here is an example of how to run the script:
57
 
58
  ```
59
+ sh shell/test_cartpole.sh
60
+ ```
61
+ Or you can also test this by copying a command from a .sh script
62
+ ```
63
+ python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider exe_actor --prompt_level 1 --num_trails 1 --distiller guide_generator
64
  ```
65
+ If you use openai key, please add "--api_type openai" at the end of the command!
deciders/act.py CHANGED
@@ -34,7 +34,7 @@ class NaiveAct(gpt):
34
  else:
35
  model = args.gpt_version
36
  self.encoding = tiktoken.encoding_for_model(model)
37
- super().__init__()
38
  self.distiller = distiller
39
  self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
40
  self.default_action = 1
 
34
  else:
35
  model = args.gpt_version
36
  self.encoding = tiktoken.encoding_for_model(model)
37
+ super().__init__(args)
38
  self.distiller = distiller
39
  self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
40
  self.default_action = 1
deciders/utils.py CHANGED
@@ -16,8 +16,8 @@ else:
16
 
17
  Model = Literal["gpt-4", "gpt-35-turbo", "text-davinci-003"]
18
 
19
- from .gpt import gpt
20
- gpt().__init__()
21
 
22
  import timeout_decorator
23
  @timeout_decorator.timeout(30)
 
16
 
17
  Model = Literal["gpt-4", "gpt-35-turbo", "text-davinci-003"]
18
 
19
+ # from .gpt import gpt
20
+ # gpt().__init__()
21
 
22
  import timeout_decorator
23
  @timeout_decorator.timeout(30)
environment.yaml CHANGED
@@ -86,6 +86,7 @@ dependencies:
86
  - zeromq=4.3.4
87
  - zlib=1.2.13
88
  - pip:
 
89
  - absl-py==1.4.0
90
  - aiohttp==3.8.4
91
  - aiosignal==1.3.1
@@ -185,3 +186,4 @@ dependencies:
185
  - win32-setctime==1.1.0
186
  - yarl==1.9.2
187
  - zipp==3.15.0
 
 
86
  - zeromq=4.3.4
87
  - zlib=1.2.13
88
  - pip:
89
+ - ale-py==0.8.1
90
  - absl-py==1.4.0
91
  - aiohttp==3.8.4
92
  - aiosignal==1.3.1
 
186
  - win32-setctime==1.1.0
187
  - yarl==1.9.2
188
  - zipp==3.15.0
189
+ - git+ssh://[email protected]/hyyh28/atari-representation-learning.git
main_reflexion.py CHANGED
@@ -302,6 +302,15 @@ if __name__ == "__main__":
302
 
303
  if args.api_type != "azure" and args.api_type != "openai":
304
  raise ValueError(f"The {args.api_type} is not supported, please use 'azure' or 'openai' !")
 
 
 
 
 
 
 
 
 
305
  # Get the specified translator, environment, and ChatGPT model
306
  env_class = envs.REGISTRY[args.env]
307
  init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer], args)
 
302
 
303
  if args.api_type != "azure" and args.api_type != "openai":
304
  raise ValueError(f"The {args.api_type} is not supported, please use 'azure' or 'openai' !")
305
+
306
+ # Please note when using "azure", the model name is gpt-35-turbo while using "openai", the model name is "gpt-3.5-turbo"
307
+ if args.api_type == "azure":
308
+ if args.gpt_version == "gpt-3.5-turbo":
309
+ args.gpt_version = 'gpt-35-turbo'
310
+ elif args.api_type == "openai":
311
+ if args.gpt_version == "gpt-35-turbo":
312
+ args.gpt_version = 'gpt-3.5-turbo'
313
+
314
  # Get the specified translator, environment, and ChatGPT model
315
  env_class = envs.REGISTRY[args.env]
316
  init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer], args)