Spaces:
Runtime error
Runtime error
hzxwonder
commited on
Commit
·
a32f407
1
Parent(s):
a1c0b89
update gpt
Browse files- README.md +17 -10
- deciders/act.py +1 -1
- deciders/utils.py +2 -2
- environment.yaml +2 -0
- main_reflexion.py +9 -0
README.md
CHANGED
@@ -34,25 +34,32 @@ For `L5` level, we handcraft the few shot examples with domain knowledge in `pro
|
|
34 |
```python
|
35 |
import openai
|
36 |
class gpt:
|
37 |
-
def __init__(self,):
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
43 |
```
|
44 |
-
|
45 |
2. Install Requirements
|
46 |
|
47 |
```
|
48 |
-
conda env create --file environment.
|
49 |
```
|
50 |
|
51 |
3. Testing
|
52 |
-
The project can be run using the provided
|
53 |
|
54 |
Here is an example of how to run the script:
|
55 |
|
56 |
```
|
57 |
-
|
|
|
|
|
|
|
|
|
58 |
```
|
|
|
|
34 |
```python
|
35 |
import openai
|
36 |
class gpt:
|
37 |
+
def __init__(self, args):
|
38 |
+
if args.api_type == "azure":
|
39 |
+
openai.api_type = "azure"
|
40 |
+
openai.api_version = "2023-05-15"
|
41 |
+
# Your Azure OpenAI resource's endpoint value.
|
42 |
+
openai.api_base = "https://midivi-main-scu1.openai.azure.com/"
|
43 |
+
openai.api_key = "your azure key"
|
44 |
+
else:
|
45 |
+
openai.api_key = "your openai key"
|
46 |
```
|
|
|
47 |
2. Install Requirements
|
48 |
|
49 |
```
|
50 |
+
conda env create --file environment.yaml
|
51 |
```
|
52 |
|
53 |
3. Testing
|
54 |
+
The project can be run using the provided .sh script in shell/ folder. This script runs a series of commands, each of which initiates a Gym environment and applies different translators to it.
|
55 |
|
56 |
Here is an example of how to run the script:
|
57 |
|
58 |
```
|
59 |
+
sh shell/test_cartpole.sh
|
60 |
+
```
|
61 |
+
Or you can also test this by copying a command from a .sh script
|
62 |
+
```
|
63 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider exe_actor --prompt_level 1 --num_trails 1 --distiller guide_generator
|
64 |
```
|
65 |
+
If you use openai key, please add "--api_type openai" at the end of the command!
|
deciders/act.py
CHANGED
@@ -34,7 +34,7 @@ class NaiveAct(gpt):
|
|
34 |
else:
|
35 |
model = args.gpt_version
|
36 |
self.encoding = tiktoken.encoding_for_model(model)
|
37 |
-
super().__init__()
|
38 |
self.distiller = distiller
|
39 |
self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
|
40 |
self.default_action = 1
|
|
|
34 |
else:
|
35 |
model = args.gpt_version
|
36 |
self.encoding = tiktoken.encoding_for_model(model)
|
37 |
+
super().__init__(args)
|
38 |
self.distiller = distiller
|
39 |
self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
|
40 |
self.default_action = 1
|
deciders/utils.py
CHANGED
@@ -16,8 +16,8 @@ else:
|
|
16 |
|
17 |
Model = Literal["gpt-4", "gpt-35-turbo", "text-davinci-003"]
|
18 |
|
19 |
-
from .gpt import gpt
|
20 |
-
gpt().__init__()
|
21 |
|
22 |
import timeout_decorator
|
23 |
@timeout_decorator.timeout(30)
|
|
|
16 |
|
17 |
Model = Literal["gpt-4", "gpt-35-turbo", "text-davinci-003"]
|
18 |
|
19 |
+
# from .gpt import gpt
|
20 |
+
# gpt().__init__()
|
21 |
|
22 |
import timeout_decorator
|
23 |
@timeout_decorator.timeout(30)
|
environment.yaml
CHANGED
@@ -86,6 +86,7 @@ dependencies:
|
|
86 |
- zeromq=4.3.4
|
87 |
- zlib=1.2.13
|
88 |
- pip:
|
|
|
89 |
- absl-py==1.4.0
|
90 |
- aiohttp==3.8.4
|
91 |
- aiosignal==1.3.1
|
@@ -185,3 +186,4 @@ dependencies:
|
|
185 |
- win32-setctime==1.1.0
|
186 |
- yarl==1.9.2
|
187 |
- zipp==3.15.0
|
|
|
|
86 |
- zeromq=4.3.4
|
87 |
- zlib=1.2.13
|
88 |
- pip:
|
89 |
+
- ale-py==0.8.1
|
90 |
- absl-py==1.4.0
|
91 |
- aiohttp==3.8.4
|
92 |
- aiosignal==1.3.1
|
|
|
186 |
- win32-setctime==1.1.0
|
187 |
- yarl==1.9.2
|
188 |
- zipp==3.15.0
|
189 |
+
- git+ssh://[email protected]/hyyh28/atari-representation-learning.git
|
main_reflexion.py
CHANGED
@@ -302,6 +302,15 @@ if __name__ == "__main__":
|
|
302 |
|
303 |
if args.api_type != "azure" and args.api_type != "openai":
|
304 |
raise ValueError(f"The {args.api_type} is not supported, please use 'azure' or 'openai' !")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
# Get the specified translator, environment, and ChatGPT model
|
306 |
env_class = envs.REGISTRY[args.env]
|
307 |
init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer], args)
|
|
|
302 |
|
303 |
if args.api_type != "azure" and args.api_type != "openai":
|
304 |
raise ValueError(f"The {args.api_type} is not supported, please use 'azure' or 'openai' !")
|
305 |
+
|
306 |
+
# Please note when using "azure", the model name is gpt-35-turbo while using "openai", the model name is "gpt-3.5-turbo"
|
307 |
+
if args.api_type == "azure":
|
308 |
+
if args.gpt_version == "gpt-3.5-turbo":
|
309 |
+
args.gpt_version = 'gpt-35-turbo'
|
310 |
+
elif args.api_type == "openai":
|
311 |
+
if args.gpt_version == "gpt-35-turbo":
|
312 |
+
args.gpt_version = 'gpt-3.5-turbo'
|
313 |
+
|
314 |
# Get the specified translator, environment, and ChatGPT model
|
315 |
env_class = envs.REGISTRY[args.env]
|
316 |
init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer], args)
|