Spaces:
Runtime error
Runtime error
hzxwonder
commited on
Commit
·
4e5e176
1
Parent(s):
377ad82
update
Browse files- deciders/act.py +13 -10
- deciders/cot.py +15 -10
- deciders/exe.py +14 -10
- deciders/reflexion.py +14 -10
- deciders/self_consistency.py +14 -10
- deciders/selfask.py +14 -10
- deciders/spp.py +14 -10
- deciders/utils.py +6 -4
- distillers/guider.py +1 -1
- distillers/traj_prompt_summarizer.py +1 -1
- main_reflexion.py +8 -0
deciders/act.py
CHANGED
@@ -6,7 +6,7 @@ from loguru import logger
|
|
6 |
from .parser import PARSERS
|
7 |
from langchain.output_parsers import PydanticOutputParser
|
8 |
from langchain.output_parsers import OutputFixingParser
|
9 |
-
from langchain.chat_models import AzureChatOpenAI
|
10 |
from memory.env_history import EnvironmentHistory
|
11 |
import tiktoken
|
12 |
import json
|
@@ -88,15 +88,18 @@ class NaiveAct(gpt):
|
|
88 |
else:
|
89 |
num_action = 1
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
100 |
|
101 |
parser = PydanticOutputParser(pydantic_object=PARSERS[num_action])
|
102 |
autofixing_parser = OutputFixingParser.from_llm(
|
|
|
6 |
from .parser import PARSERS
|
7 |
from langchain.output_parsers import PydanticOutputParser
|
8 |
from langchain.output_parsers import OutputFixingParser
|
9 |
+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
10 |
from memory.env_history import EnvironmentHistory
|
11 |
import tiktoken
|
12 |
import json
|
|
|
88 |
else:
|
89 |
num_action = 1
|
90 |
|
91 |
+
if self.args.api_type == "azure":
|
92 |
+
autofixing_chat = AzureChatOpenAI(
|
93 |
+
openai_api_type=openai.api_type,
|
94 |
+
openai_api_version=openai.api_version,
|
95 |
+
openai_api_base=openai.api_base,
|
96 |
+
openai_api_key=openai.api_key,
|
97 |
+
deployment_name=self.args.gpt_version,
|
98 |
+
temperature=self.temperature,
|
99 |
+
max_tokens=self.max_tokens
|
100 |
+
)
|
101 |
+
elif self.args.api_type == "openai":
|
102 |
+
autofixing_chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key)
|
103 |
|
104 |
parser = PydanticOutputParser(pydantic_object=PARSERS[num_action])
|
105 |
autofixing_parser = OutputFixingParser.from_llm(
|
deciders/cot.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import openai
|
2 |
from .misc import history_to_str
|
3 |
-
from langchain.chat_models import AzureChatOpenAI
|
4 |
from langchain.prompts.chat import (
|
5 |
PromptTemplate,
|
6 |
ChatPromptTemplate,
|
@@ -31,15 +31,20 @@ class ChainOfThought(NaiveAct):
|
|
31 |
):
|
32 |
self.action_description = action_description
|
33 |
self._add_history_before_action(game_description, goal_description, state_description)
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
suffix_flag = False
|
45 |
reply_format_description = \
|
|
|
1 |
import openai
|
2 |
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
4 |
from langchain.prompts.chat import (
|
5 |
PromptTemplate,
|
6 |
ChatPromptTemplate,
|
|
|
31 |
):
|
32 |
self.action_description = action_description
|
33 |
self._add_history_before_action(game_description, goal_description, state_description)
|
34 |
+
|
35 |
+
if self.args.api_type == "azure":
|
36 |
+
chat = AzureChatOpenAI(
|
37 |
+
openai_api_type=openai.api_type,
|
38 |
+
openai_api_version=openai.api_version,
|
39 |
+
openai_api_base=openai.api_base,
|
40 |
+
openai_api_key=openai.api_key,
|
41 |
+
deployment_name=self.args.gpt_version,
|
42 |
+
temperature=self.temperature,
|
43 |
+
max_tokens=self.max_tokens
|
44 |
+
)
|
45 |
+
elif self.args.api_type == "openai":
|
46 |
+
chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key)
|
47 |
+
|
48 |
|
49 |
suffix_flag = False
|
50 |
reply_format_description = \
|
deciders/exe.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import openai
|
2 |
from .misc import history_to_str
|
3 |
-
from langchain.chat_models import AzureChatOpenAI
|
4 |
from langchain.prompts.chat import (
|
5 |
PromptTemplate,
|
6 |
ChatPromptTemplate,
|
@@ -96,15 +96,19 @@ class EXE(NaiveAct):
|
|
96 |
self.game_description = game_description
|
97 |
self.goal_description = goal_description
|
98 |
self.env_history.add("observation", state_description)
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
108 |
# print(self.logger)
|
109 |
reply_format_description = \
|
110 |
"Your response should choose an optimal action from valid action list, and terminated with following format: "
|
|
|
1 |
import openai
|
2 |
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
4 |
from langchain.prompts.chat import (
|
5 |
PromptTemplate,
|
6 |
ChatPromptTemplate,
|
|
|
96 |
self.game_description = game_description
|
97 |
self.goal_description = goal_description
|
98 |
self.env_history.add("observation", state_description)
|
99 |
+
|
100 |
+
if self.args.api_type == "azure":
|
101 |
+
chat = AzureChatOpenAI(
|
102 |
+
openai_api_type=openai.api_type,
|
103 |
+
openai_api_version=openai.api_version,
|
104 |
+
openai_api_base=openai.api_base,
|
105 |
+
openai_api_key=openai.api_key,
|
106 |
+
deployment_name=self.args.gpt_version,
|
107 |
+
temperature=self.temperature,
|
108 |
+
max_tokens=self.max_tokens
|
109 |
+
)
|
110 |
+
elif self.args.api_type == "openai":
|
111 |
+
chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key)
|
112 |
# print(self.logger)
|
113 |
reply_format_description = \
|
114 |
"Your response should choose an optimal action from valid action list, and terminated with following format: "
|
deciders/reflexion.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import openai
|
2 |
from .misc import history_to_str
|
3 |
-
from langchain.chat_models import AzureChatOpenAI
|
4 |
from langchain.prompts.chat import (
|
5 |
PromptTemplate,
|
6 |
ChatPromptTemplate,
|
@@ -53,15 +53,19 @@ class Reflexion(NaiveAct):
|
|
53 |
self.game_description = game_description
|
54 |
self.goal_description = goal_description
|
55 |
self.env_history.add("observation", state_description)
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
65 |
suffix_flag = False
|
66 |
reply_format_description = \
|
67 |
"Your response should choose an optimal action from a valid action list and terminate with the following format: "
|
|
|
1 |
import openai
|
2 |
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
4 |
from langchain.prompts.chat import (
|
5 |
PromptTemplate,
|
6 |
ChatPromptTemplate,
|
|
|
53 |
self.game_description = game_description
|
54 |
self.goal_description = goal_description
|
55 |
self.env_history.add("observation", state_description)
|
56 |
+
|
57 |
+
if self.args.api_type == "azure":
|
58 |
+
chat = AzureChatOpenAI(
|
59 |
+
openai_api_type=openai.api_type,
|
60 |
+
openai_api_version=openai.api_version,
|
61 |
+
openai_api_base=openai.api_base,
|
62 |
+
openai_api_key=openai.api_key,
|
63 |
+
deployment_name=self.args.gpt_version,
|
64 |
+
temperature=self.temperature,
|
65 |
+
max_tokens=self.max_tokens
|
66 |
+
)
|
67 |
+
elif self.args.api_type == "openai":
|
68 |
+
chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key)
|
69 |
suffix_flag = False
|
70 |
reply_format_description = \
|
71 |
"Your response should choose an optimal action from a valid action list and terminate with the following format: "
|
deciders/self_consistency.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import openai
|
2 |
from .misc import history_to_str
|
3 |
-
from langchain.chat_models import AzureChatOpenAI
|
4 |
from langchain.prompts.chat import (
|
5 |
PromptTemplate,
|
6 |
ChatPromptTemplate,
|
@@ -34,15 +34,19 @@ class SelfConsistency(NaiveAct):
|
|
34 |
# print(self.temperature)
|
35 |
self.action_description = action_description
|
36 |
self._add_history_before_action(game_description, goal_description, state_description)
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
|
47 |
suffix_flag = False
|
48 |
reply_format_description = \
|
|
|
1 |
import openai
|
2 |
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
4 |
from langchain.prompts.chat import (
|
5 |
PromptTemplate,
|
6 |
ChatPromptTemplate,
|
|
|
34 |
# print(self.temperature)
|
35 |
self.action_description = action_description
|
36 |
self._add_history_before_action(game_description, goal_description, state_description)
|
37 |
+
|
38 |
+
if self.args.api_type == "azure":
|
39 |
+
chat = AzureChatOpenAI(
|
40 |
+
openai_api_type=openai.api_type,
|
41 |
+
openai_api_version=openai.api_version,
|
42 |
+
openai_api_base=openai.api_base,
|
43 |
+
openai_api_key=openai.api_key,
|
44 |
+
deployment_name=self.args.gpt_version,
|
45 |
+
temperature=self.temperature,
|
46 |
+
max_tokens=self.max_tokens
|
47 |
+
)
|
48 |
+
elif self.args.api_type == "openai":
|
49 |
+
chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key)
|
50 |
|
51 |
suffix_flag = False
|
52 |
reply_format_description = \
|
deciders/selfask.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import openai
|
2 |
from .misc import history_to_str
|
3 |
-
from langchain.chat_models import AzureChatOpenAI
|
4 |
from langchain.prompts.chat import (
|
5 |
PromptTemplate,
|
6 |
ChatPromptTemplate,
|
@@ -31,15 +31,19 @@ class SelfAskAct(NaiveAct):
|
|
31 |
):
|
32 |
self.action_description = action_description
|
33 |
self._add_history_before_action(game_description, goal_description, state_description)
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
43 |
|
44 |
suffix_flag = False
|
45 |
reply_format_description = \
|
|
|
1 |
import openai
|
2 |
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
4 |
from langchain.prompts.chat import (
|
5 |
PromptTemplate,
|
6 |
ChatPromptTemplate,
|
|
|
31 |
):
|
32 |
self.action_description = action_description
|
33 |
self._add_history_before_action(game_description, goal_description, state_description)
|
34 |
+
|
35 |
+
if self.args.api_type == "azure":
|
36 |
+
chat = AzureChatOpenAI(
|
37 |
+
openai_api_type=openai.api_type,
|
38 |
+
openai_api_version=openai.api_version,
|
39 |
+
openai_api_base=openai.api_base,
|
40 |
+
openai_api_key=openai.api_key,
|
41 |
+
deployment_name=self.args.gpt_version,
|
42 |
+
temperature=self.temperature,
|
43 |
+
max_tokens=self.max_tokens
|
44 |
+
)
|
45 |
+
elif self.args.api_type == "openai":
|
46 |
+
chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key)
|
47 |
|
48 |
suffix_flag = False
|
49 |
reply_format_description = \
|
deciders/spp.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import openai
|
2 |
from .misc import history_to_str
|
3 |
-
from langchain.chat_models import AzureChatOpenAI
|
4 |
from langchain.prompts.chat import (
|
5 |
PromptTemplate,
|
6 |
ChatPromptTemplate,
|
@@ -30,15 +30,19 @@ class SPP(NaiveAct):
|
|
30 |
):
|
31 |
self.action_description = action_description
|
32 |
self._add_history_before_action(game_description, goal_description, state_description)
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
|
43 |
self.fewshot_example = self.irr_few_shot_examples if not self.fewshot_example else self.fewshot_example
|
44 |
self.irr_few_shot_examples = self.irr_few_shot_examples if not self.fewshot_example else self.fewshot_example
|
|
|
1 |
import openai
|
2 |
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
4 |
from langchain.prompts.chat import (
|
5 |
PromptTemplate,
|
6 |
ChatPromptTemplate,
|
|
|
30 |
):
|
31 |
self.action_description = action_description
|
32 |
self._add_history_before_action(game_description, goal_description, state_description)
|
33 |
+
|
34 |
+
if self.args.api_type == "azure":
|
35 |
+
chat = AzureChatOpenAI(
|
36 |
+
openai_api_type=openai.api_type,
|
37 |
+
openai_api_version=openai.api_version,
|
38 |
+
openai_api_base=openai.api_base,
|
39 |
+
openai_api_key=openai.api_key,
|
40 |
+
deployment_name=self.args.gpt_version,
|
41 |
+
temperature=self.temperature,
|
42 |
+
max_tokens=self.max_tokens
|
43 |
+
)
|
44 |
+
elif self.args.api_type == "openai":
|
45 |
+
chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key)
|
46 |
|
47 |
self.fewshot_example = self.irr_few_shot_examples if not self.fewshot_example else self.fewshot_example
|
48 |
self.irr_few_shot_examples = self.irr_few_shot_examples if not self.fewshot_example else self.fewshot_example
|
deciders/utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import os
|
2 |
import sys
|
3 |
-
import openai
|
4 |
from tenacity import (
|
5 |
retry,
|
6 |
stop_after_attempt, # type: ignore
|
@@ -24,9 +24,10 @@ import timeout_decorator
|
|
24 |
def run_chain(chain, *args, **kwargs):
|
25 |
return chain.run(*args, **kwargs)
|
26 |
|
27 |
-
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
28 |
-
def get_completion(prompt: str, engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None) -> str:
|
29 |
response = openai.Completion.create(
|
|
|
30 |
engine=engine,
|
31 |
prompt=prompt,
|
32 |
temperature=temperature,
|
@@ -39,7 +40,7 @@ def get_completion(prompt: str, engine: str = "gpt-35-turbo", temperature: float
|
|
39 |
)
|
40 |
return response.choices[0].text
|
41 |
|
42 |
-
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
43 |
def get_chat(prompt: str, model: str = "gpt-35-turbo", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
|
44 |
assert model != "text-davinci-003"
|
45 |
messages = [
|
@@ -58,3 +59,4 @@ def get_chat(prompt: str, model: str = "gpt-35-turbo", engine: str = "gpt-35-tur
|
|
58 |
# request_timeout = 1
|
59 |
)
|
60 |
return response.choices[0]["message"]["content"]
|
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
import openai # 0.27.8
|
4 |
from tenacity import (
|
5 |
retry,
|
6 |
stop_after_attempt, # type: ignore
|
|
|
24 |
def run_chain(chain, *args, **kwargs):
|
25 |
return chain.run(*args, **kwargs)
|
26 |
|
27 |
+
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
28 |
+
def get_completion(prompt: str, api_type: str = "azure", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None) -> str:
|
29 |
response = openai.Completion.create(
|
30 |
+
model=engine,
|
31 |
engine=engine,
|
32 |
prompt=prompt,
|
33 |
temperature=temperature,
|
|
|
40 |
)
|
41 |
return response.choices[0].text
|
42 |
|
43 |
+
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
44 |
def get_chat(prompt: str, model: str = "gpt-35-turbo", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
|
45 |
assert model != "text-davinci-003"
|
46 |
messages = [
|
|
|
59 |
# request_timeout = 1
|
60 |
)
|
61 |
return response.choices[0]["message"]["content"]
|
62 |
+
|
distillers/guider.py
CHANGED
@@ -107,7 +107,7 @@ class Guidance_Generator():
|
|
107 |
reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
|
108 |
else:
|
109 |
reflection_query = self._generate_summary_query(traj, memory)
|
110 |
-
reflection = get_completion(reflection_query,engine=self.args.gpt_version)
|
111 |
logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
|
112 |
logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
|
113 |
return reflection
|
|
|
107 |
reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
|
108 |
else:
|
109 |
reflection_query = self._generate_summary_query(traj, memory)
|
110 |
+
reflection = get_completion(reflection_query, engine=self.args.gpt_version)
|
111 |
logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
|
112 |
logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
|
113 |
return reflection
|
distillers/traj_prompt_summarizer.py
CHANGED
@@ -54,7 +54,7 @@ class TrajPromptSummarizer():
|
|
54 |
reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
|
55 |
else:
|
56 |
reflection_query = self._generate_summary_query(traj, memory)
|
57 |
-
reflection = get_completion(reflection_query, engine=self.args.gpt_version)
|
58 |
logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
|
59 |
logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
|
60 |
return reflection
|
|
|
54 |
reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
|
55 |
else:
|
56 |
reflection_query = self._generate_summary_query(traj, memory)
|
57 |
+
reflection = get_completion(reflection_query, engine=self.args.gpt_version)
|
58 |
logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
|
59 |
logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
|
60 |
return reflection
|
main_reflexion.py
CHANGED
@@ -292,8 +292,16 @@ if __name__ == "__main__":
|
|
292 |
default=1,
|
293 |
help="Whether only taking local observations, if is_only_local_obs = 1, only using local obs"
|
294 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
args = parser.parse_args()
|
296 |
|
|
|
|
|
297 |
# Get the specified translator, environment, and ChatGPT model
|
298 |
env_class = envs.REGISTRY[args.env]
|
299 |
init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer], args)
|
|
|
292 |
default=1,
|
293 |
help="Whether only taking local observations, if is_only_local_obs = 1, only using local obs"
|
294 |
)
|
295 |
+
parser.add_argument(
|
296 |
+
"--api_type",
|
297 |
+
type=str,
|
298 |
+
default="azure",
|
299 |
+
help="choose api type, now support azure and openai"
|
300 |
+
)
|
301 |
args = parser.parse_args()
|
302 |
|
303 |
+
if args.api_type != "azure" and args.api_type != "openai":
|
304 |
+
raise ValueError(f"The {args.api_type} is not supported, please use 'azure' or 'openai' !")
|
305 |
# Get the specified translator, environment, and ChatGPT model
|
306 |
env_class = envs.REGISTRY[args.env]
|
307 |
init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer], args)
|