hzxwonder commited on
Commit
4e5e176
·
1 Parent(s): 377ad82
deciders/act.py CHANGED
@@ -6,7 +6,7 @@ from loguru import logger
6
  from .parser import PARSERS
7
  from langchain.output_parsers import PydanticOutputParser
8
  from langchain.output_parsers import OutputFixingParser
9
- from langchain.chat_models import AzureChatOpenAI
10
  from memory.env_history import EnvironmentHistory
11
  import tiktoken
12
  import json
@@ -88,15 +88,18 @@ class NaiveAct(gpt):
88
  else:
89
  num_action = 1
90
 
91
- autofixing_chat = AzureChatOpenAI(
92
- openai_api_type=openai.api_type,
93
- openai_api_version=openai.api_version,
94
- openai_api_base=openai.api_base,
95
- openai_api_key=openai.api_key,
96
- deployment_name=self.args.gpt_version,
97
- temperature=self.temperature,
98
- max_tokens=self.max_tokens
99
- )
 
 
 
100
 
101
  parser = PydanticOutputParser(pydantic_object=PARSERS[num_action])
102
  autofixing_parser = OutputFixingParser.from_llm(
 
6
  from .parser import PARSERS
7
  from langchain.output_parsers import PydanticOutputParser
8
  from langchain.output_parsers import OutputFixingParser
9
+ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
10
  from memory.env_history import EnvironmentHistory
11
  import tiktoken
12
  import json
 
88
  else:
89
  num_action = 1
90
 
91
+ if self.args.api_type == "azure":
92
+ autofixing_chat = AzureChatOpenAI(
93
+ openai_api_type=openai.api_type,
94
+ openai_api_version=openai.api_version,
95
+ openai_api_base=openai.api_base,
96
+ openai_api_key=openai.api_key,
97
+ deployment_name=self.args.gpt_version,
98
+ temperature=self.temperature,
99
+ max_tokens=self.max_tokens
100
+ )
101
+ elif self.args.api_type == "openai":
102
+ autofixing_chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key)
103
 
104
  parser = PydanticOutputParser(pydantic_object=PARSERS[num_action])
105
  autofixing_parser = OutputFixingParser.from_llm(
deciders/cot.py CHANGED
@@ -1,6 +1,6 @@
1
  import openai
2
  from .misc import history_to_str
3
- from langchain.chat_models import AzureChatOpenAI
4
  from langchain.prompts.chat import (
5
  PromptTemplate,
6
  ChatPromptTemplate,
@@ -31,15 +31,20 @@ class ChainOfThought(NaiveAct):
31
  ):
32
  self.action_description = action_description
33
  self._add_history_before_action(game_description, goal_description, state_description)
34
- chat = AzureChatOpenAI(
35
- openai_api_type=openai.api_type,
36
- openai_api_version=openai.api_version,
37
- openai_api_base=openai.api_base,
38
- openai_api_key=openai.api_key,
39
- deployment_name=self.args.gpt_version,
40
- temperature=self.temperature,
41
- max_tokens=self.max_tokens
42
- )
 
 
 
 
 
43
 
44
  suffix_flag = False
45
  reply_format_description = \
 
1
  import openai
2
  from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
4
  from langchain.prompts.chat import (
5
  PromptTemplate,
6
  ChatPromptTemplate,
 
31
  ):
32
  self.action_description = action_description
33
  self._add_history_before_action(game_description, goal_description, state_description)
34
+
35
+ if self.args.api_type == "azure":
36
+ chat = AzureChatOpenAI(
37
+ openai_api_type=openai.api_type,
38
+ openai_api_version=openai.api_version,
39
+ openai_api_base=openai.api_base,
40
+ openai_api_key=openai.api_key,
41
+ deployment_name=self.args.gpt_version,
42
+ temperature=self.temperature,
43
+ max_tokens=self.max_tokens
44
+ )
45
+ elif self.args.api_type == "openai":
46
+ chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key)
47
+
48
 
49
  suffix_flag = False
50
  reply_format_description = \
deciders/exe.py CHANGED
@@ -1,6 +1,6 @@
1
  import openai
2
  from .misc import history_to_str
3
- from langchain.chat_models import AzureChatOpenAI
4
  from langchain.prompts.chat import (
5
  PromptTemplate,
6
  ChatPromptTemplate,
@@ -96,15 +96,19 @@ class EXE(NaiveAct):
96
  self.game_description = game_description
97
  self.goal_description = goal_description
98
  self.env_history.add("observation", state_description)
99
- chat = AzureChatOpenAI(
100
- openai_api_type=openai.api_type,
101
- openai_api_version=openai.api_version,
102
- openai_api_base=openai.api_base,
103
- openai_api_key=openai.api_key,
104
- deployment_name=self.args.gpt_version,
105
- temperature=self.temperature,
106
- max_tokens=self.max_tokens,
107
- )
 
 
 
 
108
  # print(self.logger)
109
  reply_format_description = \
110
  "Your response should choose an optimal action from valid action list, and terminated with following format: "
 
1
  import openai
2
  from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
4
  from langchain.prompts.chat import (
5
  PromptTemplate,
6
  ChatPromptTemplate,
 
96
  self.game_description = game_description
97
  self.goal_description = goal_description
98
  self.env_history.add("observation", state_description)
99
+
100
+ if self.args.api_type == "azure":
101
+ chat = AzureChatOpenAI(
102
+ openai_api_type=openai.api_type,
103
+ openai_api_version=openai.api_version,
104
+ openai_api_base=openai.api_base,
105
+ openai_api_key=openai.api_key,
106
+ deployment_name=self.args.gpt_version,
107
+ temperature=self.temperature,
108
+ max_tokens=self.max_tokens
109
+ )
110
+ elif self.args.api_type == "openai":
111
+ chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key)
112
  # print(self.logger)
113
  reply_format_description = \
114
  "Your response should choose an optimal action from valid action list, and terminated with following format: "
deciders/reflexion.py CHANGED
@@ -1,6 +1,6 @@
1
  import openai
2
  from .misc import history_to_str
3
- from langchain.chat_models import AzureChatOpenAI
4
  from langchain.prompts.chat import (
5
  PromptTemplate,
6
  ChatPromptTemplate,
@@ -53,15 +53,19 @@ class Reflexion(NaiveAct):
53
  self.game_description = game_description
54
  self.goal_description = goal_description
55
  self.env_history.add("observation", state_description)
56
- chat = AzureChatOpenAI(
57
- openai_api_type=openai.api_type,
58
- openai_api_version=openai.api_version,
59
- openai_api_base=openai.api_base,
60
- openai_api_key=openai.api_key,
61
- deployment_name=self.args.gpt_version,
62
- temperature=self.temperature,
63
- max_tokens=self.max_tokens,
64
- )
 
 
 
 
65
  suffix_flag = False
66
  reply_format_description = \
67
  "Your response should choose an optimal action from a valid action list and terminate with the following format: "
 
1
  import openai
2
  from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
4
  from langchain.prompts.chat import (
5
  PromptTemplate,
6
  ChatPromptTemplate,
 
53
  self.game_description = game_description
54
  self.goal_description = goal_description
55
  self.env_history.add("observation", state_description)
56
+
57
+ if self.args.api_type == "azure":
58
+ chat = AzureChatOpenAI(
59
+ openai_api_type=openai.api_type,
60
+ openai_api_version=openai.api_version,
61
+ openai_api_base=openai.api_base,
62
+ openai_api_key=openai.api_key,
63
+ deployment_name=self.args.gpt_version,
64
+ temperature=self.temperature,
65
+ max_tokens=self.max_tokens
66
+ )
67
+ elif self.args.api_type == "openai":
68
+ chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key)
69
  suffix_flag = False
70
  reply_format_description = \
71
  "Your response should choose an optimal action from a valid action list and terminate with the following format: "
deciders/self_consistency.py CHANGED
@@ -1,6 +1,6 @@
1
  import openai
2
  from .misc import history_to_str
3
- from langchain.chat_models import AzureChatOpenAI
4
  from langchain.prompts.chat import (
5
  PromptTemplate,
6
  ChatPromptTemplate,
@@ -34,15 +34,19 @@ class SelfConsistency(NaiveAct):
34
  # print(self.temperature)
35
  self.action_description = action_description
36
  self._add_history_before_action(game_description, goal_description, state_description)
37
- chat = AzureChatOpenAI(
38
- openai_api_type=openai.api_type,
39
- openai_api_version=openai.api_version,
40
- openai_api_base=openai.api_base,
41
- openai_api_key=openai.api_key,
42
- deployment_name=self.args.gpt_version,
43
- temperature=self.temperature,
44
- max_tokens=self.max_tokens
45
- )
 
 
 
 
46
 
47
  suffix_flag = False
48
  reply_format_description = \
 
1
  import openai
2
  from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
4
  from langchain.prompts.chat import (
5
  PromptTemplate,
6
  ChatPromptTemplate,
 
34
  # print(self.temperature)
35
  self.action_description = action_description
36
  self._add_history_before_action(game_description, goal_description, state_description)
37
+
38
+ if self.args.api_type == "azure":
39
+ chat = AzureChatOpenAI(
40
+ openai_api_type=openai.api_type,
41
+ openai_api_version=openai.api_version,
42
+ openai_api_base=openai.api_base,
43
+ openai_api_key=openai.api_key,
44
+ deployment_name=self.args.gpt_version,
45
+ temperature=self.temperature,
46
+ max_tokens=self.max_tokens
47
+ )
48
+ elif self.args.api_type == "openai":
49
+ chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key)
50
 
51
  suffix_flag = False
52
  reply_format_description = \
deciders/selfask.py CHANGED
@@ -1,6 +1,6 @@
1
  import openai
2
  from .misc import history_to_str
3
- from langchain.chat_models import AzureChatOpenAI
4
  from langchain.prompts.chat import (
5
  PromptTemplate,
6
  ChatPromptTemplate,
@@ -31,15 +31,19 @@ class SelfAskAct(NaiveAct):
31
  ):
32
  self.action_description = action_description
33
  self._add_history_before_action(game_description, goal_description, state_description)
34
- chat = AzureChatOpenAI(
35
- openai_api_type=openai.api_type,
36
- openai_api_version=openai.api_version,
37
- openai_api_base=openai.api_base,
38
- openai_api_key=openai.api_key,
39
- deployment_name=self.args.gpt_version,
40
- temperature=self.temperature,
41
- max_tokens=self.max_tokens
42
- )
 
 
 
 
43
 
44
  suffix_flag = False
45
  reply_format_description = \
 
1
  import openai
2
  from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
4
  from langchain.prompts.chat import (
5
  PromptTemplate,
6
  ChatPromptTemplate,
 
31
  ):
32
  self.action_description = action_description
33
  self._add_history_before_action(game_description, goal_description, state_description)
34
+
35
+ if self.args.api_type == "azure":
36
+ chat = AzureChatOpenAI(
37
+ openai_api_type=openai.api_type,
38
+ openai_api_version=openai.api_version,
39
+ openai_api_base=openai.api_base,
40
+ openai_api_key=openai.api_key,
41
+ deployment_name=self.args.gpt_version,
42
+ temperature=self.temperature,
43
+ max_tokens=self.max_tokens
44
+ )
45
+ elif self.args.api_type == "openai":
46
+ chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key)
47
 
48
  suffix_flag = False
49
  reply_format_description = \
deciders/spp.py CHANGED
@@ -1,6 +1,6 @@
1
  import openai
2
  from .misc import history_to_str
3
- from langchain.chat_models import AzureChatOpenAI
4
  from langchain.prompts.chat import (
5
  PromptTemplate,
6
  ChatPromptTemplate,
@@ -30,15 +30,19 @@ class SPP(NaiveAct):
30
  ):
31
  self.action_description = action_description
32
  self._add_history_before_action(game_description, goal_description, state_description)
33
- chat = AzureChatOpenAI(
34
- openai_api_type=openai.api_type,
35
- openai_api_version=openai.api_version,
36
- openai_api_base=openai.api_base,
37
- openai_api_key=openai.api_key,
38
- deployment_name=self.args.gpt_version,
39
- temperature=self.temperature,
40
- max_tokens=self.max_tokens
41
- )
 
 
 
 
42
 
43
  self.fewshot_example = self.irr_few_shot_examples if not self.fewshot_example else self.fewshot_example
44
  self.irr_few_shot_examples = self.irr_few_shot_examples if not self.fewshot_example else self.fewshot_example
 
1
  import openai
2
  from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
4
  from langchain.prompts.chat import (
5
  PromptTemplate,
6
  ChatPromptTemplate,
 
30
  ):
31
  self.action_description = action_description
32
  self._add_history_before_action(game_description, goal_description, state_description)
33
+
34
+ if self.args.api_type == "azure":
35
+ chat = AzureChatOpenAI(
36
+ openai_api_type=openai.api_type,
37
+ openai_api_version=openai.api_version,
38
+ openai_api_base=openai.api_base,
39
+ openai_api_key=openai.api_key,
40
+ deployment_name=self.args.gpt_version,
41
+ temperature=self.temperature,
42
+ max_tokens=self.max_tokens
43
+ )
44
+ elif self.args.api_type == "openai":
45
+ chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key)
46
 
47
  self.fewshot_example = self.irr_few_shot_examples if not self.fewshot_example else self.fewshot_example
48
  self.irr_few_shot_examples = self.irr_few_shot_examples if not self.fewshot_example else self.fewshot_example
deciders/utils.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import sys
3
- import openai
4
  from tenacity import (
5
  retry,
6
  stop_after_attempt, # type: ignore
@@ -24,9 +24,10 @@ import timeout_decorator
24
  def run_chain(chain, *args, **kwargs):
25
  return chain.run(*args, **kwargs)
26
 
27
- @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
28
- def get_completion(prompt: str, engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None) -> str:
29
  response = openai.Completion.create(
 
30
  engine=engine,
31
  prompt=prompt,
32
  temperature=temperature,
@@ -39,7 +40,7 @@ def get_completion(prompt: str, engine: str = "gpt-35-turbo", temperature: float
39
  )
40
  return response.choices[0].text
41
 
42
- @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
43
  def get_chat(prompt: str, model: str = "gpt-35-turbo", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
44
  assert model != "text-davinci-003"
45
  messages = [
@@ -58,3 +59,4 @@ def get_chat(prompt: str, model: str = "gpt-35-turbo", engine: str = "gpt-35-tur
58
  # request_timeout = 1
59
  )
60
  return response.choices[0]["message"]["content"]
 
 
1
  import os
2
  import sys
3
+ import openai # 0.27.8
4
  from tenacity import (
5
  retry,
6
  stop_after_attempt, # type: ignore
 
24
  def run_chain(chain, *args, **kwargs):
25
  return chain.run(*args, **kwargs)
26
 
27
+ # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
28
+ def get_completion(prompt: str, api_type: str = "azure", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None) -> str:
29
  response = openai.Completion.create(
30
+ model=engine,
31
  engine=engine,
32
  prompt=prompt,
33
  temperature=temperature,
 
40
  )
41
  return response.choices[0].text
42
 
43
+ # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
44
  def get_chat(prompt: str, model: str = "gpt-35-turbo", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
45
  assert model != "text-davinci-003"
46
  messages = [
 
59
  # request_timeout = 1
60
  )
61
  return response.choices[0]["message"]["content"]
62
+
distillers/guider.py CHANGED
@@ -107,7 +107,7 @@ class Guidance_Generator():
107
  reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
108
  else:
109
  reflection_query = self._generate_summary_query(traj, memory)
110
- reflection = get_completion(reflection_query,engine=self.args.gpt_version)
111
  logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
112
  logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
113
  return reflection
 
107
  reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
108
  else:
109
  reflection_query = self._generate_summary_query(traj, memory)
110
+ reflection = get_completion(reflection_query, engine=self.args.gpt_version)
111
  logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
112
  logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
113
  return reflection
distillers/traj_prompt_summarizer.py CHANGED
@@ -54,7 +54,7 @@ class TrajPromptSummarizer():
54
  reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
55
  else:
56
  reflection_query = self._generate_summary_query(traj, memory)
57
- reflection = get_completion(reflection_query, engine=self.args.gpt_version)
58
  logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
59
  logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
60
  return reflection
 
54
  reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
55
  else:
56
  reflection_query = self._generate_summary_query(traj, memory)
57
+ reflection = get_completion(reflection_query, engine=self.args.gpt_version)
58
  logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
59
  logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
60
  return reflection
main_reflexion.py CHANGED
@@ -292,8 +292,16 @@ if __name__ == "__main__":
292
  default=1,
293
  help="Whether only taking local observations, if is_only_local_obs = 1, only using local obs"
294
  )
 
 
 
 
 
 
295
  args = parser.parse_args()
296
 
 
 
297
  # Get the specified translator, environment, and ChatGPT model
298
  env_class = envs.REGISTRY[args.env]
299
  init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer], args)
 
292
  default=1,
293
  help="Whether only taking local observations, if is_only_local_obs = 1, only using local obs"
294
  )
295
+ parser.add_argument(
296
+ "--api_type",
297
+ type=str,
298
+ default="azure",
299
+ help="choose api type, now support azure and openai"
300
+ )
301
  args = parser.parse_args()
302
 
303
+ if args.api_type != "azure" and args.api_type != "openai":
304
+ raise ValueError(f"The {args.api_type} is not supported, please use 'azure' or 'openai' !")
305
  # Get the specified translator, environment, and ChatGPT model
306
  env_class = envs.REGISTRY[args.env]
307
  init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer], args)