Spaces:
Runtime error
Runtime error
File size: 14,369 Bytes
594c559 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 |
import importlib
import json
import os
import shutil
from datetime import datetime
import logging
import time
from camel.agents import RolePlaying
from camel.configs import ChatGPTConfig
from camel.typing import TaskType, ModelType
from chatdev.chat_env import ChatEnv, ChatEnvConfig
from chatdev.statistics import get_info
from chatdev.utils import log_and_print_online, now
def check_bool(s):
return s.lower() == "true"
class ChatChain:
def __init__(self,
config_path: str = None,
config_phase_path: str = None,
config_role_path: str = None,
task_prompt: str = None,
project_name: str = None,
org_name: str = None,
model_type: ModelType = ModelType.GPT_3_5_TURBO) -> None:
"""
Args:
config_path: path to the ChatChainConfig.json
config_phase_path: path to the PhaseConfig.json
config_role_path: path to the RoleConfig.json
task_prompt: the user input prompt for software
project_name: the user input name for software
org_name: the organization name of the human user
"""
# load config file
self.config_path = config_path
self.config_phase_path = config_phase_path
self.config_role_path = config_role_path
self.project_name = project_name
self.org_name = org_name
self.model_type = model_type
with open(self.config_path, 'r', encoding="utf8") as file:
self.config = json.load(file)
with open(self.config_phase_path, 'r', encoding="utf8") as file:
self.config_phase = json.load(file)
with open(self.config_role_path, 'r', encoding="utf8") as file:
self.config_role = json.load(file)
# init chatchain config and recruitments
self.chain = self.config["chain"]
self.recruitments = self.config["recruitments"]
# init default max chat turn
self.chat_turn_limit_default = 10
# init ChatEnv
self.chat_env_config = ChatEnvConfig(clear_structure=check_bool(self.config["clear_structure"]),
brainstorming=check_bool(self.config["brainstorming"]),
gui_design=check_bool(self.config["gui_design"]),
git_management=check_bool(self.config["git_management"]))
self.chat_env = ChatEnv(self.chat_env_config)
# the user input prompt will be self-improved (if set "self_improve": "True" in ChatChainConfig.json)
# the self-improvement is done in self.preprocess
self.task_prompt_raw = task_prompt
self.task_prompt = ""
# init role prompts
self.role_prompts = dict()
for role in self.config_role:
self.role_prompts[role] = "\n".join(self.config_role[role])
# init log
self.start_time, self.log_filepath = self.get_logfilepath()
# init SimplePhase instances
# import all used phases in PhaseConfig.json from chatdev.phase
# note that in PhaseConfig.json there only exist SimplePhases
# ComposedPhases are defined in ChatChainConfig.json and will be imported in self.execute_step
self.compose_phase_module = importlib.import_module("chatdev.composed_phase")
self.phase_module = importlib.import_module("chatdev.phase")
self.phases = dict()
for phase in self.config_phase:
assistant_role_name = self.config_phase[phase]['assistant_role_name']
user_role_name = self.config_phase[phase]['user_role_name']
phase_prompt = "\n\n".join(self.config_phase[phase]['phase_prompt'])
phase_class = getattr(self.phase_module, phase)
phase_instance = phase_class(assistant_role_name=assistant_role_name,
user_role_name=user_role_name,
phase_prompt=phase_prompt,
role_prompts=self.role_prompts,
phase_name=phase,
model_type=self.model_type,
log_filepath=self.log_filepath)
self.phases[phase] = phase_instance
def make_recruitment(self):
"""
recruit all employees
Returns: None
"""
for employee in self.recruitments:
self.chat_env.recruit(agent_name=employee)
def execute_step(self, phase_item: dict):
"""
execute single phase in the chain
Args:
phase_item: single phase configuration in the ChatChainConfig.json
Returns:
"""
phase = phase_item['phase']
phase_type = phase_item['phaseType']
# For SimplePhase, just look it up from self.phases and conduct the "Phase.execute" method
if phase_type == "SimplePhase":
max_turn_step = phase_item['max_turn_step']
need_reflect = check_bool(phase_item['need_reflect'])
if phase in self.phases:
self.chat_env = self.phases[phase].execute(self.chat_env,
self.chat_turn_limit_default if max_turn_step <= 0 else max_turn_step,
need_reflect)
else:
raise RuntimeError(f"Phase '{phase}' is not yet implemented in chatdev.phase")
# For ComposedPhase, we create instance here then conduct the "ComposedPhase.execute" method
elif phase_type == "ComposedPhase":
cycle_num = phase_item['cycleNum']
composition = phase_item['Composition']
compose_phase_class = getattr(self.compose_phase_module, phase)
if not compose_phase_class:
raise RuntimeError(f"Phase '{phase}' is not yet implemented in chatdev.compose_phase")
compose_phase_instance = compose_phase_class(phase_name=phase,
cycle_num=cycle_num,
composition=composition,
config_phase=self.config_phase,
config_role=self.config_role,
model_type=self.model_type,
log_filepath=self.log_filepath)
self.chat_env = compose_phase_instance.execute(self.chat_env)
else:
raise RuntimeError(f"PhaseType '{phase_type}' is not yet implemented.")
def execute_chain(self):
"""
execute the whole chain based on ChatChainConfig.json
Returns: None
"""
for phase_item in self.chain:
self.execute_step(phase_item)
def get_logfilepath(self):
"""
get the log path (under the software path)
Returns:
start_time: time for starting making the software
log_filepath: path to the log
"""
start_time = now()
filepath = os.path.dirname(__file__)
# root = "/".join(filepath.split("/")[:-1])
root = os.path.dirname(filepath)
# directory = root + "/WareHouse/"
directory = os.path.join(root, "WareHouse")
log_filepath = os.path.join(directory, "{}.log".format("_".join([self.project_name, self.org_name,start_time])))
return start_time, log_filepath
def pre_processing(self):
"""
remove useless files and log some global config settings
Returns: None
"""
if self.chat_env.config.clear_structure:
filepath = os.path.dirname(__file__)
# root = "/".join(filepath.split("/")[:-1])
root = os.path.dirname(filepath)
# directory = root + "/WareHouse"
directory = os.path.join(root, "WareHouse")
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
# logs with error trials are left in WareHouse/
if os.path.isfile(file_path) and not filename.endswith(".py") and not filename.endswith(".log"):
os.remove(file_path)
print("{} Removed.".format(file_path))
software_path = os.path.join(directory, "_".join([self.project_name, self.org_name, self.start_time]))
self.chat_env.set_directory(software_path)
# copy config files to software path
shutil.copy(self.config_path, software_path)
shutil.copy(self.config_phase_path, software_path)
shutil.copy(self.config_role_path, software_path)
# write task prompt to software path
with open(os.path.join(software_path, self.project_name + ".prompt"), "w") as f:
f.write(self.task_prompt_raw)
preprocess_msg = "**[Preprocessing]**\n\n"
chat_gpt_config = ChatGPTConfig()
preprocess_msg += "**ChatDev Starts** ({})\n\n".format(self.start_time)
preprocess_msg += "**Timestamp**: {}\n\n".format(self.start_time)
preprocess_msg += "**config_path**: {}\n\n".format(self.config_path)
preprocess_msg += "**config_phase_path**: {}\n\n".format(self.config_phase_path)
preprocess_msg += "**config_role_path**: {}\n\n".format(self.config_role_path)
preprocess_msg += "**task_prompt**: {}\n\n".format(self.task_prompt_raw)
preprocess_msg += "**project_name**: {}\n\n".format(self.project_name)
preprocess_msg += "**Log File**: {}\n\n".format(self.log_filepath)
preprocess_msg += "**ChatDevConfig**:\n {}\n\n".format(self.chat_env.config.__str__())
preprocess_msg += "**ChatGPTConfig**:\n {}\n\n".format(chat_gpt_config)
log_and_print_online(preprocess_msg)
# init task prompt
if check_bool(self.config['self_improve']):
self.chat_env.env_dict['task_prompt'] = self.self_task_improve(self.task_prompt_raw)
else:
self.chat_env.env_dict['task_prompt'] = self.task_prompt_raw
def post_processing(self):
"""
summarize the production and move log files to the software directory
Returns: None
"""
self.chat_env.write_meta()
filepath = os.path.dirname(__file__)
# root = "/".join(filepath.split("/")[:-1])
root = os.path.dirname(filepath)
post_info = "**[Post Info]**\n\n"
now_time = now()
time_format = "%Y%m%d%H%M%S"
datetime1 = datetime.strptime(self.start_time, time_format)
datetime2 = datetime.strptime(now_time, time_format)
duration = (datetime2 - datetime1).total_seconds()
post_info += "Software Info: {}".format(
get_info(self.chat_env.env_dict['directory'], self.log_filepath) + "\n\n🕑**duration**={:.2f}s\n\n".format(duration))
post_info += "ChatDev Starts ({})".format(self.start_time) + "\n\n"
post_info += "ChatDev Ends ({})".format(now_time) + "\n\n"
if self.chat_env.config.clear_structure:
directory = self.chat_env.env_dict['directory']
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
if os.path.isdir(file_path) and file_path.endswith("__pycache__"):
shutil.rmtree(file_path, ignore_errors=True)
post_info += "{} Removed.".format(file_path) + "\n\n"
log_and_print_online(post_info)
logging.shutdown()
time.sleep(1)
shutil.move(self.log_filepath,
os.path.join(root + "/WareHouse", "_".join([self.project_name, self.org_name, self.start_time]),
os.path.basename(self.log_filepath)))
# @staticmethod
def self_task_improve(self, task_prompt):
"""
ask agent to improve the user query prompt
Args:
task_prompt: original user query prompt
Returns:
revised_task_prompt: revised prompt from the prompt engineer agent
"""
self_task_improve_prompt = """I will give you a short description of a software design requirement,
please rewrite it into a detailed prompt that can make large language model know how to make this software better based this prompt,
the prompt should ensure LLMs build a software that can be run correctly, which is the most import part you need to consider.
remember that the revised prompt should not contain more than 200 words,
here is the short description:\"{}\".
If the revised prompt is revised_version_of_the_description,
then you should return a message in a format like \"<INFO> revised_version_of_the_description\", do not return messages in other formats.""".format(
task_prompt)
role_play_session = RolePlaying(
assistant_role_name="Prompt Engineer",
assistant_role_prompt="You are an professional prompt engineer that can improve user input prompt to make LLM better understand these prompts.",
user_role_prompt="You are an user that want to use LLM to build software.",
user_role_name="User",
task_type=TaskType.CHATDEV,
task_prompt="Do prompt engineering on user query",
with_task_specify=False,
model_type=self.model_type,
)
# log_and_print_online("System", role_play_session.assistant_sys_msg)
# log_and_print_online("System", role_play_session.user_sys_msg)
_, input_user_msg = role_play_session.init_chat(None, None, self_task_improve_prompt)
assistant_response, user_response = role_play_session.step(input_user_msg, True)
revised_task_prompt = assistant_response.msg.content.split("<INFO>")[-1].lower().strip()
log_and_print_online(role_play_session.assistant_agent.role_name, assistant_response.msg.content)
log_and_print_online(
"**[Task Prompt Self Improvement]**\n**Original Task Prompt**: {}\n**Improved Task Prompt**: {}".format(
task_prompt, revised_task_prompt))
return revised_task_prompt
|