File size: 13,754 Bytes
594c559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import copy
from typing import Dict, List, Optional, Sequence, Tuple

from camel.agents import (
    ChatAgent,
    TaskPlannerAgent,
    TaskSpecifyAgent,
)
from camel.agents.chat_agent import ChatAgentResponse
from camel.messages import ChatMessage, UserChatMessage
from camel.messages import SystemMessage
from camel.typing import ModelType, RoleType, TaskType, PhaseType
from chatdev.utils import log_arguments, log_and_print_online


@log_arguments
class RolePlaying:
    r"""Role playing between two agents.

    Args:
        assistant_role_name (str): The name of the role played by the
            assistant.
        user_role_name (str): The name of the role played by the user.
        critic_role_name (str): The name of the role played by the critic.
            (default: :obj:`"critic"`)
        task_prompt (str, optional): A prompt for the task to be performed.
            (default: :obj:`""`)
        with_task_specify (bool, optional): Whether to use a task specify
            agent. (default: :obj:`True`)
        with_task_planner (bool, optional): Whether to use a task planner
            agent. (default: :obj:`False`)
        with_critic_in_the_loop (bool, optional): Whether to include a critic
            in the loop. (default: :obj:`False`)
        model_type (ModelType, optional): The type of backend model to use.
            (default: :obj:`ModelType.GPT_3_5_TURBO`)
        task_type (TaskType, optional): The type of task to perform.
            (default: :obj:`TaskType.AI_SOCIETY`)
        assistant_agent_kwargs (Dict, optional): Additional arguments to pass
            to the assistant agent. (default: :obj:`None`)
        user_agent_kwargs (Dict, optional): Additional arguments to pass to
            the user agent. (default: :obj:`None`)
        task_specify_agent_kwargs (Dict, optional): Additional arguments to
            pass to the task specify agent. (default: :obj:`None`)
        task_planner_agent_kwargs (Dict, optional): Additional arguments to
            pass to the task planner agent. (default: :obj:`None`)
        critic_kwargs (Dict, optional): Additional arguments to pass to the
            critic. (default: :obj:`None`)
        sys_msg_generator_kwargs (Dict, optional): Additional arguments to
            pass to the system message generator. (default: :obj:`None`)
        extend_sys_msg_meta_dicts (List[Dict], optional): A list of dicts to
            extend the system message meta dicts with. (default: :obj:`None`)
        extend_task_specify_meta_dict (Dict, optional): A dict to extend the
            task specify meta dict with. (default: :obj:`None`)
    """

    def __init__(
            self,
            assistant_role_name: str,
            user_role_name: str,
            critic_role_name: str = "critic",
            task_prompt: str = "",
            assistant_role_prompt: str = "",
            user_role_prompt: str = "",
            user_role_type: Optional[RoleType] = None,
            assistant_role_type: Optional[RoleType] = None,
            with_task_specify: bool = True,
            with_task_planner: bool = False,
            with_critic_in_the_loop: bool = False,
            critic_criteria: Optional[str] = None,
            model_type: ModelType = ModelType.GPT_3_5_TURBO,
            task_type: TaskType = TaskType.AI_SOCIETY,
            assistant_agent_kwargs: Optional[Dict] = None,
            user_agent_kwargs: Optional[Dict] = None,
            task_specify_agent_kwargs: Optional[Dict] = None,
            task_planner_agent_kwargs: Optional[Dict] = None,
            critic_kwargs: Optional[Dict] = None,
            sys_msg_generator_kwargs: Optional[Dict] = None,
            extend_sys_msg_meta_dicts: Optional[List[Dict]] = None,
            extend_task_specify_meta_dict: Optional[Dict] = None,
    ) -> None:
        self.with_task_specify = with_task_specify
        self.with_task_planner = with_task_planner
        self.with_critic_in_the_loop = with_critic_in_the_loop
        self.model_type = model_type
        self.task_type = task_type

        if with_task_specify:
            task_specify_meta_dict = dict()
            if self.task_type in [TaskType.AI_SOCIETY, TaskType.MISALIGNMENT]:
                task_specify_meta_dict.update(
                    dict(assistant_role=assistant_role_name,
                         user_role=user_role_name))
            if extend_task_specify_meta_dict is not None:
                task_specify_meta_dict.update(extend_task_specify_meta_dict)

            task_specify_agent = TaskSpecifyAgent(
                self.model_type,
                task_type=self.task_type,
                **(task_specify_agent_kwargs or {}),
            )
            self.specified_task_prompt = task_specify_agent.step(
                task_prompt,
                meta_dict=task_specify_meta_dict,
            )
            task_prompt = self.specified_task_prompt
        else:
            self.specified_task_prompt = None

        if with_task_planner:
            task_planner_agent = TaskPlannerAgent(
                self.model_type,
                **(task_planner_agent_kwargs or {}),
            )
            self.planned_task_prompt = task_planner_agent.step(task_prompt)
            task_prompt = f"{task_prompt}\n{self.planned_task_prompt}"
        else:
            self.planned_task_prompt = None

        self.task_prompt = task_prompt

        chatdev_prompt_template = "ChatDev is a software company powered by multiple intelligent agents, such as chief executive officer, chief human resources officer, chief product officer, chief technology officer, etc, with a multi-agent organizational structure and the mission of \"changing the digital world through programming\"."

        sys_msg_meta_dicts = [dict(chatdev_prompt=chatdev_prompt_template, task=task_prompt)] * 2
        if (extend_sys_msg_meta_dicts is None and self.task_type in [TaskType.AI_SOCIETY, TaskType.MISALIGNMENT,
                                                                     TaskType.CHATDEV]):
            extend_sys_msg_meta_dicts = [dict(assistant_role=assistant_role_name, user_role=user_role_name)] * 2
        if extend_sys_msg_meta_dicts is not None:
            sys_msg_meta_dicts = [{**sys_msg_meta_dict, **extend_sys_msg_meta_dict} for
                                  sys_msg_meta_dict, extend_sys_msg_meta_dict in
                                  zip(sys_msg_meta_dicts, extend_sys_msg_meta_dicts)]

        self.assistant_sys_msg = SystemMessage(role_name=assistant_role_name, role_type=RoleType.DEFAULT,
                                               meta_dict=sys_msg_meta_dicts[0],
                                               content=assistant_role_prompt.format(**sys_msg_meta_dicts[0]))
        self.user_sys_msg = SystemMessage(role_name=user_role_name, role_type=RoleType.DEFAULT,
                                          meta_dict=sys_msg_meta_dicts[1],
                                          content=user_role_prompt.format(**sys_msg_meta_dicts[1]))

        self.assistant_agent: ChatAgent = ChatAgent(self.assistant_sys_msg, model_type,
                                                    **(assistant_agent_kwargs or {}), )
        self.user_agent: ChatAgent = ChatAgent(self.user_sys_msg, model_type, **(user_agent_kwargs or {}), )

        if with_critic_in_the_loop:
            raise ValueError("with_critic_in_the_loop not available")
            # if critic_role_name.lower() == "human":
            #     self.critic = Human(**(critic_kwargs or {}))
            # else:
            #     critic_criteria = (critic_criteria or "improving the task performance")
            #     critic_msg_meta_dict = dict(critic_role=critic_role_name, criteria=critic_criteria,
            #                                 **sys_msg_meta_dicts[0])
            #     self.critic_sys_msg = sys_msg_generator.from_dict(critic_msg_meta_dict,
            #                                                       role_tuple=(critic_role_name, RoleType.CRITIC), )
            #     self.critic = CriticAgent(self.critic_sys_msg, model_type, **(critic_kwargs or {}), )
        else:
            self.critic = None

    def init_chat(self, phase_type: PhaseType = None,
                  placeholders=None, phase_prompt=None):
        r"""Initializes the chat by resetting both the assistant and user
        agents, and sending the system messages again to the agents using
        chat messages. Returns the assistant's introductory message and the
        user's response messages.

        Returns:
            A tuple containing an `AssistantChatMessage` representing the
            assistant's introductory message, and a list of `ChatMessage`s
            representing the user's response messages.
        """
        if placeholders is None:
            placeholders = {}
        self.assistant_agent.reset()
        self.user_agent.reset()

        # refactored ChatDev
        content = phase_prompt.format(
            **({"assistant_role": self.assistant_agent.role_name} | placeholders)
        )
        user_msg = UserChatMessage(
            role_name=self.user_sys_msg.role_name,
            role="user",
            content=content
            # content here will be concatenated with assistant role prompt (because we mock user and send msg to assistant) in the ChatAgent.step
        )
        pseudo_msg = copy.deepcopy(user_msg)
        pseudo_msg.role = "assistant"
        self.user_agent.update_messages(pseudo_msg)

        # here we concatenate to store the real message in the log
        log_and_print_online(self.user_agent.role_name,
                             "**[Start Chat]**\n\n[" + self.assistant_agent.system_message.content + "]\n\n" + content)
        return None, user_msg

    def process_messages(
            self,
            messages: Sequence[ChatMessage],
    ) -> ChatMessage:
        r"""Processes a list of chat messages, returning the processed message.
        If multiple messages are provided and `with_critic_in_the_loop`
        is `False`, raises a `ValueError`. If no messages are provided, also
        raises a `ValueError`.

        Args:
            messages:

        Returns:
            A single `ChatMessage` representing the processed message.
        """
        if len(messages) == 0:
            raise ValueError("No messages to process.")
        if len(messages) > 1 and not self.with_critic_in_the_loop:
            raise ValueError("Got than one message to process. "
                             f"Num of messages: {len(messages)}.")
        elif self.with_critic_in_the_loop and self.critic is not None:
            processed_msg = self.critic.step(messages)
        else:
            processed_msg = messages[0]

        return processed_msg

    def step(
            self,
            user_msg: ChatMessage,
            assistant_only: bool,
    ) -> Tuple[ChatAgentResponse, ChatAgentResponse]:
        assert isinstance(user_msg, ChatMessage), print("broken user_msg: " + str(user_msg))

        # print("assistant...")
        user_msg_rst = user_msg.set_user_role_at_backend()
        assistant_response = self.assistant_agent.step(user_msg_rst)
        if assistant_response.terminated or assistant_response.msgs is None:
            return (
                ChatAgentResponse([assistant_response.msgs], assistant_response.terminated, assistant_response.info),
                ChatAgentResponse([], False, {}))
        assistant_msg = self.process_messages(assistant_response.msgs)
        if self.assistant_agent.info:
            return (ChatAgentResponse([assistant_msg], assistant_response.terminated, assistant_response.info),
                    ChatAgentResponse([], False, {}))
        self.assistant_agent.update_messages(assistant_msg)

        if assistant_only:
            return (
                ChatAgentResponse([assistant_msg], assistant_response.terminated, assistant_response.info),
                ChatAgentResponse([], False, {})
            )

        # print("user...")
        assistant_msg_rst = assistant_msg.set_user_role_at_backend()
        user_response = self.user_agent.step(assistant_msg_rst)
        if user_response.terminated or user_response.msgs is None:
            return (ChatAgentResponse([assistant_msg], assistant_response.terminated, assistant_response.info),
                    ChatAgentResponse([user_response], user_response.terminated, user_response.info))
        user_msg = self.process_messages(user_response.msgs)
        if self.user_agent.info:
            return (ChatAgentResponse([assistant_msg], assistant_response.terminated, assistant_response.info),
                    ChatAgentResponse([user_msg], user_response.terminated, user_response.info))
        self.user_agent.update_messages(user_msg)

        return (
            ChatAgentResponse([assistant_msg], assistant_response.terminated, assistant_response.info),
            ChatAgentResponse([user_msg], user_response.terminated, user_response.info),
        )