diff --git a/api/core/agent/__init__.py b/api/core/agent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..ae086ba8ed66ec8f907c73267e09b0b4e143bb3d --- /dev/null +++ b/api/core/agent/base_agent_runner.py @@ -0,0 +1,539 @@ +import json +import logging +import uuid +from datetime import UTC, datetime +from typing import Optional, Union, cast + +from core.agent.entities import AgentEntity, AgentToolEntity +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import ( + AgentChatAppGenerateEntity, + ModelConfigWithCredentialsEntity, +) +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.file import file_manager +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.model_runtime.entities import ( + AssistantPromptMessage, + LLMUsage, + PromptMessage, + PromptMessageContent, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.model_runtime.entities.model_entities import ModelFeature +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from core.tools.entities.tool_entities import ( + ToolParameter, + ToolRuntimeVariablePool, +) +from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool +from core.tools.tool.tool import Tool +from core.tools.tool_manager import ToolManager +from extensions.ext_database import db +from factories import file_factory +from models.model import Conversation, Message, MessageAgentThought, MessageFile +from models.tools import ToolConversationVariables + +logger = logging.getLogger(__name__) + + +class BaseAgentRunner(AppRunner): + def __init__( + self, + *, + tenant_id: str, + application_generate_entity: AgentChatAppGenerateEntity, + conversation: Conversation, + app_config: AgentChatAppConfig, + model_config: ModelConfigWithCredentialsEntity, + config: AgentEntity, + queue_manager: AppQueueManager, + message: Message, + user_id: str, + memory: Optional[TokenBufferMemory] = None, + prompt_messages: Optional[list[PromptMessage]] = None, + variables_pool: Optional[ToolRuntimeVariablePool] = None, + db_variables: Optional[ToolConversationVariables] = None, + model_instance: ModelInstance, + ) -> None: + self.tenant_id = tenant_id + self.application_generate_entity = application_generate_entity + self.conversation = conversation + self.app_config = app_config + self.model_config = model_config + self.config = config + self.queue_manager = queue_manager + self.message = message + self.user_id = user_id + self.memory = memory + self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) + self.variables_pool = variables_pool + self.db_variables_pool = db_variables + self.model_instance = model_instance + + # init callback + self.agent_callback = DifyAgentCallbackHandler() + # init dataset tools + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager=queue_manager, + app_id=self.app_config.app_id, + message_id=message.id, + user_id=user_id, + invoke_from=self.application_generate_entity.invoke_from, + ) + self.dataset_tools = DatasetRetrieverTool.get_dataset_tools( + tenant_id=tenant_id, + dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [], + retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, + return_resource=app_config.additional_features.show_retrieve_source, + invoke_from=application_generate_entity.invoke_from, + hit_callback=hit_callback, + ) + # get how many agent thoughts have been created + self.agent_thought_count = ( + db.session.query(MessageAgentThought) + .filter( + MessageAgentThought.message_id == self.message.id, + ) + .count() + ) + db.session.close() + + # check if model supports stream tool call + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + features = model_schema.features if model_schema and model_schema.features else [] + self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features + self.files = application_generate_entity.files if ModelFeature.VISION in features else [] + self.query: Optional[str] = "" + self._current_thoughts: list[PromptMessage] = [] + + def _repack_app_generate_entity( + self, app_generate_entity: AgentChatAppGenerateEntity + ) -> AgentChatAppGenerateEntity: + """ + Repack app generate entity + """ + if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: + app_generate_entity.app_config.prompt_template.simple_prompt_template = "" + + return app_generate_entity + + def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: + """ + convert tool to prompt message tool + """ + tool_entity = ToolManager.get_agent_tool_runtime( + tenant_id=self.tenant_id, + app_id=self.app_config.app_id, + agent_tool=tool, + invoke_from=self.application_generate_entity.invoke_from, + ) + tool_entity.load_variables(self.variables_pool) + + message_tool = PromptMessageTool( + name=tool.tool_name, + description=tool_entity.description.llm if tool_entity.description else "", + parameters={ + "type": "object", + "properties": {}, + "required": [], + }, + ) + + parameters = tool_entity.get_all_runtime_parameters() + for parameter in parameters: + if parameter.form != ToolParameter.ToolParameterForm.LLM: + continue + + parameter_type = parameter.type.as_normal_type() + if parameter.type in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + }: + continue + enum = [] + if parameter.type == ToolParameter.ToolParameterType.SELECT: + enum = [option.value for option in parameter.options] if parameter.options else [] + + message_tool.parameters["properties"][parameter.name] = { + "type": parameter_type, + "description": parameter.llm_description or "", + } + + if len(enum) > 0: + message_tool.parameters["properties"][parameter.name]["enum"] = enum + + if parameter.required: + message_tool.parameters["required"].append(parameter.name) + + return message_tool, tool_entity + + def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: + """ + convert dataset retriever tool to prompt message tool + """ + prompt_tool = PromptMessageTool( + name=tool.identity.name if tool.identity else "unknown", + description=tool.description.llm if tool.description else "", + parameters={ + "type": "object", + "properties": {}, + "required": [], + }, + ) + + for parameter in tool.get_runtime_parameters(): + parameter_type = "string" + + prompt_tool.parameters["properties"][parameter.name] = { + "type": parameter_type, + "description": parameter.llm_description or "", + } + + if parameter.required: + if parameter.name not in prompt_tool.parameters["required"]: + prompt_tool.parameters["required"].append(parameter.name) + + return prompt_tool + + def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: + """ + Init tools + """ + tool_instances = {} + prompt_messages_tools = [] + + for tool in self.app_config.agent.tools or [] if self.app_config.agent else []: + try: + prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) + except Exception: + # api tool may be deleted + continue + # save tool entity + tool_instances[tool.tool_name] = tool_entity + # save prompt tool + prompt_messages_tools.append(prompt_tool) + + # convert dataset tools into ModelRuntime Tool format + for dataset_tool in self.dataset_tools: + prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool) + # save prompt tool + prompt_messages_tools.append(prompt_tool) + # save tool entity + if dataset_tool.identity is not None: + tool_instances[dataset_tool.identity.name] = dataset_tool + + return tool_instances, prompt_messages_tools + + def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool: + """ + update prompt message tool + """ + # try to get tool runtime parameters + tool_runtime_parameters = tool.get_runtime_parameters() + + for parameter in tool_runtime_parameters: + if parameter.form != ToolParameter.ToolParameterForm.LLM: + continue + + parameter_type = parameter.type.as_normal_type() + if parameter.type in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + }: + continue + enum = [] + if parameter.type == ToolParameter.ToolParameterType.SELECT: + enum = [option.value for option in parameter.options] if parameter.options else [] + + prompt_tool.parameters["properties"][parameter.name] = { + "type": parameter_type, + "description": parameter.llm_description or "", + } + + if len(enum) > 0: + prompt_tool.parameters["properties"][parameter.name]["enum"] = enum + + if parameter.required: + if parameter.name not in prompt_tool.parameters["required"]: + prompt_tool.parameters["required"].append(parameter.name) + + return prompt_tool + + def create_agent_thought( + self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] + ) -> MessageAgentThought: + """ + Create agent thought + """ + thought = MessageAgentThought( + message_id=message_id, + message_chain_id=None, + thought="", + tool=tool_name, + tool_labels_str="{}", + tool_meta_str="{}", + tool_input=tool_input, + message=message, + message_token=0, + message_unit_price=0, + message_price_unit=0, + message_files=json.dumps(messages_ids) if messages_ids else "", + answer="", + observation="", + answer_token=0, + answer_unit_price=0, + answer_price_unit=0, + tokens=0, + total_price=0, + position=self.agent_thought_count + 1, + currency="USD", + latency=0, + created_by_role="account", + created_by=self.user_id, + ) + + db.session.add(thought) + db.session.commit() + db.session.refresh(thought) + db.session.close() + + self.agent_thought_count += 1 + + return thought + + def save_agent_thought( + self, + agent_thought: MessageAgentThought, + tool_name: str, + tool_input: Union[str, dict], + thought: str, + observation: Union[str, dict, None], + tool_invoke_meta: Union[str, dict, None], + answer: str, + messages_ids: list[str], + llm_usage: LLMUsage | None = None, + ): + """ + Save agent thought + """ + queried_thought = ( + db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() + ) + if not queried_thought: + raise ValueError(f"Agent thought {agent_thought.id} not found") + agent_thought = queried_thought + + if thought: + agent_thought.thought = thought + + if tool_name: + agent_thought.tool = tool_name + + if tool_input: + if isinstance(tool_input, dict): + try: + tool_input = json.dumps(tool_input, ensure_ascii=False) + except Exception as e: + tool_input = json.dumps(tool_input) + + agent_thought.tool_input = tool_input + + if observation: + if isinstance(observation, dict): + try: + observation = json.dumps(observation, ensure_ascii=False) + except Exception as e: + observation = json.dumps(observation) + + agent_thought.observation = observation + + if answer: + agent_thought.answer = answer + + if messages_ids is not None and len(messages_ids) > 0: + agent_thought.message_files = json.dumps(messages_ids) + + if llm_usage: + agent_thought.message_token = llm_usage.prompt_tokens + agent_thought.message_price_unit = llm_usage.prompt_price_unit + agent_thought.message_unit_price = llm_usage.prompt_unit_price + agent_thought.answer_token = llm_usage.completion_tokens + agent_thought.answer_price_unit = llm_usage.completion_price_unit + agent_thought.answer_unit_price = llm_usage.completion_unit_price + agent_thought.tokens = llm_usage.total_tokens + agent_thought.total_price = llm_usage.total_price + + # check if tool labels is not empty + labels = agent_thought.tool_labels or {} + tools = agent_thought.tool.split(";") if agent_thought.tool else [] + for tool in tools: + if not tool: + continue + if tool not in labels: + tool_label = ToolManager.get_tool_label(tool) + if tool_label: + labels[tool] = tool_label.to_dict() + else: + labels[tool] = {"en_US": tool, "zh_Hans": tool} + + agent_thought.tool_labels_str = json.dumps(labels) + + if tool_invoke_meta is not None: + if isinstance(tool_invoke_meta, dict): + try: + tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False) + except Exception as e: + tool_invoke_meta = json.dumps(tool_invoke_meta) + + agent_thought.tool_meta_str = tool_invoke_meta + + db.session.commit() + db.session.close() + + def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): + """ + convert tool variables to db variables + """ + queried_variables = ( + db.session.query(ToolConversationVariables) + .filter( + ToolConversationVariables.conversation_id == self.message.conversation_id, + ) + .first() + ) + + if not queried_variables: + return + + db_variables = queried_variables + + db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None) + db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) + db.session.commit() + db.session.close() + + def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Organize agent history + """ + result: list[PromptMessage] = [] + # check if there is a system message in the beginning of the conversation + for prompt_message in prompt_messages: + if isinstance(prompt_message, SystemPromptMessage): + result.append(prompt_message) + + messages: list[Message] = ( + db.session.query(Message) + .filter( + Message.conversation_id == self.message.conversation_id, + ) + .order_by(Message.created_at.desc()) + .all() + ) + + messages = list(reversed(extract_thread_messages(messages))) + + for message in messages: + if message.id == self.message.id: + continue + + result.append(self.organize_agent_user_prompt(message)) + agent_thoughts: list[MessageAgentThought] = message.agent_thoughts + if agent_thoughts: + for agent_thought in agent_thoughts: + tools = agent_thought.tool + if tools: + tools = tools.split(";") + tool_calls: list[AssistantPromptMessage.ToolCall] = [] + tool_call_response: list[ToolPromptMessage] = [] + try: + tool_inputs = json.loads(agent_thought.tool_input) + except Exception as e: + tool_inputs = {tool: {} for tool in tools} + try: + tool_responses = json.loads(agent_thought.observation) + except Exception as e: + tool_responses = dict.fromkeys(tools, agent_thought.observation) + + for tool in tools: + # generate a uuid for tool call + tool_call_id = str(uuid.uuid4()) + tool_calls.append( + AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool, + arguments=json.dumps(tool_inputs.get(tool, {})), + ), + ) + ) + tool_call_response.append( + ToolPromptMessage( + content=tool_responses.get(tool, agent_thought.observation), + name=tool, + tool_call_id=tool_call_id, + ) + ) + + result.extend( + [ + AssistantPromptMessage( + content=agent_thought.thought, + tool_calls=tool_calls, + ), + *tool_call_response, + ] + ) + if not tools: + result.append(AssistantPromptMessage(content=agent_thought.thought)) + else: + if message.answer: + result.append(AssistantPromptMessage(content=message.answer)) + + db.session.close() + + return result + + def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: + files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() + if not files: + return UserPromptMessage(content=message.query) + file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) + if not file_extra_config: + return UserPromptMessage(content=message.query) + + image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + + file_objs = file_factory.build_from_message_files( + message_files=files, tenant_id=self.tenant_id, config=file_extra_config + ) + if not file_objs: + return UserPromptMessage(content=message.query) + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=message.query)) + for file in file_objs: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) + return UserPromptMessage(content=prompt_message_contents) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe1865daf058c386f43e13e7d91297d2d83890c --- /dev/null +++ b/api/core/agent/cot_agent_runner.py @@ -0,0 +1,437 @@ +import json +from abc import ABC, abstractmethod +from collections.abc import Generator, Mapping +from typing import Any, Optional + +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.entities import AgentScratchpadUnit +from core.agent.output_parser.cot_output_parser import CotAgentOutputParser +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, + UserPromptMessage, +) +from core.ops.ops_trace_manager import TraceQueueManager +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool.tool import Tool +from core.tools.tool_engine import ToolEngine +from models.model import Message + + +class CotAgentRunner(BaseAgentRunner, ABC): + _is_first_iteration = True + _ignore_observation_providers = ["wenxin"] + _historic_prompt_messages: list[PromptMessage] | None = None + _agent_scratchpad: list[AgentScratchpadUnit] | None = None + _instruction: str = "" # FIXME this must be str for now + _query: str | None = None + _prompt_messages_tools: list[PromptMessageTool] = [] + + def run( + self, + message: Message, + query: str, + inputs: Mapping[str, str], + ) -> Generator: + """ + Run Cot agent application + """ + app_generate_entity = self.application_generate_entity + self._repack_app_generate_entity(app_generate_entity) + self._init_react_state(query) + + trace_manager = app_generate_entity.trace_manager + + # check model mode + if "Observation" not in app_generate_entity.model_conf.stop: + if app_generate_entity.model_conf.provider not in self._ignore_observation_providers: + app_generate_entity.model_conf.stop.append("Observation") + + app_config = self.app_config + + # init instruction + inputs = inputs or {} + instruction = app_config.prompt_template.simple_prompt_template + self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs) + + iteration_step = 1 + max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1 + + # convert tools into ModelRuntime Tool format + tool_instances, self._prompt_messages_tools = self._init_prompt_tools() + + function_call_state = True + llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} + final_answer = "" + + def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): + if not final_llm_usage_dict["usage"]: + final_llm_usage_dict["usage"] = usage + else: + llm_usage = final_llm_usage_dict["usage"] + llm_usage.prompt_tokens += usage.prompt_tokens + llm_usage.completion_tokens += usage.completion_tokens + llm_usage.prompt_price += usage.prompt_price + llm_usage.completion_price += usage.completion_price + llm_usage.total_price += usage.total_price + + model_instance = self.model_instance + + while function_call_state and iteration_step <= max_iteration_steps: + # continue to run until there is not any tool call + function_call_state = False + + if iteration_step == max_iteration_steps: + # the last iteration, remove all tools + self._prompt_messages_tools = [] + + message_file_ids: list[str] = [] + + agent_thought = self.create_agent_thought( + message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids + ) + + if iteration_step > 1: + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) + + # recalc llm max tokens + prompt_messages = self._organize_prompt_messages() + self.recalc_llm_max_tokens(self.model_config, prompt_messages) + # invoke model + chunks = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=app_generate_entity.model_conf.parameters, + tools=[], + stop=app_generate_entity.model_conf.stop, + stream=True, + user=self.user_id, + callbacks=[], + ) + + if not isinstance(chunks, Generator): + raise ValueError("Expected streaming response from LLM") + + # check llm result + if not chunks: + raise ValueError("failed to invoke llm") + + usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None} + react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) + scratchpad = AgentScratchpadUnit( + agent_response="", + thought="", + action_str="", + observation="", + action=None, + ) + + # publish agent thought if it's first iteration + if iteration_step == 1: + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) + + for chunk in react_chunks: + if isinstance(chunk, AgentScratchpadUnit.Action): + action = chunk + # detect action + if scratchpad.agent_response is not None: + scratchpad.agent_response += json.dumps(chunk.model_dump()) + scratchpad.action_str = json.dumps(chunk.model_dump()) + scratchpad.action = action + else: + if scratchpad.agent_response is not None: + scratchpad.agent_response += chunk + if scratchpad.thought is not None: + scratchpad.thought += chunk + yield LLMResultChunk( + model=self.model_config.model, + prompt_messages=prompt_messages, + system_fingerprint="", + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), + ) + if scratchpad.thought is not None: + scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" + if self._agent_scratchpad is not None: + self._agent_scratchpad.append(scratchpad) + + # get llm usage + if "usage" in usage_dict: + if usage_dict["usage"] is not None: + increase_usage(llm_usage, usage_dict["usage"]) + else: + usage_dict["usage"] = LLMUsage.empty_usage() + + self.save_agent_thought( + agent_thought=agent_thought, + tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""), + tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, + tool_invoke_meta={}, + thought=scratchpad.thought or "", + observation="", + answer=scratchpad.agent_response or "", + messages_ids=[], + llm_usage=usage_dict["usage"], + ) + + if not scratchpad.is_final(): + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) + + if not scratchpad.action: + # failed to extract action, return final answer directly + final_answer = "" + else: + if scratchpad.action.action_name.lower() == "final answer": + # action is final answer, return final answer directly + try: + if isinstance(scratchpad.action.action_input, dict): + final_answer = json.dumps(scratchpad.action.action_input) + elif isinstance(scratchpad.action.action_input, str): + final_answer = scratchpad.action.action_input + else: + final_answer = f"{scratchpad.action.action_input}" + except json.JSONDecodeError: + final_answer = f"{scratchpad.action.action_input}" + else: + function_call_state = True + # action is tool call, invoke tool + tool_invoke_response, tool_invoke_meta = self._handle_invoke_action( + action=scratchpad.action, + tool_instances=tool_instances, + message_file_ids=message_file_ids, + trace_manager=trace_manager, + ) + scratchpad.observation = tool_invoke_response + scratchpad.agent_response = tool_invoke_response + + self.save_agent_thought( + agent_thought=agent_thought, + tool_name=scratchpad.action.action_name, + tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, + thought=scratchpad.thought or "", + observation={scratchpad.action.action_name: tool_invoke_response}, + tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, + answer=scratchpad.agent_response, + messages_ids=message_file_ids, + llm_usage=usage_dict["usage"], + ) + + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) + + # update prompt tool message + for prompt_tool in self._prompt_messages_tools: + self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) + + iteration_step += 1 + + yield LLMResultChunk( + model=model_instance.model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"] + ), + system_fingerprint="", + ) + + # save agent thought + self.save_agent_thought( + agent_thought=agent_thought, + tool_name="", + tool_input={}, + tool_invoke_meta={}, + thought=final_answer, + observation={}, + answer=final_answer, + messages_ids=[], + ) + if self.variables_pool is not None and self.db_variables_pool is not None: + self.update_db_variables(self.variables_pool, self.db_variables_pool) + # publish end event + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=llm_usage["usage"] or LLMUsage.empty_usage(), + system_fingerprint="", + ) + ), + PublishFrom.APPLICATION_MANAGER, + ) + + def _handle_invoke_action( + self, + action: AgentScratchpadUnit.Action, + tool_instances: dict[str, Tool], + message_file_ids: list[str], + trace_manager: Optional[TraceQueueManager] = None, + ) -> tuple[str, ToolInvokeMeta]: + """ + handle invoke action + :param action: action + :param tool_instances: tool instances + :param message_file_ids: message file ids + :param trace_manager: trace manager + :return: observation, meta + """ + # action is tool call, invoke tool + tool_call_name = action.action_name + tool_call_args = action.action_input + tool_instance = tool_instances.get(tool_call_name) + + if not tool_instance: + answer = f"there is not a tool named {tool_call_name}" + return answer, ToolInvokeMeta.error_instance(answer) + + if isinstance(tool_call_args, str): + try: + tool_call_args = json.loads(tool_call_args) + except json.JSONDecodeError: + pass + + # invoke tool + tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( + tool=tool_instance, + tool_parameters=tool_call_args, + user_id=self.user_id, + tenant_id=self.tenant_id, + message=self.message, + invoke_from=self.application_generate_entity.invoke_from, + agent_tool_callback=self.agent_callback, + trace_manager=trace_manager, + ) + + # publish files + for message_file_id, save_as in message_files: + if save_as is not None and self.variables_pool: + # FIXME the save_as type is confusing, it should be a string or not + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as)) + + # publish message file + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER + ) + # add message file ids + message_file_ids.append(message_file_id) + + return tool_invoke_response, tool_invoke_meta + + def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: + """ + convert dict to action + """ + return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) + + def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str: + """ + fill in inputs from external data tools + """ + for key, value in inputs.items(): + try: + instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) + except Exception as e: + continue + + return instruction + + def _init_react_state(self, query) -> None: + """ + init agent scratchpad + """ + self._query = query + self._agent_scratchpad = [] + self._historic_prompt_messages = self._organize_historic_prompt_messages() + + @abstractmethod + def _organize_prompt_messages(self) -> list[PromptMessage]: + """ + organize prompt messages + """ + + def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: + """ + format assistant message + """ + message = "" + for scratchpad in agent_scratchpad: + if scratchpad.is_final(): + message += f"Final Answer: {scratchpad.agent_response}" + else: + message += f"Thought: {scratchpad.thought}\n\n" + if scratchpad.action_str: + message += f"Action: {scratchpad.action_str}\n\n" + if scratchpad.observation: + message += f"Observation: {scratchpad.observation}\n\n" + + return message + + def _organize_historic_prompt_messages( + self, current_session_messages: Optional[list[PromptMessage]] = None + ) -> list[PromptMessage]: + """ + organize historic prompt messages + """ + result: list[PromptMessage] = [] + scratchpads: list[AgentScratchpadUnit] = [] + current_scratchpad: AgentScratchpadUnit | None = None + + for message in self.history_prompt_messages: + if isinstance(message, AssistantPromptMessage): + if not current_scratchpad: + if not isinstance(message.content, str | None): + raise NotImplementedError("expected str type") + current_scratchpad = AgentScratchpadUnit( + agent_response=message.content, + thought=message.content or "I am thinking about how to help you", + action_str="", + action=None, + observation=None, + ) + scratchpads.append(current_scratchpad) + if message.tool_calls: + try: + current_scratchpad.action = AgentScratchpadUnit.Action( + action_name=message.tool_calls[0].function.name, + action_input=json.loads(message.tool_calls[0].function.arguments), + ) + current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict()) + except: + pass + elif isinstance(message, ToolPromptMessage): + if not current_scratchpad: + continue + if isinstance(message.content, str): + current_scratchpad.observation = message.content + else: + raise NotImplementedError("expected str type") + elif isinstance(message, UserPromptMessage): + if scratchpads: + result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) + scratchpads = [] + current_scratchpad = None + + result.append(message) + + if scratchpads: + result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) + + historic_prompts = AgentHistoryPromptTransform( + model_config=self.model_config, + prompt_messages=current_session_messages or [], + history_messages=result, + memory=self.memory, + ).get_prompt() + return historic_prompts diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..6a96c349b2611c49c9603da48e8c447bc094bb7f --- /dev/null +++ b/api/core/agent/cot_chat_agent_runner.py @@ -0,0 +1,117 @@ +import json + +from core.agent.cot_agent_runner import CotAgentRunner +from core.file import file_manager +from core.model_runtime.entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageContent, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.model_runtime.utils.encoders import jsonable_encoder + + +class CotChatAgentRunner(CotAgentRunner): + def _organize_system_prompt(self) -> SystemPromptMessage: + """ + Organize system prompt + """ + if not self.app_config.agent: + raise ValueError("Agent configuration is not set") + + prompt_entity = self.app_config.agent.prompt + if not prompt_entity: + raise ValueError("Agent prompt configuration is not set") + first_prompt = prompt_entity.first_prompt + + system_prompt = ( + first_prompt.replace("{{instruction}}", self._instruction) + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) + .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) + ) + + return SystemPromptMessage(content=system_prompt) + + def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Organize user query + """ + if self.files: + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=query)) + + # get image detail config + image_detail_config = ( + self.application_generate_entity.file_upload_config.image_config.detail + if ( + self.application_generate_entity.file_upload_config + and self.application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + for file in self.files: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=query)) + + return prompt_messages + + def _organize_prompt_messages(self) -> list[PromptMessage]: + """ + Organize + """ + # organize system prompt + system_message = self._organize_system_prompt() + + # organize current assistant messages + agent_scratchpad = self._agent_scratchpad + if not agent_scratchpad: + assistant_messages = [] + else: + assistant_message = AssistantPromptMessage(content="") + assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str + for unit in agent_scratchpad: + if unit.is_final(): + assistant_message.content += f"Final Answer: {unit.agent_response}" + else: + assistant_message.content += f"Thought: {unit.thought}\n\n" + if unit.action_str: + assistant_message.content += f"Action: {unit.action_str}\n\n" + if unit.observation: + assistant_message.content += f"Observation: {unit.observation}\n\n" + + assistant_messages = [assistant_message] + + # query messages + query_messages = self._organize_user_query(self._query, []) + + if assistant_messages: + # organize historic prompt messages + historic_messages = self._organize_historic_prompt_messages( + [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")] + ) + messages = [ + system_message, + *historic_messages, + *query_messages, + *assistant_messages, + UserPromptMessage(content="continue"), + ] + else: + # organize historic prompt messages + historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages]) + messages = [system_message, *historic_messages, *query_messages] + + # join all messages + return messages diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..3a4d31e047f5ae58cc42638277fde3d4f998c6fa --- /dev/null +++ b/api/core/agent/cot_completion_agent_runner.py @@ -0,0 +1,88 @@ +import json +from typing import Optional + +from core.agent.cot_agent_runner import CotAgentRunner +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.utils.encoders import jsonable_encoder + + +class CotCompletionAgentRunner(CotAgentRunner): + def _organize_instruction_prompt(self) -> str: + """ + Organize instruction prompt + """ + if self.app_config.agent is None: + raise ValueError("Agent configuration is not set") + prompt_entity = self.app_config.agent.prompt + if prompt_entity is None: + raise ValueError("prompt entity is not set") + first_prompt = prompt_entity.first_prompt + + system_prompt = ( + first_prompt.replace("{{instruction}}", self._instruction) + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) + .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) + ) + + return system_prompt + + def _organize_historic_prompt(self, current_session_messages: Optional[list[PromptMessage]] = None) -> str: + """ + Organize historic prompt + """ + historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages) + historic_prompt = "" + + for message in historic_prompt_messages: + if isinstance(message, UserPromptMessage): + historic_prompt += f"Question: {message.content}\n\n" + elif isinstance(message, AssistantPromptMessage): + if isinstance(message.content, str): + historic_prompt += message.content + "\n\n" + elif isinstance(message.content, list): + for content in message.content: + if not isinstance(content, TextPromptMessageContent): + continue + historic_prompt += content.data + + return historic_prompt + + def _organize_prompt_messages(self) -> list[PromptMessage]: + """ + Organize prompt messages + """ + # organize system prompt + system_prompt = self._organize_instruction_prompt() + + # organize historic prompt messages + historic_prompt = self._organize_historic_prompt() + + # organize current assistant messages + agent_scratchpad = self._agent_scratchpad + assistant_prompt = "" + for unit in agent_scratchpad or []: + if unit.is_final(): + assistant_prompt += f"Final Answer: {unit.agent_response}" + else: + assistant_prompt += f"Thought: {unit.thought}\n\n" + if unit.action_str: + assistant_prompt += f"Action: {unit.action_str}\n\n" + if unit.observation: + assistant_prompt += f"Observation: {unit.observation}\n\n" + + # query messages + query_prompt = f"Question: {self._query}" + + # join all messages + prompt = ( + system_prompt.replace("{{historic_messages}}", historic_prompt) + .replace("{{agent_scratchpad}}", assistant_prompt) + .replace("{{query}}", query_prompt) + ) + + return [UserPromptMessage(content=prompt)] diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py new file mode 100644 index 0000000000000000000000000000000000000000..2ae87dca3f8cbd3675b58db7bb094c6c5b292a13 --- /dev/null +++ b/api/core/agent/entities.py @@ -0,0 +1,82 @@ +from enum import Enum +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel + + +class AgentToolEntity(BaseModel): + """ + Agent Tool Entity. + """ + + provider_type: Literal["builtin", "api", "workflow"] + provider_id: str + tool_name: str + tool_parameters: dict[str, Any] = {} + + +class AgentPromptEntity(BaseModel): + """ + Agent Prompt Entity. + """ + + first_prompt: str + next_iteration: str + + +class AgentScratchpadUnit(BaseModel): + """ + Agent First Prompt Entity. + """ + + class Action(BaseModel): + """ + Action Entity. + """ + + action_name: str + action_input: Union[dict, str] + + def to_dict(self) -> dict: + """ + Convert to dictionary. + """ + return { + "action": self.action_name, + "action_input": self.action_input, + } + + agent_response: Optional[str] = None + thought: Optional[str] = None + action_str: Optional[str] = None + observation: Optional[str] = None + action: Optional[Action] = None + + def is_final(self) -> bool: + """ + Check if the scratchpad unit is final. + """ + return self.action is None or ( + "final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower() + ) + + +class AgentEntity(BaseModel): + """ + Agent Entity. + """ + + class Strategy(Enum): + """ + Agent Strategy. + """ + + CHAIN_OF_THOUGHT = "chain-of-thought" + FUNCTION_CALLING = "function-calling" + + provider: str + model: str + strategy: Strategy + prompt: Optional[AgentPromptEntity] = None + tools: list[AgentToolEntity] | None = None + max_iteration: int = 5 diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..b862c96072aaa0f481fddd88647dfb99fb4486c2 --- /dev/null +++ b/api/core/agent/fc_agent_runner.py @@ -0,0 +1,473 @@ +import json +import logging +from collections.abc import Generator +from copy import deepcopy +from typing import Any, Optional, Union + +from core.agent.base_agent_runner import BaseAgentRunner +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.file import file_manager +from core.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, + PromptMessage, + PromptMessageContent, + PromptMessageContentType, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine +from models.model import Message + +logger = logging.getLogger(__name__) + + +class FunctionCallAgentRunner(BaseAgentRunner): + def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: + """ + Run FunctionCall agent application + """ + self.query = query + app_generate_entity = self.application_generate_entity + + app_config = self.app_config + assert app_config is not None, "app_config is required" + assert app_config.agent is not None, "app_config.agent is required" + + # convert tools into ModelRuntime Tool format + tool_instances, prompt_messages_tools = self._init_prompt_tools() + + iteration_step = 1 + max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 + + # continue to run until there is not any tool call + function_call_state = True + llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()} + final_answer = "" + + # get tracing instance + trace_manager = app_generate_entity.trace_manager + + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): + if not final_llm_usage_dict["usage"]: + final_llm_usage_dict["usage"] = usage + else: + llm_usage = final_llm_usage_dict["usage"] + llm_usage.prompt_tokens += usage.prompt_tokens + llm_usage.completion_tokens += usage.completion_tokens + llm_usage.prompt_price += usage.prompt_price + llm_usage.completion_price += usage.completion_price + llm_usage.total_price += usage.total_price + + model_instance = self.model_instance + + while function_call_state and iteration_step <= max_iteration_steps: + function_call_state = False + + if iteration_step == max_iteration_steps: + # the last iteration, remove all tools + prompt_messages_tools = [] + + message_file_ids: list[str] = [] + agent_thought = self.create_agent_thought( + message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids + ) + + # recalc llm max tokens + prompt_messages = self._organize_prompt_messages() + self.recalc_llm_max_tokens(self.model_config, prompt_messages) + # invoke model + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=app_generate_entity.model_conf.parameters, + tools=prompt_messages_tools, + stop=app_generate_entity.model_conf.stop, + stream=self.stream_tool_call, + user=self.user_id, + callbacks=[], + ) + + tool_calls: list[tuple[str, str, dict[str, Any]]] = [] + + # save full response + response = "" + + # save tool call names and inputs + tool_call_names = "" + tool_call_inputs = "" + + current_llm_usage = None + + if self.stream_tool_call and isinstance(chunks, Generator): + is_first_chunk = True + for chunk in chunks: + if is_first_chunk: + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) + is_first_chunk = False + # check if there is any tool call + if self.check_tool_calls(chunk): + function_call_state = True + tool_calls.extend(self.extract_tool_calls(chunk) or []) + tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) + try: + tool_call_inputs = json.dumps( + {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False + ) + except json.JSONDecodeError as e: + # ensure ascii to avoid encoding error + tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) + + if chunk.delta.message and chunk.delta.message.content: + if isinstance(chunk.delta.message.content, list): + for content in chunk.delta.message.content: + response += content.data + else: + response += str(chunk.delta.message.content) + + if chunk.delta.usage: + increase_usage(llm_usage, chunk.delta.usage) + current_llm_usage = chunk.delta.usage + + yield chunk + elif not self.stream_tool_call and isinstance(chunks, LLMResult): + result = chunks + # check if there is any tool call + if self.check_blocking_tool_calls(result): + function_call_state = True + tool_calls.extend(self.extract_blocking_tool_calls(result) or []) + tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) + try: + tool_call_inputs = json.dumps( + {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False + ) + except json.JSONDecodeError as e: + # ensure ascii to avoid encoding error + tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) + + if result.usage: + increase_usage(llm_usage, result.usage) + current_llm_usage = result.usage + + if result.message and result.message.content: + if isinstance(result.message.content, list): + for content in result.message.content: + response += content.data + else: + response += str(result.message.content) + + if not result.message.content: + result.message.content = "" + + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) + + yield LLMResultChunk( + model=model_instance.model, + prompt_messages=result.prompt_messages, + system_fingerprint=result.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=result.message, + usage=result.usage, + ), + ) + else: + raise RuntimeError(f"invalid chunks type: {type(chunks)}") + + assistant_message = AssistantPromptMessage(content="", tool_calls=[]) + if tool_calls: + assistant_message.tool_calls = [ + AssistantPromptMessage.ToolCall( + id=tool_call[0], + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False) + ), + ) + for tool_call in tool_calls + ] + else: + assistant_message.content = response + + self._current_thoughts.append(assistant_message) + + # save thought + self.save_agent_thought( + agent_thought=agent_thought, + tool_name=tool_call_names, + tool_input=tool_call_inputs, + thought=response, + tool_invoke_meta=None, + observation=None, + answer=response, + messages_ids=[], + llm_usage=current_llm_usage, + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) + + final_answer += response + "\n" + + # call tools + tool_responses = [] + for tool_call_id, tool_call_name, tool_call_args in tool_calls: + tool_instance = tool_instances.get(tool_call_name) + if not tool_instance: + tool_response = { + "tool_call_id": tool_call_id, + "tool_call_name": tool_call_name, + "tool_response": f"there is not a tool named {tool_call_name}", + "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(), + } + else: + # invoke tool + tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( + tool=tool_instance, + tool_parameters=tool_call_args, + user_id=self.user_id, + tenant_id=self.tenant_id, + message=self.message, + invoke_from=self.application_generate_entity.invoke_from, + agent_tool_callback=self.agent_callback, + trace_manager=trace_manager, + ) + # publish files + for message_file_id, save_as in message_files: + if save_as: + if self.variables_pool: + self.variables_pool.set_file( + tool_name=tool_call_name, value=message_file_id, name=save_as + ) + + # publish message file + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER + ) + # add message file ids + message_file_ids.append(message_file_id) + + tool_response = { + "tool_call_id": tool_call_id, + "tool_call_name": tool_call_name, + "tool_response": tool_invoke_response, + "meta": tool_invoke_meta.to_dict(), + } + + tool_responses.append(tool_response) + if tool_response["tool_response"] is not None: + self._current_thoughts.append( + ToolPromptMessage( + content=str(tool_response["tool_response"]), + tool_call_id=tool_call_id, + name=tool_call_name, + ) + ) + + if len(tool_responses) > 0: + # save agent thought + self.save_agent_thought( + agent_thought=agent_thought, + tool_name="", + tool_input="", + thought="", + tool_invoke_meta={ + tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses + }, + observation={ + tool_response["tool_call_name"]: tool_response["tool_response"] + for tool_response in tool_responses + }, + answer="", + messages_ids=message_file_ids, + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) + + # update prompt tool + for prompt_tool in prompt_messages_tools: + self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) + + iteration_step += 1 + + if self.variables_pool and self.db_variables_pool: + self.update_db_variables(self.variables_pool, self.db_variables_pool) + # publish end event + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=llm_usage["usage"] or LLMUsage.empty_usage(), + system_fingerprint="", + ) + ), + PublishFrom.APPLICATION_MANAGER, + ) + + def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: + """ + Check if there is any tool call in llm result chunk + """ + if llm_result_chunk.delta.message.tool_calls: + return True + return False + + def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: + """ + Check if there is any blocking tool call in llm result + """ + if llm_result.message.tool_calls: + return True + return False + + def extract_tool_calls( + self, llm_result_chunk: LLMResultChunk + ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: + """ + Extract tool calls from llm result chunk + + Returns: + List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] + """ + tool_calls = [] + for prompt_message in llm_result_chunk.delta.message.tool_calls: + args = {} + if prompt_message.function.arguments != "": + args = json.loads(prompt_message.function.arguments) + + tool_calls.append( + ( + prompt_message.id, + prompt_message.function.name, + args, + ) + ) + + return tool_calls + + def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: + """ + Extract blocking tool calls from llm result + + Returns: + List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] + """ + tool_calls = [] + for prompt_message in llm_result.message.tool_calls: + args = {} + if prompt_message.function.arguments != "": + args = json.loads(prompt_message.function.arguments) + + tool_calls.append( + ( + prompt_message.id, + prompt_message.function.name, + args, + ) + ) + + return tool_calls + + def _init_system_message( + self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None + ) -> list[PromptMessage]: + """ + Initialize system message + """ + if not prompt_messages and prompt_template: + return [ + SystemPromptMessage(content=prompt_template), + ] + + if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: + prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) + + return prompt_messages or [] + + def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Organize user query + """ + if self.files: + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=query)) + + # get image detail config + image_detail_config = ( + self.application_generate_entity.file_upload_config.image_config.detail + if ( + self.application_generate_entity.file_upload_config + and self.application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + for file in self.files: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=query)) + + return prompt_messages + + def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + As for now, gpt supports both fc and vision at the first iteration. + We need to remove the image messages from the prompt messages at the first iteration. + """ + prompt_messages = deepcopy(prompt_messages) + + for prompt_message in prompt_messages: + if isinstance(prompt_message, UserPromptMessage): + if isinstance(prompt_message.content, list): + prompt_message.content = "\n".join( + [ + content.data + if content.type == PromptMessageContentType.TEXT + else "[image]" + if content.type == PromptMessageContentType.IMAGE + else "[file]" + for content in prompt_message.content + ] + ) + + return prompt_messages + + def _organize_prompt_messages(self): + prompt_template = self.app_config.prompt_template.simple_prompt_template or "" + self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) + query_prompt_messages = self._organize_user_query(self.query or "", []) + + self.history_prompt_messages = AgentHistoryPromptTransform( + model_config=self.model_config, + prompt_messages=[*query_prompt_messages, *self._current_thoughts], + history_messages=self.history_prompt_messages, + memory=self.memory, + ).get_prompt() + + prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] + if len(self._current_thoughts) != 0: + # clear messages after the first iteration + prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) + return prompt_messages diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..61fa774ea5f39016eca4ab51e61794b7b58669dd --- /dev/null +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -0,0 +1,208 @@ +import json +import re +from collections.abc import Generator +from typing import Union + +from core.agent.entities import AgentScratchpadUnit +from core.model_runtime.entities.llm_entities import LLMResultChunk + + +class CotAgentOutputParser: + @classmethod + def handle_react_stream_output( + cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict + ) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: + def parse_action(json_str): + try: + action = json.loads(json_str, strict=False) + action_name = None + action_input = None + + # cohere always returns a list + if isinstance(action, list) and len(action) == 1: + action = action[0] + + for key, value in action.items(): + if "input" in key.lower(): + action_input = value + else: + action_name = value + + if action_name is not None and action_input is not None: + return AgentScratchpadUnit.Action( + action_name=action_name, + action_input=action_input, + ) + else: + return json_str or "" + except: + return json_str or "" + + def extra_json_from_code_block(code_block) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: + code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL) + if not code_blocks: + return + for block in code_blocks: + json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE) + yield parse_action(json_text) + + code_block_cache = "" + code_block_delimiter_count = 0 + in_code_block = False + json_cache = "" + json_quote_count = 0 + in_json = False + got_json = False + + action_cache = "" + action_str = "action:" + action_idx = 0 + + thought_cache = "" + thought_str = "thought:" + thought_idx = 0 + + last_character = "" + + for response in llm_response: + if response.delta.usage: + usage_dict["usage"] = response.delta.usage + response_content = response.delta.message.content + if not isinstance(response_content, str): + continue + + # stream + index = 0 + while index < len(response_content): + steps = 1 + delta = response_content[index : index + steps] + yield_delta = False + + if delta == "`": + last_character = delta + code_block_cache += delta + code_block_delimiter_count += 1 + else: + if not in_code_block: + if code_block_delimiter_count > 0: + last_character = delta + yield code_block_cache + code_block_cache = "" + else: + last_character = delta + code_block_cache += delta + code_block_delimiter_count = 0 + + if not in_code_block and not in_json: + if delta.lower() == action_str[action_idx] and action_idx == 0: + if last_character not in {"\n", " ", ""}: + yield_delta = True + else: + last_character = delta + action_cache += delta + action_idx += 1 + if action_idx == len(action_str): + action_cache = "" + action_idx = 0 + index += steps + continue + elif delta.lower() == action_str[action_idx] and action_idx > 0: + last_character = delta + action_cache += delta + action_idx += 1 + if action_idx == len(action_str): + action_cache = "" + action_idx = 0 + index += steps + continue + else: + if action_cache: + last_character = delta + yield action_cache + action_cache = "" + action_idx = 0 + + if delta.lower() == thought_str[thought_idx] and thought_idx == 0: + if last_character not in {"\n", " ", ""}: + yield_delta = True + else: + last_character = delta + thought_cache += delta + thought_idx += 1 + if thought_idx == len(thought_str): + thought_cache = "" + thought_idx = 0 + index += steps + continue + elif delta.lower() == thought_str[thought_idx] and thought_idx > 0: + last_character = delta + thought_cache += delta + thought_idx += 1 + if thought_idx == len(thought_str): + thought_cache = "" + thought_idx = 0 + index += steps + continue + else: + if thought_cache: + last_character = delta + yield thought_cache + thought_cache = "" + thought_idx = 0 + + if yield_delta: + index += steps + last_character = delta + yield delta + continue + + if code_block_delimiter_count == 3: + if in_code_block: + last_character = delta + yield from extra_json_from_code_block(code_block_cache) + code_block_cache = "" + + in_code_block = not in_code_block + code_block_delimiter_count = 0 + + if not in_code_block: + # handle single json + if delta == "{": + json_quote_count += 1 + in_json = True + last_character = delta + json_cache += delta + elif delta == "}": + last_character = delta + json_cache += delta + if json_quote_count > 0: + json_quote_count -= 1 + if json_quote_count == 0: + in_json = False + got_json = True + index += steps + continue + else: + if in_json: + last_character = delta + json_cache += delta + + if got_json: + got_json = False + last_character = delta + yield parse_action(json_cache) + json_cache = "" + json_quote_count = 0 + in_json = False + + if not in_code_block and not in_json: + last_character = delta + yield delta.replace("`", "") + + index += steps + + if code_block_cache: + yield code_block_cache + + if json_cache: + yield parse_action(json_cache) diff --git a/api/core/agent/prompt/template.py b/api/core/agent/prompt/template.py new file mode 100644 index 0000000000000000000000000000000000000000..ef64fd29fc3a76e00bedb1569f0479cf56c4b0f1 --- /dev/null +++ b/api/core/agent/prompt/template.py @@ -0,0 +1,106 @@ +ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible. + +{{instruction}} + +You have access to the following tools: + +{{tools}} + +Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). +Valid "action" values: "Final Answer" or {{tool_names}} + +Provide only ONE action per $JSON_BLOB, as shown: + +``` +{ + "action": $TOOL_NAME, + "action_input": $ACTION_INPUT +} +``` + +Follow this format: + +Question: input question to answer +Thought: consider previous and subsequent steps +Action: +``` +$JSON_BLOB +``` +Observation: action result +... (repeat Thought/Action/Observation N times) +Thought: I know what to respond +Action: +``` +{ + "action": "Final Answer", + "action_input": "Final response to human" +} +``` + +Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. +{{historic_messages}} +Question: {{query}} +{{agent_scratchpad}} +Thought:""" # noqa: E501 + + +ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}} +Thought:""" + +ENGLISH_REACT_CHAT_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible. + +{{instruction}} + +You have access to the following tools: + +{{tools}} + +Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). +Valid "action" values: "Final Answer" or {{tool_names}} + +Provide only ONE action per $JSON_BLOB, as shown: + +``` +{ + "action": $TOOL_NAME, + "action_input": $ACTION_INPUT +} +``` + +Follow this format: + +Question: input question to answer +Thought: consider previous and subsequent steps +Action: +``` +$JSON_BLOB +``` +Observation: action result +... (repeat Thought/Action/Observation N times) +Thought: I know what to respond +Action: +``` +{ + "action": "Final Answer", + "action_input": "Final response to human" +} +``` + +Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. +""" # noqa: E501 + + +ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = "" + +REACT_PROMPT_TEMPLATES = { + "english": { + "chat": { + "prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, + "agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES, + }, + "completion": { + "prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES, + "agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES, + }, + } +} diff --git a/api/core/app/__init__.py b/api/core/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/__init__.py b/api/core/app/app_config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..24d80f9cdd77f7ddcd050296ac732f0799093887 --- /dev/null +++ b/api/core/app/app_config/base_app_config_manager.py @@ -0,0 +1,49 @@ +from collections.abc import Mapping +from typing import Any + +from core.app.app_config.entities import AppAdditionalFeatures +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from models.model import AppMode + + +class BaseAppConfigManager: + @classmethod + def convert_features(cls, config_dict: Mapping[str, Any], app_mode: AppMode) -> AppAdditionalFeatures: + """ + Convert app config to app model config + + :param config_dict: app config + :param app_mode: app mode + """ + config_dict = dict(config_dict.items()) + + additional_features = AppAdditionalFeatures() + additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict) + + additional_features.file_upload = FileUploadConfigManager.convert( + config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT} + ) + + additional_features.opening_statement, additional_features.suggested_questions = ( + OpeningStatementConfigManager.convert(config=config_dict) + ) + + additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert( + config=config_dict + ) + + additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict) + + additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict) + + additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict) + + return additional_features diff --git a/api/core/app/app_config/common/__init__.py b/api/core/app/app_config/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/__init__.py b/api/core/app/app_config/common/sensitive_word_avoidance/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..037037e6ca1cf03952e472daeee0f08d9c44283b --- /dev/null +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -0,0 +1,45 @@ +from typing import Optional + +from core.app.app_config.entities import SensitiveWordAvoidanceEntity +from core.moderation.factory import ModerationFactory + + +class SensitiveWordAvoidanceConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: + sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance") + if not sensitive_word_avoidance_dict: + return None + + if sensitive_word_avoidance_dict.get("enabled"): + return SensitiveWordAvoidanceEntity( + type=sensitive_word_avoidance_dict.get("type"), + config=sensitive_word_avoidance_dict.get("config"), + ) + else: + return None + + @classmethod + def validate_and_set_defaults( + cls, tenant_id, config: dict, only_structure_validate: bool = False + ) -> tuple[dict, list[str]]: + if not config.get("sensitive_word_avoidance"): + config["sensitive_word_avoidance"] = {"enabled": False} + + if not isinstance(config["sensitive_word_avoidance"], dict): + raise ValueError("sensitive_word_avoidance must be of dict type") + + if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]: + config["sensitive_word_avoidance"]["enabled"] = False + + if config["sensitive_word_avoidance"]["enabled"]: + if not config["sensitive_word_avoidance"].get("type"): + raise ValueError("sensitive_word_avoidance.type is required") + + if not only_structure_validate: + typ = config["sensitive_word_avoidance"]["type"] + sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] + + ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config) + + return config, ["sensitive_word_avoidance"] diff --git a/api/core/app/app_config/easy_ui_based_app/__init__.py b/api/core/app/app_config/easy_ui_based_app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/easy_ui_based_app/agent/__init__.py b/api/core/app/app_config/easy_ui_based_app/agent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..f503543d7bd0f5fbacf869cb1682ea6993dedc5a --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -0,0 +1,81 @@ +from typing import Optional + +from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity +from core.agent.prompt.template import REACT_PROMPT_TEMPLATES + + +class AgentConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[AgentEntity]: + """ + Convert model config to model config + + :param config: model config args + """ + if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]: + agent_dict = config.get("agent_mode", {}) + agent_strategy = agent_dict.get("strategy", "cot") + + if agent_strategy == "function_call": + strategy = AgentEntity.Strategy.FUNCTION_CALLING + elif agent_strategy in {"cot", "react"}: + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + else: + # old configs, try to detect default strategy + if config["model"]["provider"] == "openai": + strategy = AgentEntity.Strategy.FUNCTION_CALLING + else: + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + + agent_tools = [] + for tool in agent_dict.get("tools", []): + keys = tool.keys() + if len(keys) >= 4: + if "enabled" not in tool or not tool["enabled"]: + continue + + agent_tool_properties = { + "provider_type": tool["provider_type"], + "provider_id": tool["provider_id"], + "tool_name": tool["tool_name"], + "tool_parameters": tool.get("tool_parameters", {}), + } + + agent_tools.append(AgentToolEntity(**agent_tool_properties)) + + if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in { + "react_router", + "router", + }: + agent_prompt = agent_dict.get("prompt", None) or {} + # check model mode + model_mode = config.get("model", {}).get("mode", "completion") + if model_mode == "completion": + agent_prompt_entity = AgentPromptEntity( + first_prompt=agent_prompt.get( + "first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"] + ), + next_iteration=agent_prompt.get( + "next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"] + ), + ) + else: + agent_prompt_entity = AgentPromptEntity( + first_prompt=agent_prompt.get( + "first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"] + ), + next_iteration=agent_prompt.get( + "next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"] + ), + ) + + return AgentEntity( + provider=config["model"]["provider"], + model=config["model"]["name"], + strategy=strategy, + prompt=agent_prompt_entity, + tools=agent_tools, + max_iteration=agent_dict.get("max_iteration", 5), + ) + + return None diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/__init__.py b/api/core/app/app_config/easy_ui_based_app/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..646c4badb9f73a3d81b1f56301b2ee2af019d3d4 --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -0,0 +1,221 @@ +import uuid +from typing import Optional + +from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity +from core.entities.agent_entities import PlanningStrategy +from models.model import AppMode +from services.dataset_service import DatasetService + + +class DatasetConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[DatasetEntity]: + """ + Convert model config to model config + + :param config: model config args + """ + dataset_ids = [] + if "datasets" in config.get("dataset_configs", {}): + datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []}) + + for dataset in datasets.get("datasets", []): + keys = list(dataset.keys()) + if len(keys) == 0 or keys[0] != "dataset": + continue + + dataset = dataset["dataset"] + + if "enabled" not in dataset or not dataset["enabled"]: + continue + + dataset_id = dataset.get("id", None) + if dataset_id: + dataset_ids.append(dataset_id) + + if ( + "agent_mode" in config + and config["agent_mode"] + and "enabled" in config["agent_mode"] + and config["agent_mode"]["enabled"] + ): + agent_dict = config.get("agent_mode", {}) + + for tool in agent_dict.get("tools", []): + keys = tool.keys() + if len(keys) == 1: + # old standard + key = list(tool.keys())[0] + + if key != "dataset": + continue + + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + continue + + dataset_id = tool_item["id"] + dataset_ids.append(dataset_id) + + if len(dataset_ids) == 0: + return None + + # dataset configs + if "dataset_configs" in config and config.get("dataset_configs"): + dataset_configs = config.get("dataset_configs") + else: + dataset_configs = {"retrieval_model": "multiple"} + if dataset_configs is None: + return None + query_variable = config.get("dataset_query_variable") + + if dataset_configs["retrieval_model"] == "single": + return DatasetEntity( + dataset_ids=dataset_ids, + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=query_variable, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( + dataset_configs["retrieval_model"] + ), + ), + ) + else: + return DatasetEntity( + dataset_ids=dataset_ids, + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=query_variable, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( + dataset_configs["retrieval_model"] + ), + top_k=dataset_configs.get("top_k", 4), + score_threshold=dataset_configs.get("score_threshold"), + reranking_model=dataset_configs.get("reranking_model"), + weights=dataset_configs.get("weights"), + reranking_enabled=dataset_configs.get("reranking_enabled", True), + rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), + ), + ) + + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for dataset feature + + :param tenant_id: tenant ID + :param app_mode: app mode + :param config: app model config args + """ + # Extract dataset config for legacy compatibility + config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config) + + # dataset_configs + if not config.get("dataset_configs"): + config["dataset_configs"] = {"retrieval_model": "single"} + + if not config["dataset_configs"].get("datasets"): + config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []} + + if not isinstance(config["dataset_configs"], dict): + raise ValueError("dataset_configs must be of object type") + + if not isinstance(config["dataset_configs"], dict): + raise ValueError("dataset_configs must be of object type") + + need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get( + "datasets", {} + ).get("datasets") + + if need_manual_query_datasets and app_mode == AppMode.COMPLETION: + # Only check when mode is completion + dataset_query_variable = config.get("dataset_query_variable") + + if not dataset_query_variable: + raise ValueError("Dataset query variable is required when dataset is exist") + + return config, ["agent_mode", "dataset_configs", "dataset_query_variable"] + + @classmethod + def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict: + """ + Extract dataset config for legacy compatibility + + :param tenant_id: tenant ID + :param app_mode: app mode + :param config: app model config args + """ + # Extract dataset config for legacy compatibility + if not config.get("agent_mode"): + config["agent_mode"] = {"enabled": False, "tools": []} + + if not isinstance(config["agent_mode"], dict): + raise ValueError("agent_mode must be of object type") + + # enabled + if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: + config["agent_mode"]["enabled"] = False + + if not isinstance(config["agent_mode"]["enabled"], bool): + raise ValueError("enabled in agent_mode must be of boolean type") + + # tools + if not config["agent_mode"].get("tools"): + config["agent_mode"]["tools"] = [] + + if not isinstance(config["agent_mode"]["tools"], list): + raise ValueError("tools in agent_mode must be a list of objects") + + # strategy + if not config["agent_mode"].get("strategy"): + config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + + has_datasets = False + if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}: + for tool in config["agent_mode"]["tools"]: + key = list(tool.keys())[0] + if key == "dataset": + # old style, use tool name as key + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + tool_item["enabled"] = False + + if not isinstance(tool_item["enabled"], bool): + raise ValueError("enabled in agent_mode.tools must be of boolean type") + + if "id" not in tool_item: + raise ValueError("id is required in dataset") + + try: + uuid.UUID(tool_item["id"]) + except ValueError: + raise ValueError("id in dataset must be of UUID type") + + if not cls.is_dataset_exists(tenant_id, tool_item["id"]): + raise ValueError("Dataset ID does not exist, please check your permission.") + + has_datasets = True + + need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"] + + if need_manual_query_datasets and app_mode == AppMode.COMPLETION: + # Only check when mode is completion + dataset_query_variable = config.get("dataset_query_variable") + + if not dataset_query_variable: + raise ValueError("Dataset query variable is required when dataset is exist") + + return config + + @classmethod + def is_dataset_exists(cls, tenant_id: str, dataset_id: str) -> bool: + # verify if the dataset ID exists + dataset = DatasetService.get_dataset(dataset_id) + + if not dataset: + return False + + if dataset.tenant_id != tenant_id: + return False + + return True diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/__init__.py b/api/core/app/app_config/easy_ui_based_app/model_config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py new file mode 100644 index 0000000000000000000000000000000000000000..cdc82860c6cc4d4c51c4ffa3f515779c4278c079 --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -0,0 +1,87 @@ +from typing import cast + +from core.app.app_config.entities import EasyUIBasedAppConfig +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.model_entities import ModelStatus +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.provider_manager import ProviderManager + + +class ModelConfigConverter: + @classmethod + def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity: + """ + Convert app model config dict to entity. + :param app_config: app config + :param skip_check: skip check + :raises ProviderTokenNotInitError: provider token not init error + :return: app orchestration config entity + """ + model_config = app_config.model + + provider_manager = ProviderManager() + provider_model_bundle = provider_manager.get_provider_model_bundle( + tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM + ) + + provider_name = provider_model_bundle.configuration.provider.provider + model_name = model_config.model + + model_type_instance = provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + # check model credentials + model_credentials = provider_model_bundle.configuration.get_current_credentials( + model_type=ModelType.LLM, model=model_config.model + ) + + if model_credentials is None: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_config.model, model_type=ModelType.LLM + ) + + if provider_model is None: + model_name = model_config.model + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = model_config.parameters + stop = [] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] + + # get model mode + model_mode = model_config.mode + if not model_mode: + mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials) + + model_mode = mode_enum.value + + model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) + + if not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return ModelConfigWithCredentialsEntity( + provider=model_config.provider, + model=model_config.model, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..6426865115126f78581d35b17d9fa65cabf7fd33 --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -0,0 +1,114 @@ +from collections.abc import Mapping +from typing import Any + +from core.app.app_config.entities import ModelConfigEntity +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.model_runtime.model_providers import model_provider_factory +from core.provider_manager import ProviderManager + + +class ModelConfigManager: + @classmethod + def convert(cls, config: dict) -> ModelConfigEntity: + """ + Convert model config to model config + + :param config: model config args + """ + # model config + model_config = config.get("model") + + if not model_config: + raise ValueError("model is required") + + completion_params = model_config.get("completion_params") + stop = [] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] + + # get model mode + model_mode = model_config.get("mode") + + return ModelConfigEntity( + provider=config["model"]["provider"], + model=config["model"]["name"], + mode=model_mode, + parameters=completion_params, + stop=stop, + ) + + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]: + """ + Validate and set defaults for model config + + :param tenant_id: tenant id + :param config: app model config args + """ + if "model" not in config: + raise ValueError("model is required") + + if not isinstance(config["model"], dict): + raise ValueError("model must be of object type") + + # model.provider + provider_entities = model_provider_factory.get_providers() + model_provider_names = [provider.provider for provider in provider_entities] + if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names: + raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") + + # model.name + if "name" not in config["model"]: + raise ValueError("model.name is required") + + provider_manager = ProviderManager() + models = provider_manager.get_configurations(tenant_id).get_models( + provider=config["model"]["provider"], model_type=ModelType.LLM + ) + + if not models: + raise ValueError("model.name must be in the specified model list") + + model_ids = [m.model for m in models] + if config["model"]["name"] not in model_ids: + raise ValueError("model.name must be in the specified model list") + + model_mode = None + for model in models: + if model.model == config["model"]["name"]: + model_mode = model.model_properties.get(ModelPropertyKey.MODE) + break + + # model.mode + if model_mode: + config["model"]["mode"] = model_mode + else: + config["model"]["mode"] = "completion" + + # model.completion_params + if "completion_params" not in config["model"]: + raise ValueError("model.completion_params is required") + + config["model"]["completion_params"] = cls.validate_model_completion_params( + config["model"]["completion_params"] + ) + + return dict(config), ["model"] + + @classmethod + def validate_model_completion_params(cls, cp: dict) -> dict: + # model.completion_params + if not isinstance(cp, dict): + raise ValueError("model.completion_params must be of object type") + + # stop + if "stop" not in cp: + cp["stop"] = [] + elif not isinstance(cp["stop"], list): + raise ValueError("stop in model.completion_params must be of list type") + + if len(cp["stop"]) > 4: + raise ValueError("stop sequences must be less than 4") + + return cp diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/__init__.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..fa30511f6347802fb2457282c7d2a33d4b305e86 --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -0,0 +1,138 @@ +from core.app.app_config.entities import ( + AdvancedChatMessageEntity, + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, + PromptTemplateEntity, +) +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.prompt.simple_prompt_transform import ModelMode +from models.model import AppMode + + +class PromptTemplateConfigManager: + @classmethod + def convert(cls, config: dict) -> PromptTemplateEntity: + if not config.get("prompt_type"): + raise ValueError("prompt_type is required") + + prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"]) + if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + simple_prompt_template = config.get("pre_prompt", "") + return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template) + else: + advanced_chat_prompt_template = None + chat_prompt_config = config.get("chat_prompt_config", {}) + if chat_prompt_config: + chat_prompt_messages = [] + for message in chat_prompt_config.get("prompt", []): + chat_prompt_messages.append( + AdvancedChatMessageEntity( + **{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} + ) + ) + + advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) + + advanced_completion_prompt_template = None + completion_prompt_config = config.get("completion_prompt_config", {}) + if completion_prompt_config: + completion_prompt_template_params = { + "prompt": completion_prompt_config["prompt"]["text"], + } + + if "conversation_histories_role" in completion_prompt_config: + completion_prompt_template_params["role_prefix"] = { + "user": completion_prompt_config["conversation_histories_role"]["user_prefix"], + "assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"], + } + + advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( + **completion_prompt_template_params + ) + + return PromptTemplateEntity( + prompt_type=prompt_type, + advanced_chat_prompt_template=advanced_chat_prompt_template, + advanced_completion_prompt_template=advanced_completion_prompt_template, + ) + + @classmethod + def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: + """ + Validate pre_prompt and set defaults for prompt feature + depending on the config['model'] + + :param app_mode: app mode + :param config: app model config args + """ + if not config.get("prompt_type"): + config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value + + prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] + if config["prompt_type"] not in prompt_type_vals: + raise ValueError(f"prompt_type must be in {prompt_type_vals}") + + # chat_prompt_config + if not config.get("chat_prompt_config"): + config["chat_prompt_config"] = {} + + if not isinstance(config["chat_prompt_config"], dict): + raise ValueError("chat_prompt_config must be of object type") + + # completion_prompt_config + if not config.get("completion_prompt_config"): + config["completion_prompt_config"] = {} + + if not isinstance(config["completion_prompt_config"], dict): + raise ValueError("completion_prompt_config must be of object type") + + if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value: + if not config["chat_prompt_config"] and not config["completion_prompt_config"]: + raise ValueError( + "chat_prompt_config or completion_prompt_config is required when prompt_type is advanced" + ) + + model_mode_vals = [mode.value for mode in ModelMode] + if config["model"]["mode"] not in model_mode_vals: + raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") + + if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value: + user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] + assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] + + if not user_prefix: + config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human" + + if not assistant_prefix: + config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant" + + if config["model"]["mode"] == ModelMode.CHAT.value: + prompt_list = config["chat_prompt_config"]["prompt"] + + if len(prompt_list) > 10: + raise ValueError("prompt messages must be less than 10") + else: + # pre_prompt, for simple mode + if not config.get("pre_prompt"): + config["pre_prompt"] = "" + + if not isinstance(config["pre_prompt"], str): + raise ValueError("pre_prompt must be of string type") + + return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"] + + @classmethod + def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict: + """ + Validate post_prompt and set defaults for prompt feature + + :param config: app model config args + """ + # post_prompt + if not config.get("post_prompt"): + config["post_prompt"] = "" + + if not isinstance(config["post_prompt"], str): + raise ValueError("post_prompt must be of string type") + + return config diff --git a/api/core/app/app_config/easy_ui_based_app/variables/__init__.py b/api/core/app/app_config/easy_ui_based_app/variables/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2445a33639edebcd5d07f19c33c9fd7bdad3d9 --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -0,0 +1,168 @@ +import re + +from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType +from core.external_data_tool.factory import ExternalDataToolFactory + + +class BasicVariablesConfigManager: + @classmethod + def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: + """ + Convert model config to model config + + :param config: model config args + """ + external_data_variables = [] + variable_entities = [] + + # old external_data_tools + external_data_tools = config.get("external_data_tools", []) + for external_data_tool in external_data_tools: + if "enabled" not in external_data_tool or not external_data_tool["enabled"]: + continue + + external_data_variables.append( + ExternalDataVariableEntity( + variable=external_data_tool["variable"], + type=external_data_tool["type"], + config=external_data_tool["config"], + ) + ) + + # variables and external_data_tools + for variables in config.get("user_input_form", []): + variable_type = list(variables.keys())[0] + if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: + variable = variables[variable_type] + if "config" not in variable: + continue + + external_data_variables.append( + ExternalDataVariableEntity( + variable=variable["variable"], type=variable["type"], config=variable["config"] + ) + ) + elif variable_type in { + VariableEntityType.TEXT_INPUT, + VariableEntityType.PARAGRAPH, + VariableEntityType.NUMBER, + VariableEntityType.SELECT, + }: + variable = variables[variable_type] + variable_entities.append( + VariableEntity( + type=variable_type, + variable=variable.get("variable"), + description=variable.get("description") or "", + label=variable.get("label"), + required=variable.get("required", False), + max_length=variable.get("max_length"), + options=variable.get("options") or [], + ) + ) + + return variable_entities, external_data_variables + + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for user input form + + :param tenant_id: workspace id + :param config: app model config args + """ + related_config_keys = [] + config, current_related_config_keys = cls.validate_variables_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + config, current_related_config_keys = cls.validate_external_data_tools_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + return config, related_config_keys + + @classmethod + def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for user input form + + :param config: app model config args + """ + if not config.get("user_input_form"): + config["user_input_form"] = [] + + if not isinstance(config["user_input_form"], list): + raise ValueError("user_input_form must be a list of objects") + + variables = [] + for item in config["user_input_form"]: + key = list(item.keys())[0] + if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}: + raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") + + form_item = item[key] + if "label" not in form_item: + raise ValueError("label is required in user_input_form") + + if not isinstance(form_item["label"], str): + raise ValueError("label in user_input_form must be of string type") + + if "variable" not in form_item: + raise ValueError("variable is required in user_input_form") + + if not isinstance(form_item["variable"], str): + raise ValueError("variable in user_input_form must be of string type") + + pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") + if pattern.match(form_item["variable"]) is None: + raise ValueError("variable in user_input_form must be a string, and cannot start with a number") + + variables.append(form_item["variable"]) + + if "required" not in form_item or not form_item["required"]: + form_item["required"] = False + + if not isinstance(form_item["required"], bool): + raise ValueError("required in user_input_form must be of boolean type") + + if key == "select": + if "options" not in form_item or not form_item["options"]: + form_item["options"] = [] + + if not isinstance(form_item["options"], list): + raise ValueError("options in user_input_form must be a list of strings") + + if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]: + raise ValueError("default value in user_input_form must be in the options list") + + return config, ["user_input_form"] + + @classmethod + def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for external data fetch feature + + :param tenant_id: workspace id + :param config: app model config args + """ + if not config.get("external_data_tools"): + config["external_data_tools"] = [] + + if not isinstance(config["external_data_tools"], list): + raise ValueError("external_data_tools must be of list type") + + for tool in config["external_data_tools"]: + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + + if not tool["enabled"]: + continue + + if "type" not in tool or not tool["type"]: + raise ValueError("external_data_tools[].type is required") + + typ = tool["type"] + config = tool["config"] + + ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config) + + return config, ["external_data_tools"] diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py new file mode 100644 index 0000000000000000000000000000000000000000..15bd353484e372661d8cd88c1c5ee54db114e537 --- /dev/null +++ b/api/core/app/app_config/entities.py @@ -0,0 +1,267 @@ +from collections.abc import Sequence +from enum import Enum, StrEnum +from typing import Any, Optional + +from pydantic import BaseModel, Field, field_validator + +from core.file import FileTransferMethod, FileType, FileUploadConfig +from core.model_runtime.entities.message_entities import PromptMessageRole +from models.model import AppMode + + +class ModelConfigEntity(BaseModel): + """ + Model Config Entity. + """ + + provider: str + model: str + mode: Optional[str] = None + parameters: dict[str, Any] = {} + stop: list[str] = [] + + +class AdvancedChatMessageEntity(BaseModel): + """ + Advanced Chat Message Entity. + """ + + text: str + role: PromptMessageRole + + +class AdvancedChatPromptTemplateEntity(BaseModel): + """ + Advanced Chat Prompt Template Entity. + """ + + messages: list[AdvancedChatMessageEntity] + + +class AdvancedCompletionPromptTemplateEntity(BaseModel): + """ + Advanced Completion Prompt Template Entity. + """ + + class RolePrefixEntity(BaseModel): + """ + Role Prefix Entity. + """ + + user: str + assistant: str + + prompt: str + role_prefix: Optional[RolePrefixEntity] = None + + +class PromptTemplateEntity(BaseModel): + """ + Prompt Template Entity. + """ + + class PromptType(Enum): + """ + Prompt Type. + 'simple', 'advanced' + """ + + SIMPLE = "simple" + ADVANCED = "advanced" + + @classmethod + def value_of(cls, value: str): + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid prompt type value {value}") + + prompt_type: PromptType + simple_prompt_template: Optional[str] = None + advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None + advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None + + +class VariableEntityType(StrEnum): + TEXT_INPUT = "text-input" + SELECT = "select" + PARAGRAPH = "paragraph" + NUMBER = "number" + EXTERNAL_DATA_TOOL = "external_data_tool" + FILE = "file" + FILE_LIST = "file-list" + + +class VariableEntity(BaseModel): + """ + Variable Entity. + """ + + variable: str + label: str + description: str = "" + type: VariableEntityType + required: bool = False + max_length: Optional[int] = None + options: Sequence[str] = Field(default_factory=list) + allowed_file_types: Sequence[FileType] = Field(default_factory=list) + allowed_file_extensions: Sequence[str] = Field(default_factory=list) + allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + + @field_validator("description", mode="before") + @classmethod + def convert_none_description(cls, v: Any) -> str: + return v or "" + + @field_validator("options", mode="before") + @classmethod + def convert_none_options(cls, v: Any) -> Sequence[str]: + return v or [] + + +class ExternalDataVariableEntity(BaseModel): + """ + External Data Variable Entity. + """ + + variable: str + type: str + config: dict[str, Any] = {} + + +class DatasetRetrieveConfigEntity(BaseModel): + """ + Dataset Retrieve Config Entity. + """ + + class RetrieveStrategy(Enum): + """ + Dataset Retrieve Strategy. + 'single' or 'multiple' + """ + + SINGLE = "single" + MULTIPLE = "multiple" + + @classmethod + def value_of(cls, value: str): + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid retrieve strategy value {value}") + + query_variable: Optional[str] = None # Only when app mode is completion + + retrieve_strategy: RetrieveStrategy + top_k: Optional[int] = None + score_threshold: Optional[float] = 0.0 + rerank_mode: Optional[str] = "reranking_model" + reranking_model: Optional[dict] = None + weights: Optional[dict] = None + reranking_enabled: Optional[bool] = True + + +class DatasetEntity(BaseModel): + """ + Dataset Config Entity. + """ + + dataset_ids: list[str] + retrieve_config: DatasetRetrieveConfigEntity + + +class SensitiveWordAvoidanceEntity(BaseModel): + """ + Sensitive Word Avoidance Entity. + """ + + type: str + config: dict[str, Any] = {} + + +class TextToSpeechEntity(BaseModel): + """ + Sensitive Word Avoidance Entity. + """ + + enabled: bool + voice: Optional[str] = None + language: Optional[str] = None + + +class TracingConfigEntity(BaseModel): + """ + Tracing Config Entity. + """ + + enabled: bool + tracing_provider: str + + +class AppAdditionalFeatures(BaseModel): + file_upload: Optional[FileUploadConfig] = None + opening_statement: Optional[str] = None + suggested_questions: list[str] = [] + suggested_questions_after_answer: bool = False + show_retrieve_source: bool = False + more_like_this: bool = False + speech_to_text: bool = False + text_to_speech: Optional[TextToSpeechEntity] = None + trace_config: Optional[TracingConfigEntity] = None + + +class AppConfig(BaseModel): + """ + Application Config Entity. + """ + + tenant_id: str + app_id: str + app_mode: AppMode + additional_features: AppAdditionalFeatures + variables: list[VariableEntity] = [] + sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None + + +class EasyUIBasedAppModelConfigFrom(Enum): + """ + App Model Config From. + """ + + ARGS = "args" + APP_LATEST_CONFIG = "app-latest-config" + CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config" + + +class EasyUIBasedAppConfig(AppConfig): + """ + Easy UI Based App Config Entity. + """ + + app_model_config_from: EasyUIBasedAppModelConfigFrom + app_model_config_id: str + app_model_config_dict: dict + model: ModelConfigEntity + prompt_template: PromptTemplateEntity + dataset: Optional[DatasetEntity] = None + external_data_variables: list[ExternalDataVariableEntity] = [] + + +class WorkflowUIBasedAppConfig(AppConfig): + """ + Workflow UI Based App Config Entity. + """ + + workflow_id: str diff --git a/api/core/app/app_config/features/__init__.py b/api/core/app/app_config/features/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/features/file_upload/__init__.py b/api/core/app/app_config/features/file_upload/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..0dc4efc47a53aa5ccba666f08ba61fb650156e61 --- /dev/null +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -0,0 +1,44 @@ +from collections.abc import Mapping +from typing import Any + +from core.file import FileUploadConfig + + +class FileUploadConfigManager: + @classmethod + def convert(cls, config: Mapping[str, Any], is_vision: bool = True): + """ + Convert model config to model config + + :param config: model config args + :param is_vision: if True, the feature is vision feature + """ + file_upload_dict = config.get("file_upload") + if file_upload_dict: + if file_upload_dict.get("enabled"): + transform_methods = file_upload_dict.get("allowed_file_upload_methods", []) + data = { + "image_config": { + "number_limits": file_upload_dict["number_limits"], + "transfer_methods": transform_methods, + } + } + + if is_vision: + data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low") + + return FileUploadConfig.model_validate(data) + + @classmethod + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for file upload feature + + :param config: app model config args + """ + if not config.get("file_upload"): + config["file_upload"] = {} + else: + FileUploadConfig.model_validate(config["file_upload"]) + + return config, ["file_upload"] diff --git a/api/core/app/app_config/features/more_like_this/__init__.py b/api/core/app/app_config/features/more_like_this/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/features/more_like_this/manager.py b/api/core/app/app_config/features/more_like_this/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..496e1beeecfa0f41634af482a9adbaa371d9ebf9 --- /dev/null +++ b/api/core/app/app_config/features/more_like_this/manager.py @@ -0,0 +1,36 @@ +class MoreLikeThisConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + + :param config: model config args + """ + more_like_this = False + more_like_this_dict = config.get("more_like_this") + if more_like_this_dict: + if more_like_this_dict.get("enabled"): + more_like_this = True + + return more_like_this + + @classmethod + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for more like this feature + + :param config: app model config args + """ + if not config.get("more_like_this"): + config["more_like_this"] = {"enabled": False} + + if not isinstance(config["more_like_this"], dict): + raise ValueError("more_like_this must be of dict type") + + if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]: + config["more_like_this"]["enabled"] = False + + if not isinstance(config["more_like_this"]["enabled"], bool): + raise ValueError("enabled in more_like_this must be of boolean type") + + return config, ["more_like_this"] diff --git a/api/core/app/app_config/features/opening_statement/__init__.py b/api/core/app/app_config/features/opening_statement/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..92b4185abf0183a9284d56c1202f1e2377faf7a3 --- /dev/null +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -0,0 +1,41 @@ +class OpeningStatementConfigManager: + @classmethod + def convert(cls, config: dict) -> tuple[str, list]: + """ + Convert model config to model config + + :param config: model config args + """ + # opening statement + opening_statement = config.get("opening_statement", "") + + # suggested questions + suggested_questions_list = config.get("suggested_questions", []) + + return opening_statement, suggested_questions_list + + @classmethod + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for opening statement feature + + :param config: app model config args + """ + if not config.get("opening_statement"): + config["opening_statement"] = "" + + if not isinstance(config["opening_statement"], str): + raise ValueError("opening_statement must be of string type") + + # suggested_questions + if not config.get("suggested_questions"): + config["suggested_questions"] = [] + + if not isinstance(config["suggested_questions"], list): + raise ValueError("suggested_questions must be of list type") + + for question in config["suggested_questions"]: + if not isinstance(question, str): + raise ValueError("Elements in suggested_questions list must be of string type") + + return config, ["opening_statement", "suggested_questions"] diff --git a/api/core/app/app_config/features/retrieval_resource/__init__.py b/api/core/app/app_config/features/retrieval_resource/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/features/retrieval_resource/manager.py b/api/core/app/app_config/features/retrieval_resource/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..d098abac2fa2e7fcd23675d2503be15128cccf76 --- /dev/null +++ b/api/core/app/app_config/features/retrieval_resource/manager.py @@ -0,0 +1,31 @@ +class RetrievalResourceConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + show_retrieve_source = False + retriever_resource_dict = config.get("retriever_resource") + if retriever_resource_dict: + if retriever_resource_dict.get("enabled"): + show_retrieve_source = True + + return show_retrieve_source + + @classmethod + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for retriever resource feature + + :param config: app model config args + """ + if not config.get("retriever_resource"): + config["retriever_resource"] = {"enabled": False} + + if not isinstance(config["retriever_resource"], dict): + raise ValueError("retriever_resource must be of dict type") + + if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]: + config["retriever_resource"]["enabled"] = False + + if not isinstance(config["retriever_resource"]["enabled"], bool): + raise ValueError("enabled in retriever_resource must be of boolean type") + + return config, ["retriever_resource"] diff --git a/api/core/app/app_config/features/speech_to_text/__init__.py b/api/core/app/app_config/features/speech_to_text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/features/speech_to_text/manager.py b/api/core/app/app_config/features/speech_to_text/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..e10ae03e043b786348cba86b36f9584d6f67cb54 --- /dev/null +++ b/api/core/app/app_config/features/speech_to_text/manager.py @@ -0,0 +1,36 @@ +class SpeechToTextConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + + :param config: model config args + """ + speech_to_text = False + speech_to_text_dict = config.get("speech_to_text") + if speech_to_text_dict: + if speech_to_text_dict.get("enabled"): + speech_to_text = True + + return speech_to_text + + @classmethod + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for speech to text feature + + :param config: app model config args + """ + if not config.get("speech_to_text"): + config["speech_to_text"] = {"enabled": False} + + if not isinstance(config["speech_to_text"], dict): + raise ValueError("speech_to_text must be of dict type") + + if "enabled" not in config["speech_to_text"] or not config["speech_to_text"]["enabled"]: + config["speech_to_text"]["enabled"] = False + + if not isinstance(config["speech_to_text"]["enabled"], bool): + raise ValueError("enabled in speech_to_text must be of boolean type") + + return config, ["speech_to_text"] diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/__init__.py b/api/core/app/app_config/features/suggested_questions_after_answer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac5114d12dd4449b4a3a0f5fce80513d78b4564 --- /dev/null +++ b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py @@ -0,0 +1,39 @@ +class SuggestedQuestionsAfterAnswerConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + + :param config: model config args + """ + suggested_questions_after_answer = False + suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer") + if suggested_questions_after_answer_dict: + if suggested_questions_after_answer_dict.get("enabled"): + suggested_questions_after_answer = True + + return suggested_questions_after_answer + + @classmethod + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for suggested questions feature + + :param config: app model config args + """ + if not config.get("suggested_questions_after_answer"): + config["suggested_questions_after_answer"] = {"enabled": False} + + if not isinstance(config["suggested_questions_after_answer"], dict): + raise ValueError("suggested_questions_after_answer must be of dict type") + + if ( + "enabled" not in config["suggested_questions_after_answer"] + or not config["suggested_questions_after_answer"]["enabled"] + ): + config["suggested_questions_after_answer"]["enabled"] = False + + if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): + raise ValueError("enabled in suggested_questions_after_answer must be of boolean type") + + return config, ["suggested_questions_after_answer"] diff --git a/api/core/app/app_config/features/text_to_speech/__init__.py b/api/core/app/app_config/features/text_to_speech/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/features/text_to_speech/manager.py b/api/core/app/app_config/features/text_to_speech/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..1c7598178527b45c2859e4a5e9a096824d749e42 --- /dev/null +++ b/api/core/app/app_config/features/text_to_speech/manager.py @@ -0,0 +1,45 @@ +from core.app.app_config.entities import TextToSpeechEntity + + +class TextToSpeechConfigManager: + @classmethod + def convert(cls, config: dict): + """ + Convert model config to model config + + :param config: model config args + """ + text_to_speech = None + text_to_speech_dict = config.get("text_to_speech") + if text_to_speech_dict: + if text_to_speech_dict.get("enabled"): + text_to_speech = TextToSpeechEntity( + enabled=text_to_speech_dict.get("enabled"), + voice=text_to_speech_dict.get("voice"), + language=text_to_speech_dict.get("language"), + ) + + return text_to_speech + + @classmethod + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for text to speech feature + + :param config: app model config args + """ + if not config.get("text_to_speech"): + config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""} + + if not isinstance(config["text_to_speech"], dict): + raise ValueError("text_to_speech must be of dict type") + + if "enabled" not in config["text_to_speech"] or not config["text_to_speech"]["enabled"]: + config["text_to_speech"]["enabled"] = False + config["text_to_speech"]["voice"] = "" + config["text_to_speech"]["language"] = "" + + if not isinstance(config["text_to_speech"]["enabled"], bool): + raise ValueError("enabled in text_to_speech must be of boolean type") + + return config, ["text_to_speech"] diff --git a/api/core/app/app_config/workflow_ui_based_app/__init__.py b/api/core/app/app_config/workflow_ui_based_app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/__init__.py b/api/core/app/app_config/workflow_ui_based_app/variables/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..2f1da3808231ddb70adba02ddd743c5e0f188b9d --- /dev/null +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -0,0 +1,22 @@ +from core.app.app_config.entities import VariableEntity +from models.workflow import Workflow + + +class WorkflowVariablesConfigManager: + @classmethod + def convert(cls, workflow: Workflow) -> list[VariableEntity]: + """ + Convert workflow start variables to variables + + :param workflow: workflow instance + """ + variables = [] + + # find start node + user_input_form = workflow.user_input_form() + + # variables + for variable in user_input_form: + variables.append(VariableEntity.model_validate(variable)) + + return variables diff --git a/api/core/app/apps/README.md b/api/core/app/apps/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7a57bb36585e8743701481c15a3e679a7675a1c2 --- /dev/null +++ b/api/core/app/apps/README.md @@ -0,0 +1,48 @@ +## Guidelines for Database Connection Management in App Runner and Task Pipeline + +Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks. + +Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid detach errors. + +Examples: + +1. Creating a new record: + + ```python + app = App(id=1) + db.session.add(app) + db.session.commit() + db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close + + # Handle non-long-running tasks or store the content of the App instance in memory (via variable assignment). + + db.session.close() + + return app.id + ``` + +2. Fetching a record from the table: + + ```python + app = db.session.query(App).filter(App.id == app_id).first() + + created_at = app.created_at + + db.session.close() + + # Handle tasks (include long-running). + + ``` + +3. Updating a table field: + + ```python + app = db.session.query(App).filter(App.id == app_id).first() + + app.updated_at = time.utcnow() + db.session.commit() + db.session.close() + + return app_id + ``` + diff --git a/api/core/app/apps/__init__.py b/api/core/app/apps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/apps/advanced_chat/__init__.py b/api/core/app/apps/advanced_chat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..cb606953cd7967ae6b750bc425edeaf59b2c7916 --- /dev/null +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -0,0 +1,91 @@ +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager +from models.model import App, AppMode +from models.workflow import Workflow + + +class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): + """ + Advanced Chatbot App Config Entity. + """ + + pass + + +class AdvancedChatAppConfigManager(BaseAppConfigManager): + @classmethod + def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig: + features_dict = workflow.features_dict + + app_mode = AppMode.value_of(app_model.mode) + app_config = AdvancedChatAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=app_mode, + workflow_id=workflow.id, + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict), + variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + additional_features=cls.convert_features(features_dict, app_mode), + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + """ + Validate for advanced chat app model config + + :param tenant_id: tenant id + :param config: app model config args + :param only_structure_validate: if True, only structure validation will be performed + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + config + ) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate + ) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..36f71fd47879c920a2381c8d91c9811f30a5a89a --- /dev/null +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -0,0 +1,390 @@ +import contextvars +import logging +import threading +import uuid +from collections.abc import Generator, Mapping +from typing import Any, Literal, Optional, Union, overload + +from flask import Flask, current_app +from pydantic import ValidationError + +import contexts +from configs import dify_config +from constants import UUID_NIL +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner +from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter +from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse +from core.model_runtime.errors.invoke import InvokeAuthorizationError +from core.ops.ops_trace_manager import TraceQueueManager +from core.prompt.utils.get_thread_messages_length import get_thread_messages_length +from extensions.ext_database import db +from factories import file_factory +from models.account import Account +from models.model import App, Conversation, EndUser, Message +from models.workflow import Workflow +from services.errors.message import MessageNotExistsError + +logger = logging.getLogger(__name__) + + +class AdvancedChatAppGenerator(MessageBasedAppGenerator): + _dialogue_count: int + + @overload + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... + + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ): + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not args.get("query"): + raise ValueError("query is required") + + query = args["query"] + if not isinstance(query, str): + raise ValueError("query must be a string") + + query = query.replace("\x00", "") + inputs = args["inputs"] + + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)} + + # get conversation + conversation = None + conversation_id = args.get("conversation_id") + if conversation_id: + conversation = self._get_conversation_by_user( + app_model=app_model, conversation_id=conversation_id, user=user + ) + + # parse files + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + ) + else: + file_objs = [] + + # convert to app config + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) + + if invoke_from == InvokeFrom.DEBUGGER: + # always enable retriever resource in debugger mode + app_config.additional_features.show_retrieve_source = True + + workflow_run_id = str(uuid.uuid4()) + # init application generate entity + application_generate_entity = AdvancedChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs + if conversation + else self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras=extras, + trace_manager=trace_manager, + workflow_run_id=workflow_run_id, + ) + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + + return self._generate( + workflow=workflow, + user=user, + invoke_from=invoke_from, + application_generate_entity=application_generate_entity, + conversation=conversation, + stream=streaming, + ) + + def single_iteration_generate( + self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, streaming: bool = True + ) -> Mapping[str, Any] | Generator[str, None, None]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not node_id: + raise ValueError("node_id is required") + + if args.get("inputs") is None: + raise ValueError("inputs is required") + + # convert to app config + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + + # init application generate entity + application_generate_entity = AdvancedChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + conversation_id=None, + inputs={}, + query="", + files=[], + user_id=user.id, + stream=streaming, + invoke_from=InvokeFrom.DEBUGGER, + extras={"auto_generate_conversation_name": False}, + single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( + node_id=node_id, inputs=args["inputs"] + ), + ) + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + + return self._generate( + workflow=workflow, + user=user, + invoke_from=InvokeFrom.DEBUGGER, + application_generate_entity=application_generate_entity, + conversation=None, + stream=streaming, + ) + + def _generate( + self, + *, + workflow: Workflow, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + application_generate_entity: AdvancedChatAppGenerateEntity, + conversation: Optional[Conversation] = None, + stream: bool = True, + ) -> Mapping[str, Any] | Generator[str, None, None]: + """ + Generate App response. + + :param workflow: Workflow + :param user: account or end user + :param invoke_from: invoke from source + :param application_generate_entity: application generate entity + :param conversation: conversation + :param stream: is stream + """ + is_first_conversation = False + if not conversation: + is_first_conversation = True + + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + + if is_first_conversation: + # update conversation features + conversation.override_model_configs = workflow.features + db.session.commit() + db.session.refresh(conversation) + + # get conversation dialogue count + self._dialogue_count = get_thread_messages_length(conversation.id) + + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, + ) + + # new thread + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + "context": contextvars.copy_context(), + }, + ) + + worker_thread.start() + + # return response or stream generator + response = self._handle_advanced_chat_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=stream, + ) + + return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str, + context: contextvars.Context, + ) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + for var, val in context.items(): + var.set(val) + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") + + # chatbot app + runner = AdvancedChatAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + dialogue_count=self._dialogue_count, + ) + + runner.run() + except GenerateTaskStoppedError: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except ValueError as e: + if dify_config.DEBUG: + logger.exception("Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.close() + + def _handle_advanced_chat_response( + self, + *, + application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False, + ) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param workflow: workflow + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + :param user: account or end user + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=stream, + dialogue_count=self._dialogue_count, + ) + + try: + return generate_task_pipeline.process() + except ValueError as e: + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error + raise GenerateTaskStoppedError() + else: + logger.exception(f"Failed to process generate task pipeline, conversation_id: {conversation.id}") + raise e diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py new file mode 100644 index 0000000000000000000000000000000000000000..a506447671abfb176fa321bb4c7387476c80a95f --- /dev/null +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -0,0 +1,156 @@ +import base64 +import concurrent.futures +import logging +import queue +import re +import threading +from collections.abc import Iterable +from typing import Optional + +from core.app.entities.queue_entities import ( + MessageQueueMessage, + QueueAgentMessageEvent, + QueueLLMChunkEvent, + QueueNodeSucceededEvent, + QueueTextChunkEvent, + WorkflowQueueMessage, +) +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.message_entities import TextPromptMessageContent +from core.model_runtime.entities.model_entities import ModelType + + +class AudioTrunk: + def __init__(self, status: str, audio): + self.audio = audio + self.status = status + + +def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str): + if not text_content or text_content.isspace(): + return + return model_instance.invoke_tts( + content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice + ) + + +def _process_future( + future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None], + audio_queue: queue.Queue[AudioTrunk], +): + while True: + try: + future = future_queue.get() + if future is None: + break + invoke_result = future.result() + if not invoke_result: + continue + for audio in invoke_result: + audio_base64 = base64.b64encode(bytes(audio)) + audio_queue.put(AudioTrunk("responding", audio=audio_base64)) + except Exception as e: + logging.getLogger(__name__).warning(e) + break + audio_queue.put(AudioTrunk("finish", b"")) + + +class AppGeneratorTTSPublisher: + def __init__(self, tenant_id: str, voice: str): + self.logger = logging.getLogger(__name__) + self.tenant_id = tenant_id + self.msg_text = "" + self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue() + self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() + self.match = re.compile(r"[。.!?]") + self.model_manager = ModelManager() + self.model_instance = self.model_manager.get_default_model_instance( + tenant_id=self.tenant_id, model_type=ModelType.TTS + ) + self.voices = self.model_instance.get_tts_voices() + values = [voice.get("value") for voice in self.voices] + self.voice = voice + if not voice or voice not in values: + self.voice = self.voices[0].get("value") + self.MAX_SENTENCE = 2 + self._last_audio_event: Optional[AudioTrunk] = None + # FIXME better way to handle this threading.start + threading.Thread(target=self._runtime).start() + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) + + def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /): + self._msg_queue.put(message) + + def _runtime(self): + future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue() + threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start() + while True: + try: + message = self._msg_queue.get() + if message is None: + if self.msg_text and len(self.msg_text.strip()) > 0: + futures_result = self.executor.submit( + _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice + ) + future_queue.put(futures_result) + break + elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): + message_content = message.event.chunk.delta.message.content + if not message_content: + continue + if isinstance(message_content, str): + self.msg_text += message_content + elif isinstance(message_content, list): + for content in message_content: + if not isinstance(content, TextPromptMessageContent): + continue + self.msg_text += content.data + elif isinstance(message.event, QueueTextChunkEvent): + self.msg_text += message.event.text + elif isinstance(message.event, QueueNodeSucceededEvent): + if message.event.outputs is None: + continue + self.msg_text += message.event.outputs.get("output", "") + self.last_message = message + sentence_arr, text_tmp = self._extract_sentence(self.msg_text) + if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): + self.MAX_SENTENCE += 1 + text_content = "".join(sentence_arr) + futures_result = self.executor.submit( + _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice + ) + future_queue.put(futures_result) + if text_tmp: + self.msg_text = text_tmp + else: + self.msg_text = "" + + except Exception as e: + self.logger.warning(e) + break + future_queue.put(None) + + def check_and_get_audio(self): + try: + if self._last_audio_event and self._last_audio_event.status == "finish": + if self.executor: + self.executor.shutdown(wait=False) + return self._last_audio_event + audio = self._audio_queue.get_nowait() + if audio and audio.status == "finish": + self.executor.shutdown(wait=False) + if audio: + self._last_audio_event = audio + return audio + except queue.Empty: + return None + + def _extract_sentence(self, org_text): + tx = self.match.finditer(org_text) + start = 0 + result = [] + for i in tx: + end = i.regs[0][1] + result.append(org_text[start:end]) + start = end + return result, org_text[start:] diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..6339d79898480094a216cd801abbcd67ef44f4a2 --- /dev/null +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -0,0 +1,229 @@ +import logging +from collections.abc import Mapping +from typing import Any, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from configs import dify_config +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import ( + QueueAnnotationReplyEvent, + QueueStopEvent, + QueueTextChunkEvent, +) +from core.moderation.base import ModerationError +from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db +from models.enums import UserFrom +from models.model import App, Conversation, EndUser, Message +from models.workflow import ConversationVariable, WorkflowType + +logger = logging.getLogger(__name__) + + +class AdvancedChatAppRunner(WorkflowBasedAppRunner): + """ + AdvancedChat Application Runner + """ + + def __init__( + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + dialogue_count: int, + ) -> None: + super().__init__(queue_manager) + + self.application_generate_entity = application_generate_entity + self.conversation = conversation + self.message = message + self._dialogue_count = dialogue_count + + def run(self) -> None: + app_config = self.application_generate_entity.app_config + app_config = cast(AdvancedChatAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + if not app_record: + raise ValueError("App not found") + + workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + + user_id = None + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = self.application_generate_entity.user_id + + workflow_callbacks: list[WorkflowCallback] = [] + if dify_config.DEBUG: + workflow_callbacks.append(WorkflowLoggingCallback()) + + if self.application_generate_entity.single_iteration_run: + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs, + ) + else: + inputs = self.application_generate_entity.inputs + query = self.application_generate_entity.query + files = self.application_generate_entity.files + + # moderation + if self.handle_input_moderation( + app_record=app_record, + app_generate_entity=self.application_generate_entity, + inputs=inputs, + query=query, + message_id=self.message.id, + ): + return + + # annotation reply + if self.handle_annotation_reply( + app_record=app_record, + message=self.message, + query=query, + app_generate_entity=self.application_generate_entity, + ): + return + + # Init conversation variables + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == self.conversation.app_id, + ConversationVariable.conversation_id == self.conversation.id, + ) + with Session(db.engine) as session: + db_conversation_variables = session.scalars(stmt).all() + if not db_conversation_variables: + # Create conversation variables if they don't exist. + db_conversation_variables = [ + ConversationVariable.from_variable( + app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable + ) + for variable in workflow.conversation_variables + ] + session.add_all(db_conversation_variables) + # Convert database entities to variables. + conversation_variables = [item.to_variable() for item in db_conversation_variables] + + session.commit() + + # Create a variable pool. + system_inputs = { + SystemVariableKey.QUERY: query, + SystemVariableKey.FILES: files, + SystemVariableKey.CONVERSATION_ID: self.conversation.id, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count, + SystemVariableKey.APP_ID: app_config.app_id, + SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, + SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, + } + + # init variable pool + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, + ) + + # init graph + graph = self._init_graph(graph_config=workflow.graph_dict) + + db.session.close() + + # RUN WORKFLOW + workflow_entry = WorkflowEntry( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=WorkflowType.value_of(workflow.type), + graph=graph, + graph_config=workflow.graph_dict, + user_id=self.application_generate_entity.user_id, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else UserFrom.END_USER + ), + invoke_from=self.application_generate_entity.invoke_from, + call_depth=self.application_generate_entity.call_depth, + variable_pool=variable_pool, + ) + + generator = workflow_entry.run( + callbacks=workflow_callbacks, + ) + + for event in generator: + self._handle_event(workflow_entry, event) + + def handle_input_moderation( + self, + app_record: App, + app_generate_entity: AdvancedChatAppGenerateEntity, + inputs: Mapping[str, Any], + query: str, + message_id: str, + ) -> bool: + try: + # process sensitive_word_avoidance + _, inputs, query = self.moderation_for_inputs( + app_id=app_record.id, + tenant_id=app_generate_entity.app_config.tenant_id, + app_generate_entity=app_generate_entity, + inputs=inputs, + query=query, + message_id=message_id, + ) + except ModerationError as e: + self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION) + return True + + return False + + def handle_annotation_reply( + self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity + ) -> bool: + annotation_reply = self.query_app_annotations_to_reply( + app_record=app_record, + message=message, + query=query, + user_id=app_generate_entity.user_id, + invoke_from=app_generate_entity.invoke_from, + ) + + if annotation_reply: + self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)) + + self._complete_with_stream_output( + text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY + ) + return True + + return False + + def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: + """ + Direct output + """ + self._publish_event(QueueTextChunkEvent(text=text)) + + self._publish_event(QueueStopEvent(stopped_by=stopped_by)) diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..5fbd3e9a94906fb30e92e0561792a6561d2f6122 --- /dev/null +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -0,0 +1,126 @@ +import json +from collections.abc import Generator +from typing import Any, cast + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.task_entities import ( + AppBlockingResponse, + AppStreamResponse, + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + PingStreamResponse, +) + + +class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): + _blocking_response_type = ChatbotAppBlockingResponse + + @classmethod + def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + """ + Convert blocking full response. + :param blocking_response: blocking response + :return: + """ + blocking_response = cast(ChatbotAppBlockingResponse, blocking_response) + response = { + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, + } + + return response + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + """ + Convert blocking simple response. + :param blocking_response: blocking response + :return: + """ + response = cls.convert_blocking_full_response(blocking_response) + + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) + + return response + + @classmethod + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, Any, None]: + """ + Convert stream full response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(ChatbotAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, + } + + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + else: + response_chunk.update(sub_stream_response.to_dict()) + yield json.dumps(response_chunk) + + @classmethod + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, Any, None]: + """ + Convert stream simple response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(ChatbotAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, + } + + if isinstance(sub_stream_response, MessageEndStreamResponse): + sub_stream_response_dict = sub_stream_response.to_dict() + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) + response_chunk.update(sub_stream_response_dict) + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) + else: + response_chunk.update(sub_stream_response.to_dict()) + + yield json.dumps(response_chunk) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6aad805034ba9ca7c2e71a928dcc2ed017616b6c --- /dev/null +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -0,0 +1,742 @@ +import json +import logging +import time +from collections.abc import Generator, Mapping +from threading import Thread +from typing import Any, Optional, Union + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME +from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + InvokeFrom, +) +from core.app.entities.queue_entities import ( + QueueAdvancedChatMessageEndEvent, + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueMessageReplaceEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, + MessageEndStreamResponse, + StreamResponse, + WorkflowTaskState, +) +from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline +from core.app.task_pipeline.message_cycle_manage import MessageCycleManage +from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.utils.encoders import jsonable_encoder +from core.ops.ops_trace_manager import TraceQueueManager +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes import NodeType +from events.message_event import message_was_created +from extensions.ext_database import db +from models import Conversation, EndUser, Message, MessageFile +from models.account import Account +from models.enums import CreatedByRole +from models.workflow import ( + Workflow, + WorkflowRunStatus, +) + +logger = logging.getLogger(__name__) + + +class AdvancedChatAppGenerateTaskPipeline: + """ + AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. + """ + + def __init__( + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool, + dialogue_count: int, + ) -> None: + self._base_task_pipeline = BasedGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) + + if isinstance(user, EndUser): + self._user_id = user.id + user_session_id = user.session_id + self._created_by_role = CreatedByRole.END_USER + elif isinstance(user, Account): + self._user_id = user.id + user_session_id = user.id + self._created_by_role = CreatedByRole.ACCOUNT + else: + raise NotImplementedError(f"User type not supported: {type(user)}") + + self._workflow_cycle_manager = WorkflowCycleManage( + application_generate_entity=application_generate_entity, + workflow_system_variables={ + SystemVariableKey.QUERY: message.query, + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.CONVERSATION_ID: conversation.id, + SystemVariableKey.USER_ID: user_session_id, + SystemVariableKey.DIALOGUE_COUNT: dialogue_count, + SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, + }, + ) + + self._task_state = WorkflowTaskState() + self._message_cycle_manager = MessageCycleManage( + application_generate_entity=application_generate_entity, task_state=self._task_state + ) + + self._application_generate_entity = application_generate_entity + self._workflow_id = workflow.id + self._workflow_features_dict = workflow.features_dict + self._conversation_id = conversation.id + self._conversation_mode = conversation.mode + self._message_id = message.id + self._message_created_at = int(message.created_at.timestamp()) + self._conversation_name_generate_thread: Thread | None = None + self._recorded_files: list[Mapping[str, Any]] = [] + self._workflow_run_id: str = "" + + def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: + """ + Process generate task pipeline. + :return: + """ + # start generate conversation name thread + self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name( + conversation_id=self._conversation_id, query=self._application_generate_entity.query + ) + + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) + + if self._base_task_pipeline._stream: + return self._to_stream_response(generator) + else: + return self._to_blocking_response(generator) + + def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse: + """ + Process blocking response. + :return: + """ + for stream_response in generator: + if isinstance(stream_response, ErrorStreamResponse): + raise stream_response.err + elif isinstance(stream_response, MessageEndStreamResponse): + extras = {} + if stream_response.metadata: + extras["metadata"] = stream_response.metadata + + return ChatbotAppBlockingResponse( + task_id=stream_response.task_id, + data=ChatbotAppBlockingResponse.Data( + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, + answer=self._task_state.answer, + created_at=self._message_created_at, + **extras, + ), + ) + else: + continue + + raise ValueError("queue listening stopped unexpectedly.") + + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[ChatbotAppStreamResponse, Any, None]: + """ + To stream response. + :return: + """ + for stream_response in generator: + yield ChatbotAppStreamResponse( + conversation_id=self._conversation_id, + message_id=self._message_id, + created_at=self._message_created_at, + stream_response=stream_response, + ) + + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): + if not publisher: + return None + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": + return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) + return None + + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: + tts_publisher = None + task_id = self._application_generate_entity.task_id + tenant_id = self._application_generate_entity.app_config.tenant_id + features_dict = self._workflow_features_dict + + if ( + features_dict.get("text_to_speech") + and features_dict["text_to_speech"].get("enabled") + and features_dict["text_to_speech"].get("autoPlay") == "enabled" + ): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) + + for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): + while True: + audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) + if audio_response: + yield audio_response + else: + break + yield response + + start_listener_time = time.time() + # timeout + while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: + try: + if not tts_publisher: + break + audio_trunk = tts_publisher.check_and_get_audio() + if audio_trunk is None: + # release cpu + # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) + time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME) + continue + if audio_trunk.status == "finish": + break + else: + start_listener_time = time.time() + yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) + except Exception as e: + logger.exception(f"Failed to listen audio message, task_id: {task_id}") + break + if tts_publisher: + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) + + def _process_stream_response( + self, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, + ) -> Generator[StreamResponse, None, None]: + """ + Process stream response. + :return: + """ + # init fake graph runtime state + graph_runtime_state: Optional[GraphRuntimeState] = None + + for queue_message in self._base_task_pipeline._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueuePingEvent): + yield self._base_task_pipeline._ping_stream_response() + elif isinstance(event, QueueErrorEvent): + with Session(db.engine, expire_on_commit=False) as session: + err = self._base_task_pipeline._handle_error( + event=event, session=session, message_id=self._message_id + ) + session.commit() + yield self._base_task_pipeline._error_to_stream_response(err) + break + elif isinstance(event, QueueWorkflowStartedEvent): + # override graph runtime state + graph_runtime_state = event.graph_runtime_state + + with Session(db.engine, expire_on_commit=False) as session: + # init workflow run + workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( + session=session, + workflow_id=self._workflow_id, + user_id=self._user_id, + created_by_role=self._created_by_role, + ) + self._workflow_run_id = workflow_run.id + message = self._get_message(session=session) + if not message: + raise ValueError(f"Message not found: {self._message_id}") + message.workflow_run_id = workflow_run.id + workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_start_resp + elif isinstance( + event, + QueueNodeRetryEvent, + ): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( + session=session, workflow_run=workflow_run, event=event + ) + node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() + + if node_retry_resp: + yield node_retry_resp + elif isinstance(event, QueueNodeStartedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( + session=session, workflow_run=workflow_run, event=event + ) + + node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() + + if node_start_resp: + yield node_start_resp + elif isinstance(event, QueueNodeSucceededEvent): + # Record files if it's an answer node or end node + if event.node_type in [NodeType.ANSWER, NodeType.END]: + self._recorded_files.extend( + self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {}) + ) + + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + session=session, event=event + ) + + node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() + + if node_finish_resp: + yield node_finish_resp + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + session=session, event=event + ) + + node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() + + if node_finish_resp: + yield node_finish_resp + elif isinstance(event, QueueParallelBranchRunStartedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_start_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + ) + + yield parallel_start_resp + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_finish_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + ) + + yield parallel_finish_resp + elif isinstance(event, QueueIterationStartEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_start_resp + elif isinstance(event, QueueIterationNextEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_next_resp + elif isinstance(event, QueueIterationCompletedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_finish_resp + elif isinstance(event, QueueWorkflowSucceededEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + if not graph_runtime_state: + raise ValueError("workflow run not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + conversation_id=self._conversation_id, + trace_manager=trace_manager, + ) + + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_finish_resp + self._base_task_pipeline._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE + ) + elif isinstance(event, QueueWorkflowPartialSuccessEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_finish_resp + self._base_task_pipeline._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE + ) + elif isinstance(event, QueueWorkflowFailedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED, + error=event.error, + conversation_id=self._conversation_id, + trace_manager=trace_manager, + exceptions_count=event.exceptions_count, + ) + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) + err = self._base_task_pipeline._handle_error( + event=err_event, session=session, message_id=self._message_id + ) + session.commit() + + yield workflow_finish_resp + yield self._base_task_pipeline._error_to_stream_response(err) + break + elif isinstance(event, QueueStopEvent): + if self._workflow_run_id and graph_runtime_state: + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.STOPPED, + error=event.get_stop_reason(), + conversation_id=self._conversation_id, + trace_manager=trace_manager, + ) + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + ) + # Save message + self._save_message(session=session, graph_runtime_state=graph_runtime_state) + session.commit() + + yield workflow_finish_resp + + yield self._message_end_to_stream_response() + break + elif isinstance(event, QueueRetrieverResourcesEvent): + self._message_cycle_manager._handle_retriever_resources(event) + + with Session(db.engine, expire_on_commit=False) as session: + message = self._get_message(session=session) + message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + session.commit() + elif isinstance(event, QueueAnnotationReplyEvent): + self._message_cycle_manager._handle_annotation_reply(event) + + with Session(db.engine, expire_on_commit=False) as session: + message = self._get_message(session=session) + message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + session.commit() + elif isinstance(event, QueueTextChunkEvent): + delta_text = event.text + if delta_text is None: + continue + + # handle output moderation chunk + should_direct_answer = self._handle_output_moderation_chunk(delta_text) + if should_direct_answer: + continue + + # only publish tts message at text chunk streaming + if tts_publisher: + tts_publisher.publish(queue_message) + + self._task_state.answer += delta_text + yield self._message_cycle_manager._message_to_stream_response( + answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector + ) + elif isinstance(event, QueueMessageReplaceEvent): + # published by moderation + yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text) + elif isinstance(event, QueueAdvancedChatMessageEndEvent): + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + + output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( + self._task_state.answer + ) + if output_moderation_answer: + self._task_state.answer = output_moderation_answer + yield self._message_cycle_manager._message_replace_to_stream_response( + answer=output_moderation_answer + ) + + # Save message + with Session(db.engine, expire_on_commit=False) as session: + self._save_message(session=session, graph_runtime_state=graph_runtime_state) + session.commit() + + yield self._message_end_to_stream_response() + else: + continue + + # publish None when task finished + if tts_publisher: + tts_publisher.publish(None) + + if self._conversation_name_generate_thread: + self._conversation_name_generate_thread.join() + + def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: + message = self._get_message(session=session) + message.answer = self._task_state.answer + message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at + message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + message_files = [ + MessageFile( + message_id=message.id, + type=file["type"], + transfer_method=file["transfer_method"], + url=file["remote_url"], + belongs_to="assistant", + upload_file_id=file["related_id"], + created_by_role=CreatedByRole.ACCOUNT + if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatedByRole.END_USER, + created_by=message.from_account_id or message.from_end_user_id or "", + ) + for file in self._recorded_files + ] + session.add_all(message_files) + + if graph_runtime_state and graph_runtime_state.llm_usage: + usage = graph_runtime_state.llm_usage + message.message_tokens = usage.prompt_tokens + message.message_unit_price = usage.prompt_unit_price + message.message_price_unit = usage.prompt_price_unit + message.answer_tokens = usage.completion_tokens + message.answer_unit_price = usage.completion_unit_price + message.answer_price_unit = usage.completion_price_unit + message.total_price = usage.total_price + message.currency = usage.currency + self._task_state.metadata["usage"] = jsonable_encoder(usage) + else: + self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) + message_was_created.send( + message, + application_generate_entity=self._application_generate_entity, + ) + + def _message_end_to_stream_response(self) -> MessageEndStreamResponse: + """ + Message end to stream response. + :return: + """ + extras = {} + if self._task_state.metadata: + extras["metadata"] = self._task_state.metadata.copy() + + if "annotation_reply" in extras["metadata"]: + del extras["metadata"]["annotation_reply"] + + return MessageEndStreamResponse( + task_id=self._application_generate_entity.task_id, + id=self._message_id, + files=self._recorded_files, + metadata=extras.get("metadata", {}), + ) + + def _handle_output_moderation_chunk(self, text: str) -> bool: + """ + Handle output moderation chunk. + :param text: text + :return: True if output moderation should direct output, otherwise False + """ + if self._base_task_pipeline._output_moderation_handler: + if self._base_task_pipeline._output_moderation_handler.should_direct_output(): + # stop subscribe new token when output moderation should direct output + self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() + self._base_task_pipeline._queue_manager.publish( + QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE + ) + + self._base_task_pipeline._queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE + ) + return True + else: + self._base_task_pipeline._output_moderation_handler.append_new_token(text) + + return False + + def _get_message(self, *, session: Session): + stmt = select(Message).where(Message.id == self._message_id) + message = session.scalar(stmt) + if not message: + raise ValueError(f"Message not found: {self._message_id}") + return message diff --git a/api/core/app/apps/agent_chat/__init__.py b/api/core/app/apps/agent_chat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..55b6ee510f228c3235dcae78f02262a377727689 --- /dev/null +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -0,0 +1,232 @@ +import uuid +from collections.abc import Mapping +from typing import Any, Optional + +from core.agent.entities import AgentEntity +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager +from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from core.entities.agent_entities import PlanningStrategy +from models.model import App, AppMode, AppModelConfig, Conversation + +OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] + + +class AgentChatAppConfig(EasyUIBasedAppConfig): + """ + Agent Chatbot App Config Entity. + """ + + agent: Optional[AgentEntity] = None + + +class AgentChatAppConfigManager(BaseAppConfigManager): + @classmethod + def get_app_config( + cls, + app_model: App, + app_model_config: AppModelConfig, + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None, + ) -> AgentChatAppConfig: + """ + Convert app model config to agent chat app config + :param app_model: app model + :param app_model_config: app model config + :param conversation: conversation + :param override_config_dict: app model config dict + :return: + """ + if override_config_dict: + config_from = EasyUIBasedAppModelConfigFrom.ARGS + elif conversation: + config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG + else: + config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + else: + config_dict = override_config_dict or {} + + app_mode = AppMode.value_of(app_model.mode) + app_config = AgentChatAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=app_mode, + app_model_config_from=config_from, + app_model_config_id=app_model_config.id, + app_model_config_dict=config_dict, + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + agent=AgentConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), + ) + + app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( + config=config_dict + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict: + """ + Validate for agent chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.AGENT_CHAT + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # agent_mode + config, current_related_config_keys = cls.validate_agent_mode_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + config + ) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # dataset configs + # dataset_query_variable + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config + + @classmethod + def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + """ + Validate agent_mode and set defaults for agent feature + + :param tenant_id: tenant ID + :param config: app model config args + """ + if not config.get("agent_mode"): + config["agent_mode"] = {"enabled": False, "tools": []} + + if not isinstance(config["agent_mode"], dict): + raise ValueError("agent_mode must be of object type") + + if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: + config["agent_mode"]["enabled"] = False + + if not isinstance(config["agent_mode"]["enabled"], bool): + raise ValueError("enabled in agent_mode must be of boolean type") + + if not config["agent_mode"].get("strategy"): + config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + + if config["agent_mode"]["strategy"] not in [ + member.value for member in list(PlanningStrategy.__members__.values()) + ]: + raise ValueError("strategy in agent_mode must be in the specified strategy list") + + if not config["agent_mode"].get("tools"): + config["agent_mode"]["tools"] = [] + + if not isinstance(config["agent_mode"]["tools"], list): + raise ValueError("tools in agent_mode must be a list of objects") + + for tool in config["agent_mode"]["tools"]: + key = list(tool.keys())[0] + if key in OLD_TOOLS: + # old style, use tool name as key + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + tool_item["enabled"] = False + + if not isinstance(tool_item["enabled"], bool): + raise ValueError("enabled in agent_mode.tools must be of boolean type") + + if key == "dataset": + if "id" not in tool_item: + raise ValueError("id is required in dataset") + + try: + uuid.UUID(tool_item["id"]) + except ValueError: + raise ValueError("id in dataset must be of UUID type") + + if not DatasetConfigManager.is_dataset_exists(tenant_id, tool_item["id"]): + raise ValueError("Dataset ID does not exist, please check your permission.") + else: + # latest style, use key-value pair + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + if "provider_type" not in tool: + raise ValueError("provider_type is required in agent_mode.tools") + if "provider_id" not in tool: + raise ValueError("provider_id is required in agent_mode.tools") + if "tool_name" not in tool: + raise ValueError("tool_name is required in agent_mode.tools") + if "tool_parameters" not in tool: + raise ValueError("tool_parameters is required in agent_mode.tools") + + return config, ["agent_mode"] diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..e7f622263eb084b2640db6b0da77b3ec842c3606 --- /dev/null +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -0,0 +1,256 @@ +import logging +import threading +import uuid +from collections.abc import Generator, Mapping +from typing import Any, Literal, Union, overload + +from flask import Flask, current_app +from pydantic import ValidationError + +from configs import dify_config +from constants import UUID_NIL +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager +from core.app.apps.agent_chat.app_runner import AgentChatAppRunner +from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom +from core.model_runtime.errors.invoke import InvokeAuthorizationError +from core.ops.ops_trace_manager import TraceQueueManager +from extensions.ext_database import db +from factories import file_factory +from models import Account, App, EndUser +from services.errors.message import MessageNotExistsError + +logger = logging.getLogger(__name__) + + +class AgentChatAppGenerator(MessageBasedAppGenerator): + @overload + def generate( + self, + *, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + *, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + *, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool, + ) -> Mapping[str, Any] | Generator[str, None, None]: ... + + def generate( + self, + *, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ): + """ + Generate App response. + + :param app_model: App + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not streaming: + raise ValueError("Agent Chat App does not support blocking mode") + + if not args.get("query"): + raise ValueError("query is required") + + query = args["query"] + if not isinstance(query, str): + raise ValueError("query must be a string") + + query = query.replace("\x00", "") + inputs = args["inputs"] + + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} + + # get conversation + conversation = None + if args.get("conversation_id"): + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user) + + # get app model config + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) + + # validate override model config + override_model_config_dict = None + if args.get("model_config"): + if invoke_from != InvokeFrom.DEBUGGER: + raise ValueError("Only in App debug mode can override model config") + + # validate config + override_model_config_dict = AgentChatAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, + config=args["model_config"], + ) + + # always enable retriever resource in debugger mode + override_model_config_dict["retriever_resource"] = {"enabled": True} + + # parse files + files = args.get("files") or [] + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + ) + else: + file_objs = [] + + # convert to app config + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict, + ) + + # get tracing instance + trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) + + # init application generate entity + application_generate_entity = AgentChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs + if conversation + else self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras=extras, + call_depth=0, + trace_manager=trace_manager, + ) + + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, + ) + + # new thread + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) + + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + # FIXME: Type hinting issue here, ignore it for now, will fix it later + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore + + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: AgentChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str, + ) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") + + # chatbot app + runner = AgentChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + ) + except GenerateTaskStoppedError: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except ValueError as e: + if dify_config.DEBUG: + logger.exception("Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.close() diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..c6705361407c892a4771863ca71d8d902cbb7c65 --- /dev/null +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -0,0 +1,333 @@ +import logging +from typing import cast + +from core.agent.cot_chat_agent_runner import CotChatAgentRunner +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner +from core.agent.entities import AgentEntity +from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity +from core.app.entities.queue_entities import QueueAnnotationReplyEvent +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage +from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.moderation.base import ModerationError +from core.tools.entities.tool_entities import ToolRuntimeVariablePool +from extensions.ext_database import db +from models.model import App, Conversation, Message, MessageAgentThought +from models.tools import ToolConversationVariables + +logger = logging.getLogger(__name__) + + +class AgentChatAppRunner(AppRunner): + """ + Agent Application Runner + """ + + def run( + self, + application_generate_entity: AgentChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + ) -> None: + """ + Run assistant application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param conversation: conversation + :param message: message + :return: + """ + app_config = application_generate_entity.app_config + app_config = cast(AgentChatAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + if not app_record: + raise ValueError("App not found") + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + # Pre-calculate the number of tokens of the prompt messages, + # and return the rest number of tokens by model context token size limit and max token size limit. + # If the rest number of tokens is not enough, raise exception. + # Include: prompt template, inputs, query(optional), files(optional) + # Not Include: memory, external data, dataset context + self.get_pre_calculate_rest_tokens( + app_record=app_record, + model_config=application_generate_entity.model_conf, + prompt_template_entity=app_config.prompt_template, + inputs=inputs, + files=files, + query=query, + ) + + memory = None + if application_generate_entity.conversation_id: + # get memory of conversation (read-only) + model_instance = ModelInstance( + provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, + model=application_generate_entity.model_conf.model, + ) + + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) + + # organize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + # memory(optional) + prompt_messages, _ = self.organize_prompt_messages( + app_record=app_record, + model_config=application_generate_entity.model_conf, + prompt_template_entity=app_config.prompt_template, + inputs=inputs, + files=files, + query=query, + memory=memory, + ) + + # moderation + try: + # process sensitive_word_avoidance + _, inputs, query = self.moderation_for_inputs( + app_id=app_record.id, + tenant_id=app_config.tenant_id, + app_generate_entity=application_generate_entity, + inputs=inputs, + query=query, + message_id=message.id, + ) + except ModerationError as e: + self.direct_output( + queue_manager=queue_manager, + app_generate_entity=application_generate_entity, + prompt_messages=prompt_messages, + text=str(e), + stream=application_generate_entity.stream, + ) + return + + if query: + # annotation reply + annotation_reply = self.query_app_annotations_to_reply( + app_record=app_record, + message=message, + query=query, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + ) + + if annotation_reply: + queue_manager.publish( + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), + PublishFrom.APPLICATION_MANAGER, + ) + + self.direct_output( + queue_manager=queue_manager, + app_generate_entity=application_generate_entity, + prompt_messages=prompt_messages, + text=annotation_reply.content, + stream=application_generate_entity.stream, + ) + return + + # fill in variable inputs from external data tools if exists + external_data_tools = app_config.external_data_variables + if external_data_tools: + inputs = self.fill_in_inputs_from_external_data_tools( + tenant_id=app_record.tenant_id, + app_id=app_record.id, + external_data_tools=external_data_tools, + inputs=inputs, + query=query, + ) + + # reorganize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + # memory(optional), external data, dataset context(optional) + prompt_messages, _ = self.organize_prompt_messages( + app_record=app_record, + model_config=application_generate_entity.model_conf, + prompt_template_entity=app_config.prompt_template, + inputs=inputs, + files=files, + query=query, + memory=memory, + ) + + # check hosting moderation + hosting_moderation_result = self.check_hosting_moderation( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + prompt_messages=prompt_messages, + ) + + if hosting_moderation_result: + return + + agent_entity = app_config.agent + if not agent_entity: + raise ValueError("Agent entity not found") + + # load tool variables + tool_conversation_variables = self._load_tool_variables( + conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id + ) + + # convert db variables to tool variables + tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) + + # init model instance + model_instance = ModelInstance( + provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, + model=application_generate_entity.model_conf.model, + ) + prompt_message, _ = self.organize_prompt_messages( + app_record=app_record, + model_config=application_generate_entity.model_conf, + prompt_template_entity=app_config.prompt_template, + inputs=inputs, + files=files, + query=query, + memory=memory, + ) + + # change function call strategy based on LLM model + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + if not model_schema: + raise ValueError("Model schema not found") + + if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): + agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING + + conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + if conversation_result is None: + raise ValueError("Conversation not found") + message_result = db.session.query(Message).filter(Message.id == message.id).first() + if message_result is None: + raise ValueError("Message not found") + db.session.close() + + runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner] + # start agent runner + if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: + # check LLM mode + if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: + runner_cls = CotChatAgentRunner + elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value: + runner_cls = CotCompletionAgentRunner + else: + raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}") + elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: + runner_cls = FunctionCallAgentRunner + else: + raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}") + + runner = runner_cls( + tenant_id=app_config.tenant_id, + application_generate_entity=application_generate_entity, + conversation=conversation_result, + app_config=app_config, + model_config=application_generate_entity.model_conf, + config=agent_entity, + queue_manager=queue_manager, + message=message_result, + user_id=application_generate_entity.user_id, + memory=memory, + prompt_messages=prompt_message, + variables_pool=tool_variables, + db_variables=tool_conversation_variables, + model_instance=model_instance, + ) + + invoke_result = runner.run( + message=message, + query=query, + inputs=inputs, + ) + + # handle invoke result + self._handle_invoke_result( + invoke_result=invoke_result, + queue_manager=queue_manager, + stream=application_generate_entity.stream, + agent=True, + ) + + def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables: + """ + load tool variables from database + """ + tool_variables: ToolConversationVariables | None = ( + db.session.query(ToolConversationVariables) + .filter( + ToolConversationVariables.conversation_id == conversation_id, + ToolConversationVariables.tenant_id == tenant_id, + ) + .first() + ) + + if tool_variables: + # save tool variables to session, so that we can update it later + db.session.add(tool_variables) + else: + # create new tool variables + tool_variables = ToolConversationVariables( + conversation_id=conversation_id, + user_id=user_id, + tenant_id=tenant_id, + variables_str="[]", + ) + db.session.add(tool_variables) + db.session.commit() + + return tool_variables + + def _convert_db_variables_to_tool_variables( + self, db_variables: ToolConversationVariables + ) -> ToolRuntimeVariablePool: + """ + convert db variables to tool variables + """ + return ToolRuntimeVariablePool( + **{ + "conversation_id": db_variables.conversation_id, + "user_id": db_variables.user_id, + "tenant_id": db_variables.tenant_id, + "pool": db_variables.variables, + } + ) + + def _get_usage_of_all_agent_thoughts( + self, model_config: ModelConfigWithCredentialsEntity, message: Message + ) -> LLMUsage: + """ + Get usage of all agent thoughts + :param model_config: model config + :param message: message + :return: + """ + agent_thoughts = ( + db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all() + ) + + all_message_tokens = 0 + all_answer_tokens = 0 + for agent_thought in agent_thoughts: + all_message_tokens += agent_thought.message_tokens + all_answer_tokens += agent_thought.answer_tokens + + model_type_instance = model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + return model_type_instance._calc_response_usage( + model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens + ) diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..ce331d904cc82649dcb2b0537c99249643e188a4 --- /dev/null +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -0,0 +1,121 @@ +import json +from collections.abc import Generator +from typing import cast + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) + + +class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): + _blocking_response_type = ChatbotAppBlockingResponse + + @classmethod + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] + """ + Convert blocking full response. + :param blocking_response: blocking response + :return: + """ + response = { + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, + } + + return response + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] + """ + Convert blocking simple response. + :param blocking_response: blocking response + :return: + """ + response = cls.convert_blocking_full_response(blocking_response) + + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) + + return response + + @classmethod + def convert_stream_full_response( # type: ignore[override] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], + ) -> Generator[str, None, None]: + """ + Convert stream full response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(ChatbotAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, + } + + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + else: + response_chunk.update(sub_stream_response.to_dict()) + yield json.dumps(response_chunk) + + @classmethod + def convert_stream_simple_response( # type: ignore[override] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], + ) -> Generator[str, None, None]: + """ + Convert stream simple response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(ChatbotAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, + } + + if isinstance(sub_stream_response, MessageEndStreamResponse): + sub_stream_response_dict = sub_stream_response.to_dict() + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) + response_chunk.update(sub_stream_response_dict) + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + else: + response_chunk.update(sub_stream_response.to_dict()) + + yield json.dumps(response_chunk) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..be4027132ba903ceda6970e1c71dc9f3dd3d7e5f --- /dev/null +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -0,0 +1,140 @@ +import logging +from abc import ABC, abstractmethod +from collections.abc import Generator, Mapping +from typing import Any, Union + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.errors.invoke import InvokeError + + +class AppGenerateResponseConverter(ABC): + _blocking_response_type: type[AppBlockingResponse] + + @classmethod + def convert( + cls, + response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], + invoke_from: InvokeFrom, + ) -> Mapping[str, Any] | Generator[str, None, None]: + if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: + if isinstance(response, AppBlockingResponse): + return cls.convert_blocking_full_response(response) + else: + + def _generate_full_response() -> Generator[str, Any, None]: + for chunk in cls.convert_stream_full_response(response): + if chunk == "ping": + yield f"event: {chunk}\n\n" + else: + yield f"data: {chunk}\n\n" + + return _generate_full_response() + else: + if isinstance(response, AppBlockingResponse): + return cls.convert_blocking_simple_response(response) + else: + + def _generate_simple_response() -> Generator[str, Any, None]: + for chunk in cls.convert_stream_simple_response(response): + if chunk == "ping": + yield f"event: {chunk}\n\n" + else: + yield f"data: {chunk}\n\n" + + return _generate_simple_response() + + @classmethod + @abstractmethod + def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + raise NotImplementedError + + @classmethod + @abstractmethod + def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + raise NotImplementedError + + @classmethod + @abstractmethod + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, None, None]: + raise NotImplementedError + + @classmethod + @abstractmethod + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, None, None]: + raise NotImplementedError + + @classmethod + def _get_simple_metadata(cls, metadata: dict[str, Any]): + """ + Get simple metadata. + :param metadata: metadata + :return: + """ + # show_retrieve_source + updated_resources = [] + if "retriever_resources" in metadata: + for resource in metadata["retriever_resources"]: + updated_resources.append( + { + "segment_id": resource.get("segment_id", ""), + "position": resource["position"], + "document_name": resource["document_name"], + "score": resource["score"], + "content": resource["content"], + } + ) + metadata["retriever_resources"] = updated_resources + + # show annotation reply + if "annotation_reply" in metadata: + del metadata["annotation_reply"] + + # show usage + if "usage" in metadata: + del metadata["usage"] + + return metadata + + @classmethod + def _error_to_stream_response(cls, e: Exception) -> dict: + """ + Error to stream response. + :param e: exception + :return: + """ + error_responses = { + ValueError: {"code": "invalid_param", "status": 400}, + ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400}, + QuotaExceededError: { + "code": "provider_quota_exceeded", + "message": "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + "status": 400, + }, + ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400}, + InvokeError: {"code": "completion_request_error", "status": 400}, + } + + # Determine the response based on the type of exception + data = None + for k, v in error_responses.items(): + if isinstance(e, k): + data = v + + if data: + data.setdefault("message", getattr(e, "description", str(e))) + else: + logging.error(e) + data = { + "code": "internal_server_error", + "message": "Internal Server Error, please contact support.", + "status": 500, + } + + return data diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..85b7aced557c5ec694ad6683af26d10726132304 --- /dev/null +++ b/api/core/app/apps/base_app_generator.py @@ -0,0 +1,140 @@ +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, Optional + +from core.app.app_config.entities import VariableEntityType +from core.file import File, FileUploadConfig +from factories import file_factory + +if TYPE_CHECKING: + from core.app.app_config.entities import VariableEntity + + +class BaseAppGenerator: + def _prepare_user_inputs( + self, + *, + user_inputs: Optional[Mapping[str, Any]], + variables: Sequence["VariableEntity"], + tenant_id: str, + ) -> Mapping[str, Any]: + user_inputs = user_inputs or {} + # Filter input variables from form configuration, handle required fields, default values, and option values + user_inputs = { + var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var) + for var in variables + } + user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()} + # Convert files in inputs to File + entity_dictionary = {item.variable: item for item in variables} + # Convert single file to File + files_inputs = { + k: file_factory.build_from_mapping( + mapping=v, + tenant_id=tenant_id, + config=FileUploadConfig( + allowed_file_types=entity_dictionary[k].allowed_file_types, + allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, + allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, + ), + ) + for k, v in user_inputs.items() + if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE + } + # Convert list of files to File + file_list_inputs = { + k: file_factory.build_from_mappings( + mappings=v, + tenant_id=tenant_id, + config=FileUploadConfig( + allowed_file_types=entity_dictionary[k].allowed_file_types, + allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, + allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, + ), + ) + for k, v in user_inputs.items() + if isinstance(v, list) + # Ensure skip List + and all(isinstance(item, dict) for item in v) + and entity_dictionary[k].type == VariableEntityType.FILE_LIST + } + # Merge all inputs + user_inputs = {**user_inputs, **files_inputs, **file_list_inputs} + + # Check if all files are converted to File + if any(filter(lambda v: isinstance(v, dict), user_inputs.values())): + raise ValueError("Invalid input type") + if any( + filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values())) + ): + raise ValueError("Invalid input type") + + return user_inputs + + def _validate_inputs( + self, + *, + variable_entity: "VariableEntity", + value: Any, + ): + if value is None: + if variable_entity.required: + raise ValueError(f"{variable_entity.variable} is required in input form") + return value + + if variable_entity.type in { + VariableEntityType.TEXT_INPUT, + VariableEntityType.SELECT, + VariableEntityType.PARAGRAPH, + } and not isinstance(value, str): + raise ValueError( + f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string" + ) + + if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str): + # handle empty string case + if not value.strip(): + return None + # may raise ValueError if user_input_value is not a valid number + try: + if "." in value: + return float(value) + else: + return int(value) + except ValueError: + raise ValueError(f"{variable_entity.variable} in input form must be a valid number") + + match variable_entity.type: + case VariableEntityType.SELECT: + if value not in variable_entity.options: + raise ValueError( + f"{variable_entity.variable} in input form must be one of the following: " + f"{variable_entity.options}" + ) + case VariableEntityType.TEXT_INPUT | VariableEntityType.PARAGRAPH: + if variable_entity.max_length and len(value) > variable_entity.max_length: + raise ValueError( + f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} " + "characters" + ) + case VariableEntityType.FILE: + if not isinstance(value, dict) and not isinstance(value, File): + raise ValueError(f"{variable_entity.variable} in input form must be a file") + case VariableEntityType.FILE_LIST: + # if number of files exceeds the limit, raise ValueError + if not ( + isinstance(value, list) + and (all(isinstance(item, dict) for item in value) or all(isinstance(item, File) for item in value)) + ): + raise ValueError(f"{variable_entity.variable} in input form must be a list of files") + + if variable_entity.max_length and len(value) > variable_entity.max_length: + raise ValueError( + f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files" + ) + + return value + + def _sanitize_value(self, value: Any) -> Any: + if isinstance(value, str): + return value.replace("\x00", "") + return value diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ce2222a14e77e5fee27e160af8a13d4ae39fe3eb --- /dev/null +++ b/api/core/app/apps/base_app_queue_manager.py @@ -0,0 +1,175 @@ +import queue +import time +from abc import abstractmethod +from enum import Enum +from typing import Any + +from sqlalchemy.orm import DeclarativeMeta + +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + AppQueueEvent, + MessageQueueMessage, + QueueErrorEvent, + QueuePingEvent, + QueueStopEvent, + WorkflowQueueMessage, +) +from extensions.ext_redis import redis_client + + +class PublishFrom(Enum): + APPLICATION_MANAGER = 1 + TASK_PIPELINE = 2 + + +class AppQueueManager: + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: + if not user_id: + raise ValueError("user is required") + + self._task_id = task_id + self._user_id = user_id + self._invoke_from = invoke_from + + user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" + redis_client.setex( + AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" + ) + + q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() + + self._q = q + + def listen(self): + """ + Listen to queue + :return: + """ + # wait for APP_MAX_EXECUTION_TIME seconds to stop listen + listen_timeout = dify_config.APP_MAX_EXECUTION_TIME + start_time = time.time() + last_ping_time: int | float = 0 + while True: + try: + message = self._q.get(timeout=1) + if message is None: + break + + yield message + except queue.Empty: + continue + finally: + elapsed_time = time.time() - start_time + if elapsed_time >= listen_timeout or self._is_stopped(): + # publish two messages to make sure the client can receive the stop signal + # and stop listening after the stop signal processed + self.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE + ) + + if elapsed_time // 10 > last_ping_time: + self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) + last_ping_time = elapsed_time // 10 + + def stop_listen(self) -> None: + """ + Stop listen to queue + :return: + """ + self._q.put(None) + + def publish_error(self, e, pub_from: PublishFrom) -> None: + """ + Publish error + :param e: error + :param pub_from: publish from + :return: + """ + self.publish(QueueErrorEvent(error=e), pub_from) + + def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + self._check_for_sqlalchemy_models(event.model_dump()) + self._publish(event, pub_from) + + @abstractmethod + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + raise NotImplementedError + + @classmethod + def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: + """ + Set task stop flag + :return: + """ + result = redis_client.get(cls._generate_task_belong_cache_key(task_id)) + if result is None: + return + + user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" + if result.decode("utf-8") != f"{user_prefix}-{user_id}": + return + + stopped_cache_key = cls._generate_stopped_cache_key(task_id) + redis_client.setex(stopped_cache_key, 600, 1) + + def _is_stopped(self) -> bool: + """ + Check if task is stopped + :return: + """ + stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id) + result = redis_client.get(stopped_cache_key) + if result is not None: + return True + + return False + + @classmethod + def _generate_task_belong_cache_key(cls, task_id: str) -> str: + """ + Generate task belong cache key + :param task_id: task id + :return: + """ + return f"generate_task_belong:{task_id}" + + @classmethod + def _generate_stopped_cache_key(cls, task_id: str) -> str: + """ + Generate stopped cache key + :param task_id: task id + :return: + """ + return f"generate_task_stopped:{task_id}" + + def _check_for_sqlalchemy_models(self, data: Any): + # from entity to dict or list + if isinstance(data, dict): + for key, value in data.items(): + self._check_for_sqlalchemy_models(value) + elif isinstance(data, list): + for item in data: + self._check_for_sqlalchemy_models(item) + else: + if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"): + raise TypeError( + "Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed." + ) + + +class GenerateTaskStoppedError(Exception): + pass diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..07a248d77aee860272ba0600d4f41b85b3a5c244 --- /dev/null +++ b/api/core/app/apps/base_app_runner.py @@ -0,0 +1,435 @@ +import time +from collections.abc import Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Optional, Union + +from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import ( + AppGenerateEntity, + EasyUIBasedAppGenerateEntity, + InvokeFrom, + ModelConfigWithCredentialsEntity, +) +from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature +from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature +from core.external_data_tool.external_data_fetch import ExternalDataFetch +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage +from core.model_runtime.entities.model_entities import ModelPropertyKey +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.moderation.input_moderation import InputModeration +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform +from models.model import App, AppMode, Message, MessageAnnotation + +if TYPE_CHECKING: + from core.file.models import File + + +class AppRunner: + def get_pre_calculate_rest_tokens( + self, + app_record: App, + model_config: ModelConfigWithCredentialsEntity, + prompt_template_entity: PromptTemplateEntity, + inputs: Mapping[str, str], + files: Sequence["File"], + query: Optional[str] = None, + ) -> int: + """ + Get pre calculate rest tokens + :param app_record: app record + :param model_config: model config entity + :param prompt_template_entity: prompt template entity + :param inputs: inputs + :param files: files + :param query: query + :return: + """ + # Invoke model + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template or "") + ) or 0 + + if model_context_tokens is None: + return -1 + + if max_tokens is None: + max_tokens = 0 + + # get prompt messages without memory and context + prompt_messages, stop = self.organize_prompt_messages( + app_record=app_record, + model_config=model_config, + prompt_template_entity=prompt_template_entity, + inputs=inputs, + files=files, + query=query, + ) + + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens + if rest_tokens < 0: + raise InvokeBadRequestError( + "Query or prefix prompt is too long, you can reduce the prefix prompt, " + "or shrink the max token, or switch to a llm with a larger token limit size." + ) + + return rest_tokens + + def recalc_llm_max_tokens( + self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage] + ): + # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template or "") + ) or 0 + + if model_context_tokens is None: + return -1 + + if max_tokens is None: + max_tokens = 0 + + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + if prompt_tokens + max_tokens > model_context_tokens: + max_tokens = max(model_context_tokens - prompt_tokens, 16) + + for parameter_rule in model_config.model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + model_config.parameters[parameter_rule.name] = max_tokens + + def organize_prompt_messages( + self, + app_record: App, + model_config: ModelConfigWithCredentialsEntity, + prompt_template_entity: PromptTemplateEntity, + inputs: Mapping[str, str], + files: Sequence["File"], + query: Optional[str] = None, + context: Optional[str] = None, + memory: Optional[TokenBufferMemory] = None, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: + """ + Organize prompt messages + :param context: + :param app_record: app record + :param model_config: model config entity + :param prompt_template_entity: prompt template entity + :param inputs: inputs + :param files: files + :param query: query + :param memory: memory + :return: + """ + # get prompt without memory and context + if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + prompt_transform: Union[SimplePromptTransform, AdvancedPromptTransform] + prompt_transform = SimplePromptTransform() + prompt_messages, stop = prompt_transform.get_prompt( + app_mode=AppMode.value_of(app_record.mode), + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=query or "", + files=files, + context=context, + memory=memory, + model_config=model_config, + ) + else: + memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) + + model_mode = ModelMode.value_of(model_config.mode) + prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] + if model_mode == ModelMode.COMPLETION: + advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template + if not advanced_completion_prompt_template: + raise InvokeBadRequestError("Advanced completion prompt template is required.") + prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt) + + if advanced_completion_prompt_template.role_prefix: + memory_config.role_prefix = MemoryConfig.RolePrefix( + user=advanced_completion_prompt_template.role_prefix.user, + assistant=advanced_completion_prompt_template.role_prefix.assistant, + ) + else: + if not prompt_template_entity.advanced_chat_prompt_template: + raise InvokeBadRequestError("Advanced chat prompt template is required.") + prompt_template = [] + for message in prompt_template_entity.advanced_chat_prompt_template.messages: + prompt_template.append(ChatModelMessage(text=message.text, role=message.role)) + + prompt_transform = AdvancedPromptTransform() + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, + inputs=inputs, + query=query or "", + files=files, + context=context, + memory_config=memory_config, + memory=memory, + model_config=model_config, + ) + stop = model_config.stop + + return prompt_messages, stop + + def direct_output( + self, + queue_manager: AppQueueManager, + app_generate_entity: EasyUIBasedAppGenerateEntity, + prompt_messages: list, + text: str, + stream: bool, + usage: Optional[LLMUsage] = None, + ) -> None: + """ + Direct output + :param queue_manager: application queue manager + :param app_generate_entity: app generate entity + :param prompt_messages: prompt messages + :param text: text + :param stream: stream + :param usage: usage + :return: + """ + if stream: + index = 0 + for token in text: + chunk = LLMResultChunk( + model=app_generate_entity.model_conf.model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)), + ) + + queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER) + index += 1 + time.sleep(0.01) + + queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=app_generate_entity.model_conf.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=text), + usage=usage or LLMUsage.empty_usage(), + ), + ), + PublishFrom.APPLICATION_MANAGER, + ) + + def _handle_invoke_result( + self, + invoke_result: Union[LLMResult, Generator[Any, None, None]], + queue_manager: AppQueueManager, + stream: bool, + agent: bool = False, + ) -> None: + """ + Handle invoke result + :param invoke_result: invoke result + :param queue_manager: application queue manager + :param stream: stream + :param agent: agent + :return: + """ + if not stream and isinstance(invoke_result, LLMResult): + self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + elif stream and isinstance(invoke_result, Generator): + self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + else: + raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") + + def _handle_invoke_result_direct( + self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool + ) -> None: + """ + Handle invoke result direct + :param invoke_result: invoke result + :param queue_manager: application queue manager + :param agent: agent + :return: + """ + queue_manager.publish( + QueueMessageEndEvent( + llm_result=invoke_result, + ), + PublishFrom.APPLICATION_MANAGER, + ) + + def _handle_invoke_result_stream( + self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool + ) -> None: + """ + Handle invoke result + :param invoke_result: invoke result + :param queue_manager: application queue manager + :param agent: agent + :return: + """ + model: str = "" + prompt_messages: list[PromptMessage] = [] + text = "" + usage = None + for result in invoke_result: + if not agent: + queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) + else: + queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) + + text += result.delta.message.content + + if not model: + model = result.model + + if not prompt_messages: + prompt_messages = result.prompt_messages + + if result.delta.usage: + usage = result.delta.usage + + if not usage: + usage = LLMUsage.empty_usage() + + llm_result = LLMResult( + model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage + ) + + queue_manager.publish( + QueueMessageEndEvent( + llm_result=llm_result, + ), + PublishFrom.APPLICATION_MANAGER, + ) + + def moderation_for_inputs( + self, + *, + app_id: str, + tenant_id: str, + app_generate_entity: AppGenerateEntity, + inputs: Mapping[str, Any], + query: str | None = None, + message_id: str, + ) -> tuple[bool, Mapping[str, Any], str]: + """ + Process sensitive_word_avoidance. + :param app_id: app id + :param tenant_id: tenant id + :param app_generate_entity: app generate entity + :param inputs: inputs + :param query: query + :param message_id: message id + :return: + """ + moderation_feature = InputModeration() + return moderation_feature.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_generate_entity.app_config, + inputs=dict(inputs), + query=query or "", + message_id=message_id, + trace_manager=app_generate_entity.trace_manager, + ) + + def check_hosting_moderation( + self, + application_generate_entity: EasyUIBasedAppGenerateEntity, + queue_manager: AppQueueManager, + prompt_messages: list[PromptMessage], + ) -> bool: + """ + Check hosting moderation + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param prompt_messages: prompt messages + :return: + """ + hosting_moderation_feature = HostingModerationFeature() + moderation_result = hosting_moderation_feature.check( + application_generate_entity=application_generate_entity, prompt_messages=prompt_messages + ) + + if moderation_result: + self.direct_output( + queue_manager=queue_manager, + app_generate_entity=application_generate_entity, + prompt_messages=prompt_messages, + text="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.", + stream=application_generate_entity.stream, + ) + + return moderation_result + + def fill_in_inputs_from_external_data_tools( + self, + tenant_id: str, + app_id: str, + external_data_tools: list[ExternalDataVariableEntity], + inputs: Mapping[str, Any], + query: str, + ) -> Mapping[str, Any]: + """ + Fill in variable inputs from external data tools if exists. + + :param tenant_id: workspace id + :param app_id: app id + :param external_data_tools: external data tools configs + :param inputs: the inputs + :param query: the query + :return: the filled inputs + """ + external_data_fetch_feature = ExternalDataFetch() + return external_data_fetch_feature.fetch( + tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query + ) + + def query_app_annotations_to_reply( + self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom + ) -> Optional[MessageAnnotation]: + """ + Query app annotations to reply + :param app_record: app record + :param message: message + :param query: query + :param user_id: user id + :param invoke_from: invoke from + :return: + """ + annotation_reply_feature = AnnotationReplyFeature() + return annotation_reply_feature.query( + app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from + ) diff --git a/api/core/app/apps/chat/__init__.py b/api/core/app/apps/chat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..96dc7dda79af6d93888b035df88dac0c564ca6b5 --- /dev/null +++ b/api/core/app/apps/chat/app_config_manager.py @@ -0,0 +1,150 @@ +from typing import Optional + +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager +from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from models.model import App, AppMode, AppModelConfig, Conversation + + +class ChatAppConfig(EasyUIBasedAppConfig): + """ + Chatbot App Config Entity. + """ + + pass + + +class ChatAppConfigManager(BaseAppConfigManager): + @classmethod + def get_app_config( + cls, + app_model: App, + app_model_config: AppModelConfig, + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None, + ) -> ChatAppConfig: + """ + Convert app model config to chat app config + :param app_model: app model + :param app_model_config: app model config + :param conversation: conversation + :param override_config_dict: app model config dict + :return: + """ + if override_config_dict: + config_from = EasyUIBasedAppModelConfigFrom.ARGS + elif conversation: + config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG + else: + config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + else: + if not override_config_dict: + raise Exception("override_config_dict is required when config_from is ARGS") + + config_dict = override_config_dict + + app_mode = AppMode.value_of(app_model.mode) + app_config = ChatAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=app_mode, + app_model_config_from=config_from, + app_model_config_id=app_model_config.id, + app_model_config_dict=config_dict, + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), + ) + + app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( + config=config_dict + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.CHAT + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # dataset_query_variable + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + config + ) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..dc7cf3667a01755439cf4a0cd2f289bbacf5f246 --- /dev/null +++ b/api/core/app/apps/chat/app_generator.py @@ -0,0 +1,248 @@ +import logging +import threading +import uuid +from collections.abc import Generator, Mapping +from typing import Any, Literal, Union, overload + +from flask import Flask, current_app +from pydantic import ValidationError + +from configs import dify_config +from constants import UUID_NIL +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from core.app.apps.chat.app_runner import ChatAppRunner +from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom +from core.model_runtime.errors.invoke import InvokeAuthorizationError +from core.ops.ops_trace_manager import TraceQueueManager +from extensions.ext_database import db +from factories import file_factory +from models.account import Account +from models.model import App, EndUser +from services.errors.message import MessageNotExistsError + +logger = logging.getLogger(__name__) + + +class ChatAppGenerator(MessageBasedAppGenerator): + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool, + ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... + + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ): + """ + Generate App response. + + :param app_model: App + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not args.get("query"): + raise ValueError("query is required") + + query = args["query"] + if not isinstance(query, str): + raise ValueError("query must be a string") + + query = query.replace("\x00", "") + inputs = args["inputs"] + + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} + + # get conversation + conversation = None + if args.get("conversation_id"): + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user) + + # get app model config + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) + + # validate override model config + override_model_config_dict = None + if args.get("model_config"): + if invoke_from != InvokeFrom.DEBUGGER: + raise ValueError("Only in App debug mode can override model config") + + # validate config + override_model_config_dict = ChatAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, config=args.get("model_config", {}) + ) + + # always enable retriever resource in debugger mode + override_model_config_dict["retriever_resource"] = {"enabled": True} + + # parse files + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + ) + else: + file_objs = [] + + # convert to app config + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict, + ) + + # get tracing instance + trace_manager = TraceQueueManager(app_id=app_model.id) + + # init application generate entity + application_generate_entity = ChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs + if conversation + else self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + invoke_from=invoke_from, + extras=extras, + trace_manager=trace_manager, + stream=streaming, + ) + + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, + ) + + # new thread + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) + + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + + return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str, + ) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") + + # chatbot app + runner = ChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + ) + except GenerateTaskStoppedError: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except ValueError as e: + if dify_config.DEBUG: + logger.exception("Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.close() diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..425f1ab7ef4cc6836241984f8700ff852e6b435e --- /dev/null +++ b/api/core/app/apps/chat/app_runner.py @@ -0,0 +1,219 @@ +import logging +from typing import cast + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.apps.chat.app_config_manager import ChatAppConfig +from core.app.entities.app_invoke_entities import ( + ChatAppGenerateEntity, +) +from core.app.entities.queue_entities import QueueAnnotationReplyEvent +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.moderation.base import ModerationError +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from extensions.ext_database import db +from models.model import App, Conversation, Message + +logger = logging.getLogger(__name__) + + +class ChatAppRunner(AppRunner): + """ + Chat Application Runner + """ + + def run( + self, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + ) -> None: + """ + Run application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param conversation: conversation + :param message: message + :return: + """ + app_config = application_generate_entity.app_config + app_config = cast(ChatAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + if not app_record: + raise ValueError("App not found") + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + # Pre-calculate the number of tokens of the prompt messages, + # and return the rest number of tokens by model context token size limit and max token size limit. + # If the rest number of tokens is not enough, raise exception. + # Include: prompt template, inputs, query(optional), files(optional) + # Not Include: memory, external data, dataset context + self.get_pre_calculate_rest_tokens( + app_record=app_record, + model_config=application_generate_entity.model_conf, + prompt_template_entity=app_config.prompt_template, + inputs=inputs, + files=files, + query=query, + ) + + memory = None + if application_generate_entity.conversation_id: + # get memory of conversation (read-only) + model_instance = ModelInstance( + provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, + model=application_generate_entity.model_conf.model, + ) + + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) + + # organize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + # memory(optional) + prompt_messages, stop = self.organize_prompt_messages( + app_record=app_record, + model_config=application_generate_entity.model_conf, + prompt_template_entity=app_config.prompt_template, + inputs=inputs, + files=files, + query=query, + memory=memory, + ) + + # moderation + try: + # process sensitive_word_avoidance + _, inputs, query = self.moderation_for_inputs( + app_id=app_record.id, + tenant_id=app_config.tenant_id, + app_generate_entity=application_generate_entity, + inputs=inputs, + query=query, + message_id=message.id, + ) + except ModerationError as e: + self.direct_output( + queue_manager=queue_manager, + app_generate_entity=application_generate_entity, + prompt_messages=prompt_messages, + text=str(e), + stream=application_generate_entity.stream, + ) + return + + if query: + # annotation reply + annotation_reply = self.query_app_annotations_to_reply( + app_record=app_record, + message=message, + query=query, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + ) + + if annotation_reply: + queue_manager.publish( + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), + PublishFrom.APPLICATION_MANAGER, + ) + + self.direct_output( + queue_manager=queue_manager, + app_generate_entity=application_generate_entity, + prompt_messages=prompt_messages, + text=annotation_reply.content, + stream=application_generate_entity.stream, + ) + return + + # fill in variable inputs from external data tools if exists + external_data_tools = app_config.external_data_variables + if external_data_tools: + inputs = self.fill_in_inputs_from_external_data_tools( + tenant_id=app_record.tenant_id, + app_id=app_record.id, + external_data_tools=external_data_tools, + inputs=inputs, + query=query, + ) + + # get context from datasets + context = None + if app_config.dataset and app_config.dataset.dataset_ids: + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager, + app_record.id, + message.id, + application_generate_entity.user_id, + application_generate_entity.invoke_from, + ) + + dataset_retrieval = DatasetRetrieval(application_generate_entity) + context = dataset_retrieval.retrieve( + app_id=app_record.id, + user_id=application_generate_entity.user_id, + tenant_id=app_record.tenant_id, + model_config=application_generate_entity.model_conf, + config=app_config.dataset, + query=query, + invoke_from=application_generate_entity.invoke_from, + show_retrieve_source=app_config.additional_features.show_retrieve_source, + hit_callback=hit_callback, + memory=memory, + message_id=message.id, + ) + + # reorganize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + # memory(optional), external data, dataset context(optional) + prompt_messages, stop = self.organize_prompt_messages( + app_record=app_record, + model_config=application_generate_entity.model_conf, + prompt_template_entity=app_config.prompt_template, + inputs=inputs, + files=files, + query=query, + context=context, + memory=memory, + ) + + # check hosting moderation + hosting_moderation_result = self.check_hosting_moderation( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + prompt_messages=prompt_messages, + ) + + if hosting_moderation_result: + return + + # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit + self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages) + + # Invoke model + model_instance = ModelInstance( + provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, + model=application_generate_entity.model_conf.model, + ) + + db.session.close() + + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=application_generate_entity.model_conf.parameters, + stop=stop, + stream=application_generate_entity.stream, + user=application_generate_entity.user_id, + ) + + # handle invoke result + self._handle_invoke_result( + invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream + ) diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..9024c3a98273d12c326fb4e6beaab4f732eb8aa9 --- /dev/null +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -0,0 +1,121 @@ +import json +from collections.abc import Generator +from typing import cast + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) + + +class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): + _blocking_response_type = ChatbotAppBlockingResponse + + @classmethod + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] + """ + Convert blocking full response. + :param blocking_response: blocking response + :return: + """ + response = { + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, + } + + return response + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] + """ + Convert blocking simple response. + :param blocking_response: blocking response + :return: + """ + response = cls.convert_blocking_full_response(blocking_response) + + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) + + return response + + @classmethod + def convert_stream_full_response( + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] + ) -> Generator[str, None, None]: + """ + Convert stream full response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(ChatbotAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, + } + + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + else: + response_chunk.update(sub_stream_response.to_dict()) + yield json.dumps(response_chunk) + + @classmethod + def convert_stream_simple_response( + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] + ) -> Generator[str, None, None]: + """ + Convert stream simple response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(ChatbotAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, + } + + if isinstance(sub_stream_response, MessageEndStreamResponse): + sub_stream_response_dict = sub_stream_response.to_dict() + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) + response_chunk.update(sub_stream_response_dict) + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + else: + response_chunk.update(sub_stream_response.to_dict()) + + yield json.dumps(response_chunk) diff --git a/api/core/app/apps/completion/__init__.py b/api/core/app/apps/completion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..02e5d475684cdc4cdf6365733854331a46a582fd --- /dev/null +++ b/api/core/app/apps/completion/app_config_manager.py @@ -0,0 +1,121 @@ +from typing import Optional + +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager +from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from models.model import App, AppMode, AppModelConfig + + +class CompletionAppConfig(EasyUIBasedAppConfig): + """ + Completion App Config Entity. + """ + + pass + + +class CompletionAppConfigManager(BaseAppConfigManager): + @classmethod + def get_app_config( + cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None + ) -> CompletionAppConfig: + """ + Convert app model config to completion app config + :param app_model: app model + :param app_model_config: app model config + :param override_config_dict: app model config dict + :return: + """ + if override_config_dict: + config_from = EasyUIBasedAppModelConfigFrom.ARGS + else: + config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + else: + config_dict = override_config_dict or {} + + app_mode = AppMode.value_of(app_model.mode) + app_config = CompletionAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=app_mode, + app_model_config_from=config_from, + app_model_config_id=app_model_config.id, + app_model_config_dict=config_dict, + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), + ) + + app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( + config=config_dict + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for completion app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.COMPLETION + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # dataset_query_variable + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # more_like_this + config, current_related_config_keys = MoreLikeThisConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..be50d496d236947466fa42e97271274b97f75884 --- /dev/null +++ b/api/core/app/apps/completion/app_generator.py @@ -0,0 +1,339 @@ +import logging +import threading +import uuid +from collections.abc import Generator, Mapping +from typing import Any, Literal, Union, overload + +from flask import Flask, current_app +from pydantic import ValidationError + +from configs import dify_config +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from core.app.apps.completion.app_runner import CompletionAppRunner +from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager +from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom +from core.model_runtime.errors.invoke import InvokeAuthorizationError +from core.ops.ops_trace_manager import TraceQueueManager +from extensions.ext_database import db +from factories import file_factory +from models import Account, App, EndUser, Message +from services.errors.app import MoreLikeThisDisabledError +from services.errors.message import MessageNotExistsError + +logger = logging.getLogger(__name__) + + +class CompletionAppGenerator(MessageBasedAppGenerator): + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool, + ) -> Mapping[str, Any] | Generator[str, None, None]: ... + + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ): + """ + Generate App response. + + :param app_model: App + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + query = args["query"] + if not isinstance(query, str): + raise ValueError("query must be a string") + + query = query.replace("\x00", "") + inputs = args["inputs"] + + # get conversation + conversation = None + + # get app model config + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) + + # validate override model config + override_model_config_dict = None + if args.get("model_config"): + if invoke_from != InvokeFrom.DEBUGGER: + raise ValueError("Only in App debug mode can override model config") + + # validate config + override_model_config_dict = CompletionAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, config=args.get("model_config", {}) + ) + + # parse files + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + ) + else: + file_objs = [] + + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict + ) + + # get tracing instance + trace_manager = TraceQueueManager(app_model.id) + + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras={}, + trace_manager=trace_manager, + ) + + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity) + + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, + ) + + # new thread + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "message_id": message.id, + }, + ) + + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: CompletionAppGenerateEntity, + queue_manager: AppQueueManager, + message_id: str, + ) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get message + message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError() + + # chatbot app + runner = CompletionAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message=message, + ) + except GenerateTaskStoppedError: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except ValueError as e: + if dify_config.DEBUG: + logger.exception("Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.close() + + def generate_more_like_this( + self, + app_model: App, + message_id: str, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + stream: bool = True, + ) -> Union[Mapping[str, Any], Generator[str, None, None]]: + """ + Generate App response. + + :param app_model: App + :param message_id: message ID + :param user: account or end user + :param invoke_from: invoke from source + :param stream: is stream + """ + message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ("api" if isinstance(user, EndUser) else "console"), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ) + .first() + ) + + if not message: + raise MessageNotExistsError() + + current_app_model_config = app_model.app_model_config + more_like_this = current_app_model_config.more_like_this_dict + + if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: + raise MoreLikeThisDisabledError() + + app_model_config = message.app_model_config + override_model_config_dict = app_model_config.to_dict() + model_dict = override_model_config_dict["model"] + completion_params = model_dict.get("completion_params") + completion_params["temperature"] = 0.9 + model_dict["completion_params"] = completion_params + override_model_config_dict["model"] = model_dict + + # parse files + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=message.message_files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + ) + else: + file_objs = [] + + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict + ) + + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + inputs=message.inputs, + query=message.query, + files=list(file_objs), + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras={}, + ) + + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity) + + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, + ) + + # new thread + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "message_id": message.id, + }, + ) + + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=stream, + ) + + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..41278b75b42bf4f6fefbaf155a71d596b710937a --- /dev/null +++ b/api/core/app/apps/completion/app_runner.py @@ -0,0 +1,177 @@ +import logging +from typing import cast + +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.base_app_runner import AppRunner +from core.app.apps.completion.app_config_manager import CompletionAppConfig +from core.app.entities.app_invoke_entities import ( + CompletionAppGenerateEntity, +) +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.model_manager import ModelInstance +from core.moderation.base import ModerationError +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from extensions.ext_database import db +from models.model import App, Message + +logger = logging.getLogger(__name__) + + +class CompletionAppRunner(AppRunner): + """ + Completion Application Runner + """ + + def run( + self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message + ) -> None: + """ + Run application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param message: message + :return: + """ + app_config = application_generate_entity.app_config + app_config = cast(CompletionAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + if not app_record: + raise ValueError("App not found") + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + # Pre-calculate the number of tokens of the prompt messages, + # and return the rest number of tokens by model context token size limit and max token size limit. + # If the rest number of tokens is not enough, raise exception. + # Include: prompt template, inputs, query(optional), files(optional) + # Not Include: memory, external data, dataset context + self.get_pre_calculate_rest_tokens( + app_record=app_record, + model_config=application_generate_entity.model_conf, + prompt_template_entity=app_config.prompt_template, + inputs=inputs, + files=files, + query=query, + ) + + # organize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + prompt_messages, stop = self.organize_prompt_messages( + app_record=app_record, + model_config=application_generate_entity.model_conf, + prompt_template_entity=app_config.prompt_template, + inputs=inputs, + files=files, + query=query, + ) + + # moderation + try: + # process sensitive_word_avoidance + _, inputs, query = self.moderation_for_inputs( + app_id=app_record.id, + tenant_id=app_config.tenant_id, + app_generate_entity=application_generate_entity, + inputs=inputs, + query=query or "", + message_id=message.id, + ) + except ModerationError as e: + self.direct_output( + queue_manager=queue_manager, + app_generate_entity=application_generate_entity, + prompt_messages=prompt_messages, + text=str(e), + stream=application_generate_entity.stream, + ) + return + + # fill in variable inputs from external data tools if exists + external_data_tools = app_config.external_data_variables + if external_data_tools: + inputs = self.fill_in_inputs_from_external_data_tools( + tenant_id=app_record.tenant_id, + app_id=app_record.id, + external_data_tools=external_data_tools, + inputs=inputs, + query=query, + ) + + # get context from datasets + context = None + if app_config.dataset and app_config.dataset.dataset_ids: + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager, + app_record.id, + message.id, + application_generate_entity.user_id, + application_generate_entity.invoke_from, + ) + + dataset_config = app_config.dataset + if dataset_config and dataset_config.retrieve_config.query_variable: + query = inputs.get(dataset_config.retrieve_config.query_variable, "") + + dataset_retrieval = DatasetRetrieval(application_generate_entity) + context = dataset_retrieval.retrieve( + app_id=app_record.id, + user_id=application_generate_entity.user_id, + tenant_id=app_record.tenant_id, + model_config=application_generate_entity.model_conf, + config=dataset_config, + query=query or "", + invoke_from=application_generate_entity.invoke_from, + show_retrieve_source=app_config.additional_features.show_retrieve_source, + hit_callback=hit_callback, + message_id=message.id, + ) + + # reorganize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + # memory(optional), external data, dataset context(optional) + prompt_messages, stop = self.organize_prompt_messages( + app_record=app_record, + model_config=application_generate_entity.model_conf, + prompt_template_entity=app_config.prompt_template, + inputs=inputs, + files=files, + query=query, + context=context, + ) + + # check hosting moderation + hosting_moderation_result = self.check_hosting_moderation( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + prompt_messages=prompt_messages, + ) + + if hosting_moderation_result: + return + + # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit + self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages) + + # Invoke model + model_instance = ModelInstance( + provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, + model=application_generate_entity.model_conf.model, + ) + + db.session.close() + + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=application_generate_entity.model_conf.parameters, + stop=stop, + stream=application_generate_entity.stream, + user=application_generate_entity.user_id, + ) + + # handle invoke result + self._handle_invoke_result( + invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream + ) diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..73f38c3d0bcb966efc803876ded1a537203ff809 --- /dev/null +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -0,0 +1,118 @@ +import json +from collections.abc import Generator +from typing import cast + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.task_entities import ( + CompletionAppBlockingResponse, + CompletionAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) + + +class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): + _blocking_response_type = CompletionAppBlockingResponse + + @classmethod + def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] + """ + Convert blocking full response. + :param blocking_response: blocking response + :return: + """ + response = { + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, + } + + return response + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] + """ + Convert blocking simple response. + :param blocking_response: blocking response + :return: + """ + response = cls.convert_blocking_full_response(blocking_response) + + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) + + return response + + @classmethod + def convert_stream_full_response( + cls, + stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] + ) -> Generator[str, None, None]: + """ + Convert stream full response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(CompletionAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "message_id": chunk.message_id, + "created_at": chunk.created_at, + } + + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + else: + response_chunk.update(sub_stream_response.to_dict()) + yield json.dumps(response_chunk) + + @classmethod + def convert_stream_simple_response( + cls, + stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] + ) -> Generator[str, None, None]: + """ + Convert stream simple response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(CompletionAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "message_id": chunk.message_id, + "created_at": chunk.created_at, + } + + if isinstance(sub_stream_response, MessageEndStreamResponse): + sub_stream_response_dict = sub_stream_response.to_dict() + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) + response_chunk.update(sub_stream_response_dict) + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + else: + response_chunk.update(sub_stream_response.to_dict()) + + yield json.dumps(response_chunk) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..cccd62cd5ba375660bc96855e29d7e2d0530c815 --- /dev/null +++ b/api/core/app/apps/message_based_app_generator.py @@ -0,0 +1,293 @@ +import json +import logging +from collections.abc import Generator +from datetime import UTC, datetime +from typing import Optional, Union, cast + +from sqlalchemy import and_ + +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + AgentChatAppGenerateEntity, + AppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + InvokeFrom, +) +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + CompletionAppBlockingResponse, + CompletionAppStreamResponse, +) +from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from extensions.ext_database import db +from models import Account +from models.enums import CreatedByRole +from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile +from services.errors.app_model_config import AppModelConfigBrokenError +from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError + +logger = logging.getLogger(__name__) + + +class MessageBasedAppGenerator(BaseAppGenerator): + def _handle_response( + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AgentChatAppGenerateEntity, + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False, + ) -> Union[ + ChatbotAppBlockingResponse, + CompletionAppBlockingResponse, + Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], + ]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + :param user: user + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream, + ) + + try: + return generate_task_pipeline.process() + except ValueError as e: + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error + raise GenerateTaskStoppedError() + else: + logger.exception(f"Failed to handle response, conversation_id: {conversation.id}") + raise e + + def _get_conversation_by_user( + self, app_model: App, conversation_id: str, user: Union[Account, EndUser] + ) -> Conversation: + conversation_filter = [ + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + Conversation.status == "normal", + Conversation.is_deleted.is_(False), + ] + + if isinstance(user, Account): + conversation_filter.append(Conversation.from_account_id == user.id) + else: + conversation_filter.append(Conversation.from_end_user_id == user.id if user else None) + + conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first() + + if not conversation: + raise ConversationNotExistsError() + + if conversation.status != "normal": + raise ConversationCompletedError() + + return conversation + + def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: + if conversation: + app_model_config = ( + db.session.query(AppModelConfig) + .filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) + .first() + ) + + if not app_model_config: + raise AppModelConfigBrokenError() + else: + if app_model.app_model_config_id is None: + raise AppModelConfigBrokenError() + + app_model_config = app_model.app_model_config + + if not app_model_config: + raise AppModelConfigBrokenError() + + return app_model_config + + def _init_generate_records( + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + ], + conversation: Optional[Conversation] = None, + ) -> tuple[Conversation, Message]: + """ + Initialize generate records + :param application_generate_entity: application generate entity + :conversation conversation + :return: + """ + app_config: EasyUIBasedAppConfig = cast(EasyUIBasedAppConfig, application_generate_entity.app_config) + + # get from source + end_user_id = None + account_id = None + if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: + from_source = "api" + end_user_id = application_generate_entity.user_id + else: + from_source = "console" + account_id = application_generate_entity.user_id + + if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): + app_model_config_id = None + override_model_configs = None + model_provider = None + model_id = None + else: + app_model_config_id = app_config.app_model_config_id + model_provider = application_generate_entity.model_conf.provider + model_id = application_generate_entity.model_conf.model + override_model_configs = None + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in { + AppMode.AGENT_CHAT, + AppMode.CHAT, + AppMode.COMPLETION, + }: + override_model_configs = app_config.app_model_config_dict + + # get conversation introduction + introduction = self._get_conversation_introduction(application_generate_entity) + + if not conversation: + conversation = Conversation( + app_id=app_config.app_id, + app_model_config_id=app_model_config_id, + model_provider=model_provider, + model_id=model_id, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + mode=app_config.app_mode.value, + name="New conversation", + inputs=application_generate_entity.inputs, + introduction=introduction, + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=application_generate_entity.invoke_from.value, + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id, + ) + + db.session.add(conversation) + db.session.commit() + db.session.refresh(conversation) + else: + conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() + + message = Message( + app_id=app_config.app_id, + model_provider=model_provider, + model_id=model_id, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + conversation_id=conversation.id, + inputs=application_generate_entity.inputs, + query=application_generate_entity.query or "", + message="", + message_tokens=0, + message_unit_price=0, + message_price_unit=0, + answer="", + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + parent_message_id=getattr(application_generate_entity, "parent_message_id", None), + provider_response_latency=0, + total_price=0, + currency="USD", + invoke_from=application_generate_entity.invoke_from.value, + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id, + ) + + db.session.add(message) + db.session.commit() + db.session.refresh(message) + + for file in application_generate_entity.files: + message_file = MessageFile( + message_id=message.id, + type=file.type, + transfer_method=file.transfer_method, + belongs_to="user", + url=file.remote_url, + upload_file_id=file.related_id, + created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER), + created_by=account_id or end_user_id or "", + ) + db.session.add(message_file) + db.session.commit() + + return conversation, message + + def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str: + """ + Get conversation introduction + :param application_generate_entity: application generate entity + :return: conversation introduction + """ + app_config = application_generate_entity.app_config + introduction = app_config.additional_features.opening_statement + + if introduction: + try: + inputs = application_generate_entity.inputs + prompt_template = PromptTemplateParser(template=introduction) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + introduction = prompt_template.format(prompt_inputs) + except KeyError: + pass + + return introduction or "" + + def _get_conversation(self, conversation_id: str): + """ + Get conversation by conversation id + :param conversation_id: conversation id + :return: conversation + """ + conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + + if not conversation: + raise ConversationNotExistsError() + + return conversation + + def _get_message(self, message_id: str) -> Optional[Message]: + """ + Get message by message id + :param message_id: message id + :return: message + """ + message = db.session.query(Message).filter(Message.id == message_id).first() + + return message diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..363c3c82bbc24e489bd1aa318dbe1ca24182db9b --- /dev/null +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -0,0 +1,56 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + AppQueueEvent, + MessageQueueMessage, + QueueAdvancedChatMessageEndEvent, + QueueErrorEvent, + QueueMessage, + QueueMessageEndEvent, + QueueStopEvent, +) + + +class MessageBasedAppQueueManager(AppQueueManager): + def __init__( + self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str + ) -> None: + super().__init__(task_id, user_id, invoke_from) + + self._conversation_id = str(conversation_id) + self._app_mode = app_mode + self._message_id = str(message_id) + + def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: + return MessageQueueMessage( + task_id=self._task_id, + message_id=self._message_id, + conversation_id=self._conversation_id, + app_mode=self._app_mode, + event=event, + ) + + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + message = MessageQueueMessage( + task_id=self._task_id, + message_id=self._message_id, + conversation_id=self._conversation_id, + app_mode=self._app_mode, + event=event, + ) + + self._q.put(message) + + if isinstance( + event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent + ): + self.stop_listen() + + if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/workflow/__init__.py b/api/core/app/apps/workflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..b0aa21c7317b653c6ca5e4341395489bde08e8d6 --- /dev/null +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -0,0 +1,67 @@ +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager +from models.model import App, AppMode +from models.workflow import Workflow + + +class WorkflowAppConfig(WorkflowUIBasedAppConfig): + """ + Workflow App Config Entity. + """ + + pass + + +class WorkflowAppConfigManager(BaseAppConfigManager): + @classmethod + def get_app_config(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig: + features_dict = workflow.features_dict + + app_mode = AppMode.value_of(app_model.mode) + app_config = WorkflowAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=app_mode, + workflow_id=workflow.id, + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict), + variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + additional_features=cls.convert_features(features_dict, app_mode), + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + """ + Validate for workflow app model config + + :param tenant_id: tenant id + :param config: app model config args + :param only_structure_validate: only validate the structure of the config + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate + ) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..3e65fd56c808e029ae953e673a52e90e175e5db4 --- /dev/null +++ b/api/core/app/apps/workflow/app_generator.py @@ -0,0 +1,319 @@ +import contextvars +import logging +import threading +import uuid +from collections.abc import Generator, Mapping, Sequence +from typing import Any, Literal, Optional, Union, overload + +from flask import Flask, current_app +from pydantic import ValidationError + +import contexts +from configs import dify_config +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager +from core.app.apps.workflow.app_runner import WorkflowAppRunner +from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter +from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.model_runtime.errors.invoke import InvokeAuthorizationError +from core.ops.ops_trace_manager import TraceQueueManager +from extensions.ext_database import db +from factories import file_factory +from models import Account, App, EndUser, Workflow + +logger = logging.getLogger(__name__) + + +class WorkflowAppGenerator(BaseAppGenerator): + @overload + def generate( + self, + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> Mapping[str, Any] | Generator[str, None, None]: ... + + def generate( + self, + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ): + files: Sequence[Mapping[str, Any]] = args.get("files") or [] + + # parse files + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + system_files = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + ) + + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow, + ) + + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, + user_id=user.id if isinstance(user, Account) else user.session_id, + ) + + inputs: Mapping[str, Any] = args["inputs"] + workflow_run_id = str(uuid.uuid4()) + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + file_upload_config=file_extra_config, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + files=list(system_files), + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + call_depth=call_depth, + trace_manager=trace_manager, + workflow_run_id=workflow_run_id, + ) + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + streaming=streaming, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + def _generate( + self, + *, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + application_generate_entity: WorkflowAppGenerateEntity, + invoke_from: InvokeFrom, + streaming: bool = True, + workflow_thread_pool_id: Optional[str] = None, + ) -> Mapping[str, Any] | Generator[str, None, None]: + # init queue manager + queue_manager = WorkflowAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + app_mode=app_model.mode, + ) + + # new thread + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": contextvars.copy_context(), + "workflow_thread_pool_id": workflow_thread_pool_id, + }, + ) + + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + stream=streaming, + ) + + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + + def single_iteration_generate( + self, + app_model: App, + workflow: Workflow, + node_id: str, + user: Account, + args: Mapping[str, Any], + streaming: bool = True, + ) -> Mapping[str, Any] | Generator[str, None, None]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not node_id: + raise ValueError("node_id is required") + + if args.get("inputs") is None: + raise ValueError("inputs is required") + + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + inputs={}, + files=[], + user_id=user.id, + stream=streaming, + invoke_from=InvokeFrom.DEBUGGER, + extras={"auto_generate_conversation_name": False}, + single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( + node_id=node_id, inputs=args["inputs"] + ), + workflow_run_id=str(uuid.uuid4()), + ) + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + invoke_from=InvokeFrom.DEBUGGER, + application_generate_entity=application_generate_entity, + streaming=streaming, + ) + + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + context: contextvars.Context, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param workflow_thread_pool_id: workflow thread pool id + :return: + """ + for var, val in context.items(): + var.set(val) + with flask_app.app_context(): + try: + # workflow app + runner = WorkflowAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + runner.run() + except GenerateTaskStoppedError: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except ValueError as e: + if dify_config.DEBUG: + logger.exception("Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.close() + + def _handle_response( + self, + application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool = False, + ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param workflow: workflow + :param queue_manager: queue manager + :param user: account or end user + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + stream=stream, + ) + + try: + return generate_task_pipeline.process() + except ValueError as e: + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error + raise GenerateTaskStoppedError() + else: + logger.exception( + f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}" + ) + raise e diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..349b8eb51b1546b2dc278ffb2c5848805d027c69 --- /dev/null +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -0,0 +1,44 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueErrorEvent, + QueueMessageEndEvent, + QueueStopEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowSucceededEvent, + WorkflowQueueMessage, +) + + +class WorkflowAppQueueManager(AppQueueManager): + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: + super().__init__(task_id, user_id, invoke_from) + + self._app_mode = app_mode + + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event) + + self._q.put(message) + + if isinstance( + event, + QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowSucceededEvent + | QueueWorkflowFailedEvent + | QueueWorkflowPartialSuccessEvent, + ): + self.stop_listen() + + if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..faefcb0ed506294f20b9abadc60d3783de6b9e08 --- /dev/null +++ b/api/core/app/apps/workflow/app_runner.py @@ -0,0 +1,130 @@ +import logging +from typing import Optional, cast + +from configs import dify_config +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.workflow.app_config_manager import WorkflowAppConfig +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import ( + InvokeFrom, + WorkflowAppGenerateEntity, +) +from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db +from models.enums import UserFrom +from models.model import App, EndUser +from models.workflow import WorkflowType + +logger = logging.getLogger(__name__) + + +class WorkflowAppRunner(WorkflowBasedAppRunner): + """ + Workflow Application Runner + """ + + def __init__( + self, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: + """ + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param workflow_thread_pool_id: workflow thread pool id + """ + self.application_generate_entity = application_generate_entity + self.queue_manager = queue_manager + self.workflow_thread_pool_id = workflow_thread_pool_id + + def run(self) -> None: + """ + Run application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :return: + """ + app_config = self.application_generate_entity.app_config + app_config = cast(WorkflowAppConfig, app_config) + + user_id = None + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = self.application_generate_entity.user_id + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + if not app_record: + raise ValueError("App not found") + + workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + + db.session.close() + + workflow_callbacks: list[WorkflowCallback] = [] + if dify_config.DEBUG: + workflow_callbacks.append(WorkflowLoggingCallback()) + + # if only single iteration run is requested + if self.application_generate_entity.single_iteration_run: + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs, + ) + else: + inputs = self.application_generate_entity.inputs + files = self.application_generate_entity.files + + # Create a variable pool. + system_inputs = { + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.APP_ID: app_config.app_id, + SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, + SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, + } + + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=[], + ) + + # init graph + graph = self._init_graph(graph_config=workflow.graph_dict) + + # RUN WORKFLOW + workflow_entry = WorkflowEntry( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=WorkflowType.value_of(workflow.type), + graph=graph, + graph_config=workflow.graph_dict, + user_id=self.application_generate_entity.user_id, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else UserFrom.END_USER + ), + invoke_from=self.application_generate_entity.invoke_from, + call_depth=self.application_generate_entity.call_depth, + variable_pool=variable_pool, + thread_pool_id=self.workflow_thread_pool_id, + ) + + generator = workflow_entry.run(callbacks=workflow_callbacks) + + for event in generator: + self._handle_event(workflow_entry, event) diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..5cdac6ad28fdaa2a28f3526bf3d85db4d0a7f740 --- /dev/null +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -0,0 +1,97 @@ +import json +from collections.abc import Generator +from typing import cast + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.task_entities import ( + ErrorStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + PingStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, +) + + +class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): + _blocking_response_type = WorkflowAppBlockingResponse + + @classmethod + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] + """ + Convert blocking full response. + :param blocking_response: blocking response + :return: + """ + return dict(blocking_response.to_dict()) + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] + """ + Convert blocking simple response. + :param blocking_response: blocking response + :return: + """ + return cls.convert_blocking_full_response(blocking_response) + + @classmethod + def convert_stream_full_response( + cls, + stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] + ) -> Generator[str, None, None]: + """ + Convert stream full response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(WorkflowAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, + } + + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + else: + response_chunk.update(sub_stream_response.to_dict()) + yield json.dumps(response_chunk) + + @classmethod + def convert_stream_simple_response( + cls, + stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] + ) -> Generator[str, None, None]: + """ + Convert stream simple response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(WorkflowAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, + } + + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) + else: + response_chunk.update(sub_stream_response.to_dict()) + yield json.dumps(response_chunk) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..f89f456916e72e83fc23acc9759bc82504eefdcd --- /dev/null +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -0,0 +1,576 @@ +import logging +import time +from collections.abc import Generator +from typing import Optional, Union + +from sqlalchemy.orm import Session + +from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME +from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import ( + InvokeFrom, + WorkflowAppGenerateEntity, +) +from core.app.entities.queue_entities import ( + QueueErrorEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, + QueuePingEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.app.entities.task_entities import ( + ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, + StreamResponse, + TextChunkStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, + WorkflowFinishStreamResponse, + WorkflowStartStreamResponse, + WorkflowTaskState, +) +from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline +from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage +from core.ops.ops_trace_manager import TraceQueueManager +from core.workflow.enums import SystemVariableKey +from extensions.ext_database import db +from models.account import Account +from models.enums import CreatedByRole +from models.model import EndUser +from models.workflow import ( + Workflow, + WorkflowAppLog, + WorkflowAppLogCreatedFrom, + WorkflowRun, + WorkflowRunStatus, +) + +logger = logging.getLogger(__name__) + + +class WorkflowAppGenerateTaskPipeline: + """ + WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. + """ + + def __init__( + self, + application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool, + ) -> None: + self._base_task_pipeline = BasedGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) + + if isinstance(user, EndUser): + self._user_id = user.id + user_session_id = user.session_id + self._created_by_role = CreatedByRole.END_USER + elif isinstance(user, Account): + self._user_id = user.id + user_session_id = user.id + self._created_by_role = CreatedByRole.ACCOUNT + else: + raise ValueError(f"Invalid user type: {type(user)}") + + self._workflow_cycle_manager = WorkflowCycleManage( + application_generate_entity=application_generate_entity, + workflow_system_variables={ + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.USER_ID: user_session_id, + SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, + }, + ) + + self._application_generate_entity = application_generate_entity + self._workflow_id = workflow.id + self._workflow_features_dict = workflow.features_dict + self._task_state = WorkflowTaskState() + self._workflow_run_id = "" + + def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: + """ + Process generate task pipeline. + :return: + """ + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) + if self._base_task_pipeline._stream: + return self._to_stream_response(generator) + else: + return self._to_blocking_response(generator) + + def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse: + """ + To blocking response. + :return: + """ + for stream_response in generator: + if isinstance(stream_response, ErrorStreamResponse): + raise stream_response.err + elif isinstance(stream_response, WorkflowFinishStreamResponse): + response = WorkflowAppBlockingResponse( + task_id=self._application_generate_entity.task_id, + workflow_run_id=stream_response.data.id, + data=WorkflowAppBlockingResponse.Data( + id=stream_response.data.id, + workflow_id=stream_response.data.workflow_id, + status=stream_response.data.status, + outputs=stream_response.data.outputs, + error=stream_response.data.error, + elapsed_time=stream_response.data.elapsed_time, + total_tokens=stream_response.data.total_tokens, + total_steps=stream_response.data.total_steps, + created_at=int(stream_response.data.created_at), + finished_at=int(stream_response.data.finished_at), + ), + ) + + return response + else: + continue + + raise ValueError("queue listening stopped unexpectedly.") + + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[WorkflowAppStreamResponse, None, None]: + """ + To stream response. + :return: + """ + workflow_run_id = None + for stream_response in generator: + if isinstance(stream_response, WorkflowStartStreamResponse): + workflow_run_id = stream_response.workflow_run_id + + yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) + + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): + if not publisher: + return None + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": + return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) + return None + + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: + tts_publisher = None + task_id = self._application_generate_entity.task_id + tenant_id = self._application_generate_entity.app_config.tenant_id + features_dict = self._workflow_features_dict + + if ( + features_dict.get("text_to_speech") + and features_dict["text_to_speech"].get("enabled") + and features_dict["text_to_speech"].get("autoPlay") == "enabled" + ): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) + + for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): + while True: + audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) + if audio_response: + yield audio_response + else: + break + yield response + + start_listener_time = time.time() + while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: + try: + if not tts_publisher: + break + audio_trunk = tts_publisher.check_and_get_audio() + if audio_trunk is None: + # release cpu + # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) + time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME) + continue + if audio_trunk.status == "finish": + break + else: + yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) + except Exception: + logger.exception(f"Fails to get audio trunk, task_id: {task_id}") + break + if tts_publisher: + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) + + def _process_stream_response( + self, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, + ) -> Generator[StreamResponse, None, None]: + """ + Process stream response. + :return: + """ + graph_runtime_state = None + + for queue_message in self._base_task_pipeline._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueuePingEvent): + yield self._base_task_pipeline._ping_stream_response() + elif isinstance(event, QueueErrorEvent): + err = self._base_task_pipeline._handle_error(event=event) + yield self._base_task_pipeline._error_to_stream_response(err) + break + elif isinstance(event, QueueWorkflowStartedEvent): + # override graph runtime state + graph_runtime_state = event.graph_runtime_state + + with Session(db.engine, expire_on_commit=False) as session: + # init workflow run + workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( + session=session, + workflow_id=self._workflow_id, + user_id=self._user_id, + created_by_role=self._created_by_role, + ) + self._workflow_run_id = workflow_run.id + start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield start_resp + elif isinstance( + event, + QueueNodeRetryEvent, + ): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( + session=session, workflow_run=workflow_run, event=event + ) + response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() + + if response: + yield response + elif isinstance(event, QueueNodeStartedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( + session=session, workflow_run=workflow_run, event=event + ) + node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() + + if node_start_response: + yield node_start_response + elif isinstance(event, QueueNodeSucceededEvent): + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + session=session, event=event + ) + node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() + + if node_success_response: + yield node_success_response + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + session=session, + event=event, + ) + node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() + + if node_failed_response: + yield node_failed_response + + elif isinstance(event, QueueParallelBranchRunStartedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_start_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + ) + + yield parallel_start_resp + + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_finish_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + ) + + yield parallel_finish_resp + + elif isinstance(event, QueueIterationStartEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_start_resp + + elif isinstance(event, QueueIterationNextEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_next_resp + + elif isinstance(event, QueueIterationCompletedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_finish_resp + + elif isinstance(event, QueueWorkflowSucceededEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + conversation_id=None, + trace_manager=trace_manager, + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) + + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + ) + session.commit() + + yield workflow_finish_resp + elif isinstance(event, QueueWorkflowPartialSuccessEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) + + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_finish_resp + elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, + error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + conversation_id=None, + trace_manager=trace_manager, + exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) + + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_finish_resp + elif isinstance(event, QueueTextChunkEvent): + delta_text = event.text + if delta_text is None: + continue + + # only publish tts message at text chunk streaming + if tts_publisher: + tts_publisher.publish(queue_message) + + self._task_state.answer += delta_text + yield self._text_chunk_to_stream_response( + delta_text, from_variable_selector=event.from_variable_selector + ) + else: + continue + + if tts_publisher: + tts_publisher.publish(None) + + def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None: + """ + Save workflow app log. + :return: + """ + invoke_from = self._application_generate_entity.invoke_from + if invoke_from == InvokeFrom.SERVICE_API: + created_from = WorkflowAppLogCreatedFrom.SERVICE_API + elif invoke_from == InvokeFrom.EXPLORE: + created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP + elif invoke_from == InvokeFrom.WEB_APP: + created_from = WorkflowAppLogCreatedFrom.WEB_APP + else: + # not save log for debugging + return + + workflow_app_log = WorkflowAppLog() + workflow_app_log.tenant_id = workflow_run.tenant_id + workflow_app_log.app_id = workflow_run.app_id + workflow_app_log.workflow_id = workflow_run.workflow_id + workflow_app_log.workflow_run_id = workflow_run.id + workflow_app_log.created_from = created_from.value + workflow_app_log.created_by_role = self._created_by_role + workflow_app_log.created_by = self._user_id + + session.add(workflow_app_log) + + def _text_chunk_to_stream_response( + self, text: str, from_variable_selector: Optional[list[str]] = None + ) -> TextChunkStreamResponse: + """ + Handle completed event. + :param text: text + :return: + """ + response = TextChunkStreamResponse( + task_id=self._application_generate_entity.task_id, + data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector), + ) + + return response diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..63f516bcc6068284846820c4b3543d2c17b1dc63 --- /dev/null +++ b/api/core/app/apps/workflow_app_runner.py @@ -0,0 +1,481 @@ +from collections.abc import Mapping +from typing import Any, Optional, cast + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, + QueueRetrieverResourcesEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeInIterationFailedEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes import NodeType +from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db +from models.model import App +from models.workflow import Workflow + + +class WorkflowBasedAppRunner(AppRunner): + def __init__(self, queue_manager: AppQueueManager): + self.queue_manager = queue_manager + + def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: + """ + Init graph + """ + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") + + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") + + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") + # init graph + graph = Graph.init(graph_config=graph_config) + + if not graph: + raise ValueError("graph not found in workflow") + + return graph + + def _get_graph_and_variable_pool_of_single_iteration( + self, + workflow: Workflow, + node_id: str, + user_inputs: dict, + ) -> tuple[Graph, VariablePool]: + """ + Get variable pool of single iteration + """ + # fetch workflow graph + graph_config = workflow.graph_dict + if not graph_config: + raise ValueError("workflow graph not found") + + graph_config = cast(dict[str, Any], graph_config) + + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") + + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") + + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") + + # filter nodes only in iteration + node_configs = [ + node + for node in graph_config.get("nodes", []) + if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id + ] + + graph_config["nodes"] = node_configs + + node_ids = [node.get("id") for node in node_configs] + + # filter edges only in iteration + edge_configs = [ + edge + for edge in graph_config.get("edges", []) + if (edge.get("source") is None or edge.get("source") in node_ids) + and (edge.get("target") is None or edge.get("target") in node_ids) + ] + + graph_config["edges"] = edge_configs + + # init graph + graph = Graph.init(graph_config=graph_config, root_node_id=node_id) + + if not graph: + raise ValueError("graph not found in workflow") + + # fetch node config from node id + iteration_node_config = None + for node in node_configs: + if node.get("id") == node_id: + iteration_node_config = node + break + + if not iteration_node_config: + raise ValueError("iteration node id not found in workflow graph") + + # Get node class + node_type = NodeType(iteration_node_config.get("data", {}).get("type")) + node_version = iteration_node_config.get("data", {}).get("version", "1") + node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=workflow.environment_variables, + ) + + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, config=iteration_node_config + ) + except NotImplementedError: + variable_mapping = {} + + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + ) + + return graph, variable_pool + + def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: + """ + Handle event + :param workflow_entry: workflow entry + :param event: event + """ + if isinstance(event, GraphRunStartedEvent): + self._publish_event( + QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state) + ) + elif isinstance(event, GraphRunSucceededEvent): + self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs)) + elif isinstance(event, GraphRunPartialSucceededEvent): + self._publish_event( + QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count) + ) + elif isinstance(event, GraphRunFailedEvent): + self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) + elif isinstance(event, NodeRunRetryEvent): + node_run_result = event.route_node_state.node_run_result + inputs: Mapping[str, Any] | None = {} + process_data: Mapping[str, Any] | None = {} + outputs: Mapping[str, Any] | None = {} + execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {} + if node_run_result: + inputs = node_run_result.inputs + process_data = node_run_result.process_data + outputs = node_run_result.outputs + execution_metadata = node_run_result.metadata + self._publish_event( + QueueNodeRetryEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + node_run_index=event.route_node_state.index, + predecessor_node_id=event.predecessor_node_id, + in_iteration_id=event.in_iteration_id, + parallel_mode_run_id=event.parallel_mode_run_id, + inputs=inputs, + process_data=process_data, + outputs=outputs, + error=event.error, + execution_metadata=execution_metadata, + retry_index=event.retry_index, + ) + ) + elif isinstance(event, NodeRunStartedEvent): + self._publish_event( + QueueNodeStartedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + node_run_index=event.route_node_state.index, + predecessor_node_id=event.predecessor_node_id, + in_iteration_id=event.in_iteration_id, + parallel_mode_run_id=event.parallel_mode_run_id, + ) + ) + elif isinstance(event, NodeRunSucceededEvent): + node_run_result = event.route_node_state.node_run_result + if node_run_result: + inputs = node_run_result.inputs + process_data = node_run_result.process_data + outputs = node_run_result.outputs + execution_metadata = node_run_result.metadata + else: + inputs = {} + process_data = {} + outputs = {} + execution_metadata = {} + self._publish_event( + QueueNodeSucceededEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=inputs, + process_data=process_data, + outputs=outputs, + execution_metadata=execution_metadata, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, NodeRunFailedEvent): + self._publish_event( + QueueNodeFailedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs or {} + if event.route_node_state.node_run_result + else {}, + error=event.route_node_state.node_run_result.error + if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error + else "Unknown error", + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, NodeRunExceptionEvent): + self._publish_event( + QueueNodeExceptionEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + error=event.route_node_state.node_run_result.error + if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error + else "Unknown error", + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, NodeInIterationFailedEvent): + self._publish_event( + QueueNodeInIterationFailedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs or {} + if event.route_node_state.node_run_result + else {}, + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + error=event.error, + ) + ) + elif isinstance(event, NodeRunStreamChunkEvent): + self._publish_event( + QueueTextChunkEvent( + text=event.chunk_content, + from_variable_selector=event.from_variable_selector, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, NodeRunRetrieverResourceEvent): + self._publish_event( + QueueRetrieverResourcesEvent( + retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, ParallelBranchRunStartedEvent): + self._publish_event( + QueueParallelBranchRunStartedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, ParallelBranchRunSucceededEvent): + self._publish_event( + QueueParallelBranchRunSucceededEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, ParallelBranchRunFailedEvent): + self._publish_event( + QueueParallelBranchRunFailedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id, + error=event.error, + ) + ) + elif isinstance(event, IterationRunStartedEvent): + self._publish_event( + QueueIterationStartEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + inputs=event.inputs, + predecessor_node_id=event.predecessor_node_id, + metadata=event.metadata, + ) + ) + elif isinstance(event, IterationRunNextEvent): + self._publish_event( + QueueIterationNextEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + index=event.index, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + output=event.pre_iteration_output, + parallel_mode_run_id=event.parallel_mode_run_id, + duration=event.duration, + ) + ) + elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): + self._publish_event( + QueueIterationCompletedEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + inputs=event.inputs, + outputs=event.outputs, + metadata=event.metadata, + steps=event.steps, + error=event.error if isinstance(event, IterationRunFailedEvent) else None, + ) + ) + + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id + ) + .first() + ) + + # return workflow + return workflow + + def _publish_event(self, event: AppQueueEvent) -> None: + self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/entities/__init__.py b/api/core/app/entities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb4e590326257628a01905ee699c6a25792608c --- /dev/null +++ b/api/core/app/entities/app_invoke_entities.py @@ -0,0 +1,208 @@ +from collections.abc import Mapping, Sequence +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator + +from constants import UUID_NIL +from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig +from core.entities.provider_configuration import ProviderModelBundle +from core.file import File, FileUploadConfig +from core.model_runtime.entities.model_entities import AIModelEntity +from core.ops.ops_trace_manager import TraceQueueManager + + +class InvokeFrom(Enum): + """ + Invoke From. + """ + + SERVICE_API = "service-api" + WEB_APP = "web-app" + EXPLORE = "explore" + DEBUGGER = "debugger" + + @classmethod + def value_of(cls, value: str): + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid invoke from value {value}") + + def to_source(self) -> str: + """ + Get source of invoke from. + + :return: source + """ + if self == InvokeFrom.WEB_APP: + return "web_app" + elif self == InvokeFrom.DEBUGGER: + return "dev" + elif self == InvokeFrom.EXPLORE: + return "explore_app" + elif self == InvokeFrom.SERVICE_API: + return "api" + + return "dev" + + +class ModelConfigWithCredentialsEntity(BaseModel): + """ + Model Config With Credentials Entity. + """ + + provider: str + model: str + model_schema: AIModelEntity + mode: str + provider_model_bundle: ProviderModelBundle + credentials: dict[str, Any] = {} + parameters: dict[str, Any] = {} + stop: list[str] = [] + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) + + +class AppGenerateEntity(BaseModel): + """ + App Generate Entity. + """ + + task_id: str + + # app config + app_config: Any + file_upload_config: Optional[FileUploadConfig] = None + + inputs: Mapping[str, Any] + files: Sequence[File] + user_id: str + + # extras + stream: bool + invoke_from: InvokeFrom + + # invoke call depth + call_depth: int = 0 + + # extra parameters, like: auto_generate_conversation_name + extras: dict[str, Any] = {} + + # tracing instance + trace_manager: Optional[TraceQueueManager] = None + + class Config: + arbitrary_types_allowed = True + + +class EasyUIBasedAppGenerateEntity(AppGenerateEntity): + """ + Chat Application Generate Entity. + """ + + # app config + app_config: EasyUIBasedAppConfig + model_conf: ModelConfigWithCredentialsEntity + + query: Optional[str] = None + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) + + +class ConversationAppGenerateEntity(AppGenerateEntity): + """ + Base entity for conversation-based app generation. + """ + + conversation_id: Optional[str] = None + parent_message_id: Optional[str] = Field( + default=None, + description=( + "Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API." + "For service API, we need to ensure its forward compatibility, " + "so passing in the parent_message_id as request arg is not supported for now. " + "It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages." + ), + ) + + @field_validator("parent_message_id") + @classmethod + def validate_parent_message_id(cls, v, info: ValidationInfo): + if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL: + raise ValueError("parent_message_id should be UUID_NIL for service API") + return v + + +class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity): + """ + Chat Application Generate Entity. + """ + + pass + + +class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity): + """ + Completion Application Generate Entity. + """ + + pass + + +class AgentChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity): + """ + Agent Chat Application Generate Entity. + """ + + pass + + +class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): + """ + Advanced Chat Application Generate Entity. + """ + + # app config + app_config: WorkflowUIBasedAppConfig + + workflow_run_id: Optional[str] = None + query: str + + class SingleIterationRunEntity(BaseModel): + """ + Single Iteration Run Entity. + """ + + node_id: str + inputs: dict + + single_iteration_run: Optional[SingleIterationRunEntity] = None + + +class WorkflowAppGenerateEntity(AppGenerateEntity): + """ + Workflow Application Generate Entity. + """ + + # app config + app_config: WorkflowUIBasedAppConfig + workflow_run_id: str + + class SingleIterationRunEntity(BaseModel): + """ + Single Iteration Run Entity. + """ + + node_id: str + inputs: dict + + single_iteration_run: Optional[SingleIterationRunEntity] = None diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py new file mode 100644 index 0000000000000000000000000000000000000000..a93e533ff45d267b3e841bc645466251ede3efa7 --- /dev/null +++ b/api/core/app/entities/queue_entities.py @@ -0,0 +1,568 @@ +from collections.abc import Mapping +from datetime import datetime +from enum import Enum, StrEnum +from typing import Any, Optional + +from pydantic import BaseModel + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes import NodeType +from core.workflow.nodes.base import BaseNodeData + + +class QueueEvent(StrEnum): + """ + QueueEvent enum + """ + + LLM_CHUNK = "llm_chunk" + TEXT_CHUNK = "text_chunk" + AGENT_MESSAGE = "agent_message" + MESSAGE_REPLACE = "message_replace" + MESSAGE_END = "message_end" + ADVANCED_CHAT_MESSAGE_END = "advanced_chat_message_end" + WORKFLOW_STARTED = "workflow_started" + WORKFLOW_SUCCEEDED = "workflow_succeeded" + WORKFLOW_FAILED = "workflow_failed" + WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded" + ITERATION_START = "iteration_start" + ITERATION_NEXT = "iteration_next" + ITERATION_COMPLETED = "iteration_completed" + NODE_STARTED = "node_started" + NODE_SUCCEEDED = "node_succeeded" + NODE_FAILED = "node_failed" + NODE_EXCEPTION = "node_exception" + RETRIEVER_RESOURCES = "retriever_resources" + ANNOTATION_REPLY = "annotation_reply" + AGENT_THOUGHT = "agent_thought" + MESSAGE_FILE = "message_file" + PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started" + PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded" + PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed" + ERROR = "error" + PING = "ping" + STOP = "stop" + RETRY = "retry" + + +class AppQueueEvent(BaseModel): + """ + QueueEvent abstract entity + """ + + event: QueueEvent + + +class QueueLLMChunkEvent(AppQueueEvent): + """ + QueueLLMChunkEvent entity + Only for basic mode apps + """ + + event: QueueEvent = QueueEvent.LLM_CHUNK + chunk: LLMResultChunk + + +class QueueIterationStartEvent(AppQueueEvent): + """ + QueueIterationStartEvent entity + """ + + event: QueueEvent = QueueEvent.ITERATION_START + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + start_at: datetime + + node_run_index: int + inputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: Optional[str] = None + metadata: Optional[Mapping[str, Any]] = None + + +class QueueIterationNextEvent(AppQueueEvent): + """ + QueueIterationNextEvent entity + """ + + event: QueueEvent = QueueEvent.ITERATION_NEXT + + index: int + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" + node_run_index: int + output: Optional[Any] = None # output for the current iteration + duration: Optional[float] = None + + +class QueueIterationCompletedEvent(AppQueueEvent): + """ + QueueIterationCompletedEvent entity + """ + + event: QueueEvent = QueueEvent.ITERATION_COMPLETED + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + start_at: datetime + + node_run_index: int + inputs: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + metadata: Optional[Mapping[str, Any]] = None + steps: int = 0 + + error: Optional[str] = None + + +class QueueTextChunkEvent(AppQueueEvent): + """ + QueueTextChunkEvent entity + """ + + event: QueueEvent = QueueEvent.TEXT_CHUNK + text: str + from_variable_selector: Optional[list[str]] = None + """from variable selector""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class QueueAgentMessageEvent(AppQueueEvent): + """ + QueueMessageEvent entity + """ + + event: QueueEvent = QueueEvent.AGENT_MESSAGE + chunk: LLMResultChunk + + +class QueueMessageReplaceEvent(AppQueueEvent): + """ + QueueMessageReplaceEvent entity + """ + + event: QueueEvent = QueueEvent.MESSAGE_REPLACE + text: str + + +class QueueRetrieverResourcesEvent(AppQueueEvent): + """ + QueueRetrieverResourcesEvent entity + """ + + event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES + retriever_resources: list[dict] + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class QueueAnnotationReplyEvent(AppQueueEvent): + """ + QueueAnnotationReplyEvent entity + """ + + event: QueueEvent = QueueEvent.ANNOTATION_REPLY + message_annotation_id: str + + +class QueueMessageEndEvent(AppQueueEvent): + """ + QueueMessageEndEvent entity + """ + + event: QueueEvent = QueueEvent.MESSAGE_END + llm_result: Optional[LLMResult] = None + + +class QueueAdvancedChatMessageEndEvent(AppQueueEvent): + """ + QueueAdvancedChatMessageEndEvent entity + """ + + event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END + + +class QueueWorkflowStartedEvent(AppQueueEvent): + """ + QueueWorkflowStartedEvent entity + """ + + event: QueueEvent = QueueEvent.WORKFLOW_STARTED + graph_runtime_state: GraphRuntimeState + + +class QueueWorkflowSucceededEvent(AppQueueEvent): + """ + QueueWorkflowSucceededEvent entity + """ + + event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED + outputs: Optional[dict[str, Any]] = None + + +class QueueWorkflowFailedEvent(AppQueueEvent): + """ + QueueWorkflowFailedEvent entity + """ + + event: QueueEvent = QueueEvent.WORKFLOW_FAILED + error: str + exceptions_count: int + + +class QueueWorkflowPartialSuccessEvent(AppQueueEvent): + """ + QueueWorkflowFailedEvent entity + """ + + event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED + exceptions_count: int + outputs: Optional[dict[str, Any]] = None + + +class QueueNodeStartedEvent(AppQueueEvent): + """ + QueueNodeStartedEvent entity + """ + + event: QueueEvent = QueueEvent.NODE_STARTED + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + node_run_index: int = 1 + predecessor_node_id: Optional[str] = None + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" + + +class QueueNodeSucceededEvent(AppQueueEvent): + """ + QueueNodeSucceededEvent entity + """ + + event: QueueEvent = QueueEvent.NODE_SUCCEEDED + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + + error: Optional[str] = None + """single iteration duration map""" + iteration_duration_map: Optional[dict[str, float]] = None + + +class QueueNodeRetryEvent(QueueNodeStartedEvent): + """QueueNodeRetryEvent entity""" + + event: QueueEvent = QueueEvent.RETRY + + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + + error: str + retry_index: int # retry index + + +class QueueNodeInIterationFailedEvent(AppQueueEvent): + """ + QueueNodeInIterationFailedEvent entity + """ + + event: QueueEvent = QueueEvent.NODE_FAILED + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + + error: str + + +class QueueNodeExceptionEvent(AppQueueEvent): + """ + QueueNodeExceptionEvent entity + """ + + event: QueueEvent = QueueEvent.NODE_EXCEPTION + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + + error: str + + +class QueueNodeFailedEvent(AppQueueEvent): + """ + QueueNodeFailedEvent entity + """ + + event: QueueEvent = QueueEvent.NODE_FAILED + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + + error: str + + +class QueueAgentThoughtEvent(AppQueueEvent): + """ + QueueAgentThoughtEvent entity + """ + + event: QueueEvent = QueueEvent.AGENT_THOUGHT + agent_thought_id: str + + +class QueueMessageFileEvent(AppQueueEvent): + """ + QueueAgentThoughtEvent entity + """ + + event: QueueEvent = QueueEvent.MESSAGE_FILE + message_file_id: str + + +class QueueErrorEvent(AppQueueEvent): + """ + QueueErrorEvent entity + """ + + event: QueueEvent = QueueEvent.ERROR + error: Any = None + + +class QueuePingEvent(AppQueueEvent): + """ + QueuePingEvent entity + """ + + event: QueueEvent = QueueEvent.PING + + +class QueueStopEvent(AppQueueEvent): + """ + QueueStopEvent entity + """ + + class StopBy(Enum): + """ + Stop by enum + """ + + USER_MANUAL = "user-manual" + ANNOTATION_REPLY = "annotation-reply" + OUTPUT_MODERATION = "output-moderation" + INPUT_MODERATION = "input-moderation" + + event: QueueEvent = QueueEvent.STOP + stopped_by: StopBy + + def get_stop_reason(self) -> str: + """ + To stop reason + """ + reason_mapping = { + QueueStopEvent.StopBy.USER_MANUAL: "Stopped by user.", + QueueStopEvent.StopBy.ANNOTATION_REPLY: "Stopped by annotation reply.", + QueueStopEvent.StopBy.OUTPUT_MODERATION: "Stopped by output moderation.", + QueueStopEvent.StopBy.INPUT_MODERATION: "Stopped by input moderation.", + } + + return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.") + + +class QueueMessage(BaseModel): + """ + QueueMessage abstract entity + """ + + task_id: str + app_mode: str + event: AppQueueEvent + + +class MessageQueueMessage(QueueMessage): + """ + MessageQueueMessage entity + """ + + message_id: str + conversation_id: str + + +class WorkflowQueueMessage(QueueMessage): + """ + WorkflowQueueMessage entity + """ + + pass + + +class QueueParallelBranchRunStartedEvent(AppQueueEvent): + """ + QueueParallelBranchRunStartedEvent entity + """ + + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class QueueParallelBranchRunSucceededEvent(AppQueueEvent): + """ + QueueParallelBranchRunSucceededEvent entity + """ + + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class QueueParallelBranchRunFailedEvent(AppQueueEvent): + """ + QueueParallelBranchRunFailedEvent entity + """ + + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py new file mode 100644 index 0000000000000000000000000000000000000000..5e845eba2da1d319c3f38db636d469a31f7bcceb --- /dev/null +++ b/api/core/app/entities/task_entities.py @@ -0,0 +1,698 @@ +from collections.abc import Mapping, Sequence +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel, ConfigDict + +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.utils.encoders import jsonable_encoder +from models.workflow import WorkflowNodeExecutionStatus + + +class TaskState(BaseModel): + """ + TaskState entity + """ + + metadata: dict = {} + + +class EasyUITaskState(TaskState): + """ + EasyUITaskState entity + """ + + llm_result: LLMResult + + +class WorkflowTaskState(TaskState): + """ + WorkflowTaskState entity + """ + + answer: str = "" + + +class StreamEvent(Enum): + """ + Stream event + """ + + PING = "ping" + ERROR = "error" + MESSAGE = "message" + MESSAGE_END = "message_end" + TTS_MESSAGE = "tts_message" + TTS_MESSAGE_END = "tts_message_end" + MESSAGE_FILE = "message_file" + MESSAGE_REPLACE = "message_replace" + AGENT_THOUGHT = "agent_thought" + AGENT_MESSAGE = "agent_message" + WORKFLOW_STARTED = "workflow_started" + WORKFLOW_FINISHED = "workflow_finished" + NODE_STARTED = "node_started" + NODE_FINISHED = "node_finished" + NODE_RETRY = "node_retry" + PARALLEL_BRANCH_STARTED = "parallel_branch_started" + PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" + ITERATION_STARTED = "iteration_started" + ITERATION_NEXT = "iteration_next" + ITERATION_COMPLETED = "iteration_completed" + TEXT_CHUNK = "text_chunk" + TEXT_REPLACE = "text_replace" + + +class StreamResponse(BaseModel): + """ + StreamResponse entity + """ + + event: StreamEvent + task_id: str + + def to_dict(self): + return jsonable_encoder(self) + + +class ErrorStreamResponse(StreamResponse): + """ + ErrorStreamResponse entity + """ + + event: StreamEvent = StreamEvent.ERROR + err: Exception + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class MessageStreamResponse(StreamResponse): + """ + MessageStreamResponse entity + """ + + event: StreamEvent = StreamEvent.MESSAGE + id: str + answer: str + from_variable_selector: Optional[list[str]] = None + + +class MessageAudioStreamResponse(StreamResponse): + """ + MessageStreamResponse entity + """ + + event: StreamEvent = StreamEvent.TTS_MESSAGE + audio: str + + +class MessageAudioEndStreamResponse(StreamResponse): + """ + MessageStreamResponse entity + """ + + event: StreamEvent = StreamEvent.TTS_MESSAGE_END + audio: str + + +class MessageEndStreamResponse(StreamResponse): + """ + MessageEndStreamResponse entity + """ + + event: StreamEvent = StreamEvent.MESSAGE_END + id: str + metadata: dict = {} + files: Optional[Sequence[Mapping[str, Any]]] = None + + +class MessageFileStreamResponse(StreamResponse): + """ + MessageFileStreamResponse entity + """ + + event: StreamEvent = StreamEvent.MESSAGE_FILE + id: str + type: str + belongs_to: str + url: str + + +class MessageReplaceStreamResponse(StreamResponse): + """ + MessageReplaceStreamResponse entity + """ + + event: StreamEvent = StreamEvent.MESSAGE_REPLACE + answer: str + + +class AgentThoughtStreamResponse(StreamResponse): + """ + AgentThoughtStreamResponse entity + """ + + event: StreamEvent = StreamEvent.AGENT_THOUGHT + id: str + position: int + thought: Optional[str] = None + observation: Optional[str] = None + tool: Optional[str] = None + tool_labels: Optional[dict] = None + tool_input: Optional[str] = None + message_files: Optional[list[str]] = None + + +class AgentMessageStreamResponse(StreamResponse): + """ + AgentMessageStreamResponse entity + """ + + event: StreamEvent = StreamEvent.AGENT_MESSAGE + id: str + answer: str + + +class WorkflowStartStreamResponse(StreamResponse): + """ + WorkflowStartStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + workflow_id: str + sequence_number: int + inputs: dict + created_at: int + + event: StreamEvent = StreamEvent.WORKFLOW_STARTED + workflow_run_id: str + data: Data + + +class WorkflowFinishStreamResponse(StreamResponse): + """ + WorkflowFinishStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + workflow_id: str + sequence_number: int + status: str + outputs: Optional[dict] = None + error: Optional[str] = None + elapsed_time: float + total_tokens: int + total_steps: int + created_by: Optional[dict] = None + created_at: int + finished_at: int + exceptions_count: Optional[int] = 0 + files: Optional[Sequence[Mapping[str, Any]]] = [] + + event: StreamEvent = StreamEvent.WORKFLOW_FINISHED + workflow_run_id: str + data: Data + + +class NodeStartStreamResponse(StreamResponse): + """ + NodeStartStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + node_id: str + node_type: str + title: str + index: int + predecessor_node_id: Optional[str] = None + inputs: Optional[dict] = None + created_at: int + extras: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + parallel_run_id: Optional[str] = None + + event: StreamEvent = StreamEvent.NODE_STARTED + workflow_run_id: str + data: Data + + def to_ignore_detail_dict(self): + return { + "event": self.event.value, + "task_id": self.task_id, + "workflow_run_id": self.workflow_run_id, + "data": { + "id": self.data.id, + "node_id": self.data.node_id, + "node_type": self.data.node_type, + "title": self.data.title, + "index": self.data.index, + "predecessor_node_id": self.data.predecessor_node_id, + "inputs": None, + "created_at": self.data.created_at, + "extras": {}, + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, + }, + } + + +class NodeFinishStreamResponse(StreamResponse): + """ + NodeFinishStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + node_id: str + node_type: str + title: str + index: int + predecessor_node_id: Optional[str] = None + inputs: Optional[dict] = None + process_data: Optional[dict] = None + outputs: Optional[dict] = None + status: str + error: Optional[str] = None + elapsed_time: float + execution_metadata: Optional[dict] = None + created_at: int + finished_at: int + files: Optional[Sequence[Mapping[str, Any]]] = [] + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + + event: StreamEvent = StreamEvent.NODE_FINISHED + workflow_run_id: str + data: Data + + def to_ignore_detail_dict(self): + return { + "event": self.event.value, + "task_id": self.task_id, + "workflow_run_id": self.workflow_run_id, + "data": { + "id": self.data.id, + "node_id": self.data.node_id, + "node_type": self.data.node_type, + "title": self.data.title, + "index": self.data.index, + "predecessor_node_id": self.data.predecessor_node_id, + "inputs": None, + "process_data": None, + "outputs": None, + "status": self.data.status, + "error": None, + "elapsed_time": self.data.elapsed_time, + "execution_metadata": None, + "created_at": self.data.created_at, + "finished_at": self.data.finished_at, + "files": [], + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, + }, + } + + +class NodeRetryStreamResponse(StreamResponse): + """ + NodeFinishStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + node_id: str + node_type: str + title: str + index: int + predecessor_node_id: Optional[str] = None + inputs: Optional[dict] = None + process_data: Optional[dict] = None + outputs: Optional[dict] = None + status: str + error: Optional[str] = None + elapsed_time: float + execution_metadata: Optional[dict] = None + created_at: int + finished_at: int + files: Optional[Sequence[Mapping[str, Any]]] = [] + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + retry_index: int = 0 + + event: StreamEvent = StreamEvent.NODE_RETRY + workflow_run_id: str + data: Data + + def to_ignore_detail_dict(self): + return { + "event": self.event.value, + "task_id": self.task_id, + "workflow_run_id": self.workflow_run_id, + "data": { + "id": self.data.id, + "node_id": self.data.node_id, + "node_type": self.data.node_type, + "title": self.data.title, + "index": self.data.index, + "predecessor_node_id": self.data.predecessor_node_id, + "inputs": None, + "process_data": None, + "outputs": None, + "status": self.data.status, + "error": None, + "elapsed_time": self.data.elapsed_time, + "execution_metadata": None, + "created_at": self.data.created_at, + "finished_at": self.data.finished_at, + "files": [], + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, + "retry_index": self.data.retry_index, + }, + } + + +class ParallelBranchStartStreamResponse(StreamResponse): + """ + ParallelBranchStartStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + parallel_id: str + parallel_branch_id: str + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + created_at: int + + event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED + workflow_run_id: str + data: Data + + +class ParallelBranchFinishedStreamResponse(StreamResponse): + """ + ParallelBranchFinishedStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + parallel_id: str + parallel_branch_id: str + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + status: str + error: Optional[str] = None + created_at: int + + event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED + workflow_run_id: str + data: Data + + +class IterationNodeStartStreamResponse(StreamResponse): + """ + NodeStartStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + node_id: str + node_type: str + title: str + created_at: int + extras: dict = {} + metadata: Mapping = {} + inputs: Mapping = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + + event: StreamEvent = StreamEvent.ITERATION_STARTED + workflow_run_id: str + data: Data + + +class IterationNodeNextStreamResponse(StreamResponse): + """ + NodeStartStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + node_id: str + node_type: str + title: str + index: int + created_at: int + pre_iteration_output: Optional[Any] = None + extras: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parallel_mode_run_id: Optional[str] = None + duration: Optional[float] = None + + event: StreamEvent = StreamEvent.ITERATION_NEXT + workflow_run_id: str + data: Data + + +class IterationNodeCompletedStreamResponse(StreamResponse): + """ + NodeCompletedStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + node_id: str + node_type: str + title: str + outputs: Optional[Mapping] = None + created_at: int + extras: Optional[dict] = None + inputs: Optional[Mapping] = None + status: WorkflowNodeExecutionStatus + error: Optional[str] = None + elapsed_time: float + total_tokens: int + execution_metadata: Optional[Mapping] = None + finished_at: int + steps: int + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + + event: StreamEvent = StreamEvent.ITERATION_COMPLETED + workflow_run_id: str + data: Data + + +class TextChunkStreamResponse(StreamResponse): + """ + TextChunkStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + text: str + from_variable_selector: Optional[list[str]] = None + + event: StreamEvent = StreamEvent.TEXT_CHUNK + data: Data + + +class TextReplaceStreamResponse(StreamResponse): + """ + TextReplaceStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + text: str + + event: StreamEvent = StreamEvent.TEXT_REPLACE + data: Data + + +class PingStreamResponse(StreamResponse): + """ + PingStreamResponse entity + """ + + event: StreamEvent = StreamEvent.PING + + +class AppStreamResponse(BaseModel): + """ + AppStreamResponse entity + """ + + stream_response: StreamResponse + + +class ChatbotAppStreamResponse(AppStreamResponse): + """ + ChatbotAppStreamResponse entity + """ + + conversation_id: str + message_id: str + created_at: int + + +class CompletionAppStreamResponse(AppStreamResponse): + """ + CompletionAppStreamResponse entity + """ + + message_id: str + created_at: int + + +class WorkflowAppStreamResponse(AppStreamResponse): + """ + WorkflowAppStreamResponse entity + """ + + workflow_run_id: Optional[str] = None + + +class AppBlockingResponse(BaseModel): + """ + AppBlockingResponse entity + """ + + task_id: str + + def to_dict(self): + return jsonable_encoder(self) + + +class ChatbotAppBlockingResponse(AppBlockingResponse): + """ + ChatbotAppBlockingResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + mode: str + conversation_id: str + message_id: str + answer: str + metadata: dict = {} + created_at: int + + data: Data + + +class CompletionAppBlockingResponse(AppBlockingResponse): + """ + CompletionAppBlockingResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + mode: str + message_id: str + answer: str + metadata: dict = {} + created_at: int + + data: Data + + +class WorkflowAppBlockingResponse(AppBlockingResponse): + """ + WorkflowAppBlockingResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + workflow_id: str + status: str + outputs: Optional[dict] = None + error: Optional[str] = None + elapsed_time: float + total_tokens: int + total_steps: int + created_at: int + finished_at: int + + workflow_run_id: str + data: Data diff --git a/api/core/app/features/__init__.py b/api/core/app/features/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/features/annotation_reply/__init__.py b/api/core/app/features/annotation_reply/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py new file mode 100644 index 0000000000000000000000000000000000000000..83fd3debad4cf143ebac62032a3377f8757e0ec9 --- /dev/null +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -0,0 +1,89 @@ +import logging +from typing import Optional + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.rag.datasource.vdb.vector_factory import Vector +from extensions.ext_database import db +from models.dataset import Dataset +from models.model import App, AppAnnotationSetting, Message, MessageAnnotation +from services.annotation_service import AppAnnotationService +from services.dataset_service import DatasetCollectionBindingService + +logger = logging.getLogger(__name__) + + +class AnnotationReplyFeature: + def query( + self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom + ) -> Optional[MessageAnnotation]: + """ + Query app annotations to reply + :param app_record: app record + :param message: message + :param query: query + :param user_id: user id + :param invoke_from: invoke from + :return: + """ + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first() + ) + + if not annotation_setting: + return None + + collection_binding_detail = annotation_setting.collection_binding_detail + + try: + score_threshold = annotation_setting.score_threshold or 1 + embedding_provider_name = collection_binding_detail.provider_name + embedding_model_name = collection_binding_detail.model_name + + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_provider_name, embedding_model_name, "annotation" + ) + + dataset = Dataset( + id=app_record.id, + tenant_id=app_record.tenant_id, + indexing_technique="high_quality", + embedding_model_provider=embedding_provider_name, + embedding_model=embedding_model_name, + collection_binding_id=dataset_collection_binding.id, + ) + + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + + documents = vector.search_by_vector( + query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]} + ) + + if documents and documents[0].metadata: + annotation_id = documents[0].metadata["annotation_id"] + score = documents[0].metadata["score"] + annotation = AppAnnotationService.get_annotation_by_id(annotation_id) + if annotation: + if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}: + from_source = "api" + else: + from_source = "console" + + # insert annotation history + AppAnnotationService.add_annotation_history( + annotation.id, + app_record.id, + annotation.question, + annotation.content, + query, + user_id, + message.id, + from_source, + score, + ) + + return annotation + except Exception as e: + logger.warning(f"Query annotation failed, exception: {str(e)}.") + return None + + return None diff --git a/api/core/app/features/hosting_moderation/__init__.py b/api/core/app/features/hosting_moderation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py new file mode 100644 index 0000000000000000000000000000000000000000..ba14b61201e72ffc5009bbca6992875021eb7e83 --- /dev/null +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -0,0 +1,29 @@ +import logging + +from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity +from core.helper import moderation +from core.model_runtime.entities.message_entities import PromptMessage + +logger = logging.getLogger(__name__) + + +class HostingModerationFeature: + def check( + self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage] + ) -> bool: + """ + Check hosting moderation + :param application_generate_entity: application generate entity + :param prompt_messages: prompt messages + :return: + """ + model_config = application_generate_entity.model_conf + + text = "" + for prompt_message in prompt_messages: + if isinstance(prompt_message.content, str): + text += prompt_message.content + "\n" + + moderation_result = moderation.check_moderation(model_config, text) + + return moderation_result diff --git a/api/core/app/features/rate_limiting/__init__.py b/api/core/app/features/rate_limiting/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6624f6ad9d15782356add300bd00f4e4b9a8e0fe --- /dev/null +++ b/api/core/app/features/rate_limiting/__init__.py @@ -0,0 +1 @@ +from .rate_limit import RateLimit diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc2b4e55f6ae1e122583c9371834f1f6c049b62 --- /dev/null +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -0,0 +1,122 @@ +import logging +import time +import uuid +from collections.abc import Generator, Mapping +from datetime import timedelta +from typing import Any, Optional, Union + +from core.errors.error import AppInvokeQuotaExceededError +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) + + +class RateLimit: + _MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests" + _ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests" + _UNLIMITED_REQUEST_ID = "unlimited_request_id" + _REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes + _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes + _instance_dict: dict[str, "RateLimit"] = {} + + def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): + if client_id not in cls._instance_dict: + instance = super().__new__(cls) + cls._instance_dict[client_id] = instance + return cls._instance_dict[client_id] + + def __init__(self, client_id: str, max_active_requests: int): + self.max_active_requests = max_active_requests + if hasattr(self, "initialized"): + return + self.initialized = True + self.client_id = client_id + self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id) + self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id) + self.last_recalculate_time = float("-inf") + self.flush_cache(use_local_value=True) + + def flush_cache(self, use_local_value=False): + self.last_recalculate_time = time.time() + # flush max active requests + if use_local_value or not redis_client.exists(self.max_active_requests_key): + with redis_client.pipeline() as pipe: + pipe.set(self.max_active_requests_key, self.max_active_requests) + pipe.expire(self.max_active_requests_key, timedelta(days=1)) + pipe.execute() + else: + with redis_client.pipeline() as pipe: + self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8")) + redis_client.expire(self.max_active_requests_key, timedelta(days=1)) + + # flush max active requests (in-transit request list) + if not redis_client.exists(self.active_requests_key): + return + request_details = redis_client.hgetall(self.active_requests_key) + redis_client.expire(self.active_requests_key, timedelta(days=1)) + timeout_requests = [ + k + for k, v in request_details.items() + if time.time() - float(v.decode("utf-8")) > RateLimit._REQUEST_MAX_ALIVE_TIME + ] + if timeout_requests: + redis_client.hdel(self.active_requests_key, *timeout_requests) + + def enter(self, request_id: Optional[str] = None) -> str: + if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL: + self.flush_cache() + if self.max_active_requests <= 0: + return RateLimit._UNLIMITED_REQUEST_ID + if not request_id: + request_id = RateLimit.gen_request_key() + + active_requests_count = redis_client.hlen(self.active_requests_key) + if active_requests_count >= self.max_active_requests: + raise AppInvokeQuotaExceededError( + "Too many requests. Please try again later. The current maximum " + "concurrent requests allowed is {}.".format(self.max_active_requests) + ) + redis_client.hset(self.active_requests_key, request_id, str(time.time())) + return request_id + + def exit(self, request_id: str): + if request_id == RateLimit._UNLIMITED_REQUEST_ID: + return + redis_client.hdel(self.active_requests_key, request_id) + + @staticmethod + def gen_request_key() -> str: + return str(uuid.uuid4()) + + def generate(self, generator: Union[Generator[str, None, None], Mapping[str, Any]], request_id: str): + if isinstance(generator, Mapping): + return generator + else: + return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id) + + +class RateLimitGenerator: + def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str): + self.rate_limit = rate_limit + self.generator = generator + self.request_id = request_id + self.closed = False + + def __iter__(self): + return self + + def __next__(self): + if self.closed: + raise StopIteration + try: + return next(self.generator) + except Exception: + self.close() + raise + + def close(self): + if not self.closed: + self.closed = True + self.rate_limit.exit(self.request_id) + if self.generator is not None and hasattr(self.generator, "close"): + self.generator.close() diff --git a/api/core/app/task_pipeline/__init__.py b/api/core/app/task_pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..a2e06d4e1ff492475e8b41d204910b64f5851468 --- /dev/null +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -0,0 +1,137 @@ +import logging +import time +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import ( + AppGenerateEntity, +) +from core.app.entities.queue_entities import ( + QueueErrorEvent, +) +from core.app.entities.task_entities import ( + ErrorStreamResponse, + PingStreamResponse, +) +from core.errors.error import QuotaExceededError +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.moderation.output_moderation import ModerationRule, OutputModeration +from models.model import Message + +logger = logging.getLogger(__name__) + + +class BasedGenerateTaskPipeline: + """ + BasedGenerateTaskPipeline is a class that generate stream output and state management for Application. + """ + + def __init__( + self, + application_generate_entity: AppGenerateEntity, + queue_manager: AppQueueManager, + stream: bool, + ) -> None: + self._application_generate_entity = application_generate_entity + self._queue_manager = queue_manager + self._start_at = time.perf_counter() + self._output_moderation_handler = self._init_output_moderation() + self._stream = stream + + def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): + logger.debug("error: %s", event.error) + e = event.error + err: Exception + + if isinstance(e, InvokeAuthorizationError): + err = InvokeAuthorizationError("Incorrect API key provided") + elif isinstance(e, InvokeError | ValueError): + err = e + else: + err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) + + if not message_id or not session: + return err + + stmt = select(Message).where(Message.id == message_id) + message = session.scalar(stmt) + if not message: + return err + + err_desc = self._error_to_desc(err) + message.status = "error" + message.error = err_desc + return err + + def _error_to_desc(self, e: Exception) -> str: + """ + Error to desc. + :param e: exception + :return: + """ + if isinstance(e, QuotaExceededError): + return ( + "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) + + message = getattr(e, "description", str(e)) + if not message: + message = "Internal Server Error, please contact support." + + return message + + def _error_to_stream_response(self, e: Exception): + """ + Error to stream response. + :param e: exception + :return: + """ + return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e) + + def _ping_stream_response(self) -> PingStreamResponse: + """ + Ping stream response. + :return: + """ + return PingStreamResponse(task_id=self._application_generate_entity.task_id) + + def _init_output_moderation(self) -> Optional[OutputModeration]: + """ + Init output moderation. + :return: + """ + app_config = self._application_generate_entity.app_config + sensitive_word_avoidance = app_config.sensitive_word_avoidance + + if sensitive_word_avoidance: + return OutputModeration( + tenant_id=app_config.tenant_id, + app_id=app_config.app_id, + rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config), + queue_manager=self._queue_manager, + ) + return None + + def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: + """ + Handle output moderation when task finished. + :param completion: completion + :return: + """ + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + completion = self._output_moderation_handler.moderation_completion( + completion=completion, public_event=False + ) + + self._output_moderation_handler = None + + return completion + + return None diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..c84f8ba3e450ccee028e70f1cd835cca1dc5bf95 --- /dev/null +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -0,0 +1,505 @@ +import json +import logging +import time +from collections.abc import Generator +from threading import Thread +from typing import Optional, Union, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME +from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import ( + AgentChatAppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, +) +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueAgentThoughtEvent, + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, +) +from core.app.entities.task_entities import ( + AgentMessageStreamResponse, + AgentThoughtStreamResponse, + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + CompletionAppBlockingResponse, + CompletionAppStreamResponse, + EasyUITaskState, + ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, + MessageEndStreamResponse, + StreamResponse, +) +from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline +from core.app.task_pipeline.message_cycle_manage import MessageCycleManage +from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, +) +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from events.message_event import message_was_created +from extensions.ext_database import db +from models.model import AppMode, Conversation, Message, MessageAgentThought + +logger = logging.getLogger(__name__) + + +class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage): + """ + EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. + """ + + _task_state: EasyUITaskState + _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] + + def __init__( + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + stream: bool, + ) -> None: + super().__init__( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) + self._model_config = application_generate_entity.model_conf + self._app_config = application_generate_entity.app_config + + self._conversation_id = conversation.id + self._conversation_mode = conversation.mode + + self._message_id = message.id + self._message_created_at = int(message.created_at.timestamp()) + + self._task_state = EasyUITaskState( + llm_result=LLMResult( + model=self._model_config.model, + prompt_messages=[], + message=AssistantPromptMessage(content=""), + usage=LLMUsage.empty_usage(), + ) + ) + + self._conversation_name_generate_thread: Optional[Thread] = None + + def process( + self, + ) -> Union[ + ChatbotAppBlockingResponse, + CompletionAppBlockingResponse, + Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], + ]: + if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: + # start generate conversation name thread + self._conversation_name_generate_thread = self._generate_conversation_name( + conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" + ) + + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) + if self._stream: + return self._to_stream_response(generator) + else: + return self._to_blocking_response(generator) + + def _to_blocking_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]: + """ + Process blocking response. + :return: + """ + for stream_response in generator: + if isinstance(stream_response, ErrorStreamResponse): + raise stream_response.err + elif isinstance(stream_response, MessageEndStreamResponse): + extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} + if self._task_state.metadata: + extras["metadata"] = self._task_state.metadata + response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] + if self._conversation_mode == AppMode.COMPLETION.value: + response = CompletionAppBlockingResponse( + task_id=self._application_generate_entity.task_id, + data=CompletionAppBlockingResponse.Data( + id=self._message_id, + mode=self._conversation_mode, + message_id=self._message_id, + answer=cast(str, self._task_state.llm_result.message.content), + created_at=self._message_created_at, + **extras, + ), + ) + else: + response = ChatbotAppBlockingResponse( + task_id=self._application_generate_entity.task_id, + data=ChatbotAppBlockingResponse.Data( + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, + answer=cast(str, self._task_state.llm_result.message.content), + created_at=self._message_created_at, + **extras, + ), + ) + + return response + else: + continue + + raise RuntimeError("queue listening stopped unexpectedly.") + + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]: + """ + To stream response. + :return: + """ + for stream_response in generator: + if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): + yield CompletionAppStreamResponse( + message_id=self._message_id, + created_at=self._message_created_at, + stream_response=stream_response, + ) + else: + yield ChatbotAppStreamResponse( + conversation_id=self._conversation_id, + message_id=self._message_id, + created_at=self._message_created_at, + stream_response=stream_response, + ) + + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): + if publisher is None: + return None + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": + # audio_str = audio_msg.audio.decode('utf-8', errors='ignore') + return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) + return None + + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: + tenant_id = self._application_generate_entity.app_config.tenant_id + task_id = self._application_generate_entity.task_id + publisher = None + text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech") + if ( + text_to_speech_dict + and text_to_speech_dict.get("autoPlay") == "enabled" + and text_to_speech_dict.get("enabled") + ): + publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None)) + for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): + while True: + audio_response = self._listen_audio_msg(publisher, task_id) + if audio_response: + yield audio_response + else: + break + yield response + + start_listener_time = time.time() + # timeout + while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: + if publisher is None: + break + audio = publisher.check_and_get_audio() + if audio is None: + # release cpu + # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) + time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME) + continue + if audio.status == "finish": + break + else: + start_listener_time = time.time() + yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id) + if publisher: + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) + + def _process_stream_response( + self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: + """ + Process stream response. + :return: + """ + for message in self._queue_manager.listen(): + if publisher: + publisher.publish(message) + event = message.event + + if isinstance(event, QueueErrorEvent): + with Session(db.engine) as session: + err = self._handle_error(event=event, session=session, message_id=self._message_id) + session.commit() + yield self._error_to_stream_response(err) + break + elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): + if isinstance(event, QueueMessageEndEvent): + if event.llm_result: + self._task_state.llm_result = event.llm_result + else: + self._handle_stop(event) + + # handle output moderation + output_moderation_answer = self._handle_output_moderation_when_task_finished( + cast(str, self._task_state.llm_result.message.content) + ) + if output_moderation_answer: + self._task_state.llm_result.message.content = output_moderation_answer + yield self._message_replace_to_stream_response(answer=output_moderation_answer) + + with Session(db.engine) as session: + # Save message + self._save_message(session=session, trace_manager=trace_manager) + session.commit() + message_end_resp = self._message_end_to_stream_response() + yield message_end_resp + elif isinstance(event, QueueRetrieverResourcesEvent): + self._handle_retriever_resources(event) + elif isinstance(event, QueueAnnotationReplyEvent): + annotation = self._handle_annotation_reply(event) + if annotation: + self._task_state.llm_result.message.content = annotation.content + elif isinstance(event, QueueAgentThoughtEvent): + agent_thought_response = self._agent_thought_to_stream_response(event) + if agent_thought_response is not None: + yield agent_thought_response + elif isinstance(event, QueueMessageFileEvent): + response = self._message_file_to_stream_response(event) + if response: + yield response + elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): + chunk = event.chunk + delta_text = chunk.delta.message.content + if delta_text is None: + continue + + if not self._task_state.llm_result.prompt_messages: + self._task_state.llm_result.prompt_messages = chunk.prompt_messages + + # handle output moderation chunk + should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text)) + if should_direct_answer: + continue + + current_content = cast(str, self._task_state.llm_result.message.content) + current_content += cast(str, delta_text) + self._task_state.llm_result.message.content = current_content + + if isinstance(event, QueueLLMChunkEvent): + yield self._message_to_stream_response( + answer=cast(str, delta_text), + message_id=self._message_id, + ) + else: + yield self._agent_message_to_stream_response( + answer=cast(str, delta_text), + message_id=self._message_id, + ) + elif isinstance(event, QueueMessageReplaceEvent): + yield self._message_replace_to_stream_response(answer=event.text) + elif isinstance(event, QueuePingEvent): + yield self._ping_stream_response() + else: + continue + if publisher: + publisher.publish(None) + if self._conversation_name_generate_thread: + self._conversation_name_generate_thread.join() + + def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None: + """ + Save message. + :return: + """ + llm_result = self._task_state.llm_result + usage = llm_result.usage + + message_stmt = select(Message).where(Message.id == self._message_id) + message = session.scalar(message_stmt) + if not message: + raise ValueError(f"message {self._message_id} not found") + conversation_stmt = select(Conversation).where(Conversation.id == self._conversation_id) + conversation = session.scalar(conversation_stmt) + if not conversation: + raise ValueError(f"Conversation {self._conversation_id} not found") + + message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + self._model_config.mode, self._task_state.llm_result.prompt_messages + ) + message.message_tokens = usage.prompt_tokens + message.message_unit_price = usage.prompt_unit_price + message.message_price_unit = usage.prompt_price_unit + message.answer = ( + PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip()) + if llm_result.message.content + else "" + ) + message.answer_tokens = usage.completion_tokens + message.answer_unit_price = usage.completion_unit_price + message.answer_price_unit = usage.completion_price_unit + message.provider_response_latency = time.perf_counter() - self._start_at + message.total_price = usage.total_price + message.currency = usage.currency + message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id + ) + ) + + message_was_created.send( + message, + application_generate_entity=self._application_generate_entity, + ) + + def _handle_stop(self, event: QueueStopEvent) -> None: + """ + Handle stop. + :return: + """ + model_config = self._model_config + model = model_config.model + + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + + # calculate num tokens + prompt_tokens = 0 + if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: + prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages) + + completion_tokens = 0 + if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: + completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message]) + + credentials = model_config.credentials + + # transform usage + model_type_instance = model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + self._task_state.llm_result.usage = model_type_instance._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) + + def _message_end_to_stream_response(self) -> MessageEndStreamResponse: + """ + Message end to stream response. + :return: + """ + self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage) + + extras = {} + if self._task_state.metadata: + extras["metadata"] = self._task_state.metadata + + return MessageEndStreamResponse( + task_id=self._application_generate_entity.task_id, + id=self._message_id, + metadata=extras.get("metadata", {}), + ) + + def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: + """ + Agent message to stream response. + :param answer: answer + :param message_id: message id + :return: + """ + return AgentMessageStreamResponse( + task_id=self._application_generate_entity.task_id, id=message_id, answer=answer + ) + + def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]: + """ + Agent thought to stream response. + :param event: agent thought event + :return: + """ + agent_thought: Optional[MessageAgentThought] = ( + db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() + ) + db.session.refresh(agent_thought) + db.session.close() + + if agent_thought: + return AgentThoughtStreamResponse( + task_id=self._application_generate_entity.task_id, + id=agent_thought.id, + position=agent_thought.position, + thought=agent_thought.thought, + observation=agent_thought.observation, + tool=agent_thought.tool, + tool_labels=agent_thought.tool_labels, + tool_input=agent_thought.tool_input, + message_files=agent_thought.files, + ) + + return None + + def _handle_output_moderation_chunk(self, text: str) -> bool: + """ + Handle output moderation chunk. + :param text: text + :return: True if output moderation should direct output, otherwise False + """ + if self._output_moderation_handler: + if self._output_moderation_handler.should_direct_output(): + # stop subscribe new token when output moderation should direct output + self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() + self._queue_manager.publish( + QueueLLMChunkEvent( + chunk=LLMResultChunk( + model=self._task_state.llm_result.model, + prompt_messages=self._task_state.llm_result.prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=self._task_state.llm_result.message.content), + ), + ) + ), + PublishFrom.TASK_PIPELINE, + ) + + self._queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE + ) + return True + else: + self._output_moderation_handler.append_new_token(text) + + return False diff --git a/api/core/app/task_pipeline/exc.py b/api/core/app/task_pipeline/exc.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b4168d0881e03cf4e87eb41d65f947a907dbfc --- /dev/null +++ b/api/core/app/task_pipeline/exc.py @@ -0,0 +1,17 @@ +class TaskPipilineError(ValueError): + pass + + +class RecordNotFoundError(TaskPipilineError): + def __init__(self, record_name: str, record_id: str): + super().__init__(f"{record_name} with id {record_id} not found") + + +class WorkflowRunNotFoundError(RecordNotFoundError): + def __init__(self, workflow_run_id: str): + super().__init__("WorkflowRun", workflow_run_id) + + +class WorkflowNodeExecutionNotFoundError(RecordNotFoundError): + def __init__(self, workflow_node_execution_id: str): + super().__init__("WorkflowNodeExecution", workflow_node_execution_id) diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py new file mode 100644 index 0000000000000000000000000000000000000000..ef3a52442fe3048fb751896b838a1312ae11d91a --- /dev/null +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -0,0 +1,191 @@ +import logging +from threading import Thread +from typing import Optional, Union + +from flask import Flask, current_app + +from configs import dify_config +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + AgentChatAppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, +) +from core.app.entities.queue_entities import ( + QueueAnnotationReplyEvent, + QueueMessageFileEvent, + QueueRetrieverResourcesEvent, +) +from core.app.entities.task_entities import ( + EasyUITaskState, + MessageFileStreamResponse, + MessageReplaceStreamResponse, + MessageStreamResponse, + WorkflowTaskState, +) +from core.llm_generator.llm_generator import LLMGenerator +from core.tools.tool_file_manager import ToolFileManager +from extensions.ext_database import db +from models.model import AppMode, Conversation, MessageAnnotation, MessageFile +from services.annotation_service import AppAnnotationService + + +class MessageCycleManage: + def __init__( + self, + *, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + ], + task_state: Union[EasyUITaskState, WorkflowTaskState], + ) -> None: + self._application_generate_entity = application_generate_entity + self._task_state = task_state + + def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: + """ + Generate conversation name. + :param conversation: conversation + :param query: query + :return: thread + """ + if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): + return None + + is_first_message = self._application_generate_entity.conversation_id is None + extras = self._application_generate_entity.extras + auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True) + + if auto_generate_conversation_name and is_first_message: + # start generate thread + thread = Thread( + target=self._generate_conversation_name_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "conversation_id": conversation_id, + "query": query, + }, + ) + + thread.start() + + return thread + + return None + + def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): + with flask_app.app_context(): + # get conversation and message + conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + + if not conversation: + return + + if conversation.mode != AppMode.COMPLETION.value: + app_model = conversation.app + if not app_model: + return + + # generate conversation name + try: + name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query) + conversation.name = name + except Exception as e: + if dify_config.DEBUG: + logging.exception(f"generate conversation name failed, conversation_id: {conversation_id}") + pass + + db.session.merge(conversation) + db.session.commit() + db.session.close() + + def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: + """ + Handle annotation reply. + :param event: event + :return: + """ + annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) + if annotation: + account = annotation.account + self._task_state.metadata["annotation_reply"] = { + "id": annotation.id, + "account": {"id": annotation.account_id, "name": account.name if account else "Dify user"}, + } + + return annotation + + return None + + def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: + """ + Handle retriever resources. + :param event: event + :return: + """ + if self._application_generate_entity.app_config.additional_features.show_retrieve_source: + self._task_state.metadata["retriever_resources"] = event.retriever_resources + + def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: + """ + Message file to stream response. + :param event: event + :return: + """ + message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() + + if message_file and message_file.url is not None: + # get tool file id + tool_file_id = message_file.url.split("/")[-1] + # trim extension + tool_file_id = tool_file_id.split(".")[0] + + # get extension + if "." in message_file.url: + extension = f".{message_file.url.split('.')[-1]}" + if len(extension) > 10: + extension = ".bin" + else: + extension = ".bin" + # add sign url to local file + if message_file.url.startswith("http"): + url = message_file.url + else: + url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension) + + return MessageFileStreamResponse( + task_id=self._application_generate_entity.task_id, + id=message_file.id, + type=message_file.type, + belongs_to=message_file.belongs_to or "user", + url=url, + ) + + return None + + def _message_to_stream_response( + self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None + ) -> MessageStreamResponse: + """ + Message to stream response. + :param answer: answer + :param message_id: message id + :return: + """ + return MessageStreamResponse( + task_id=self._application_generate_entity.task_id, + id=message_id, + answer=answer, + from_variable_selector=from_variable_selector, + ) + + def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse: + """ + Message replace to stream response. + :param answer: answer + :return: + """ + return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc364d22766e626dde529ccd8f6c03fff5a4cfc --- /dev/null +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -0,0 +1,845 @@ +import json +import time +from collections.abc import Mapping, Sequence +from datetime import UTC, datetime +from typing import Any, Optional, Union, cast +from uuid import uuid4 + +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.queue_entities import ( + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, +) +from core.app.entities.task_entities import ( + IterationNodeCompletedStreamResponse, + IterationNodeNextStreamResponse, + IterationNodeStartStreamResponse, + NodeFinishStreamResponse, + NodeRetryStreamResponse, + NodeStartStreamResponse, + ParallelBranchFinishedStreamResponse, + ParallelBranchStartStreamResponse, + WorkflowFinishStreamResponse, + WorkflowStartStreamResponse, +) +from core.file import FILE_MODEL_IDENTITY, File +from core.model_runtime.utils.encoders import jsonable_encoder +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.tools.tool_manager import ToolManager +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.enums import SystemVariableKey +from core.workflow.nodes import NodeType +from core.workflow.nodes.tool.entities import ToolNodeData +from core.workflow.workflow_entry import WorkflowEntry +from models.account import Account +from models.enums import CreatedByRole, WorkflowRunTriggeredFrom +from models.model import EndUser +from models.workflow import ( + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowRunStatus, +) + +from .exc import WorkflowRunNotFoundError + + +class WorkflowCycleManage: + def __init__( + self, + *, + application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], + workflow_system_variables: dict[SystemVariableKey, Any], + ) -> None: + self._workflow_run: WorkflowRun | None = None + self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {} + self._application_generate_entity = application_generate_entity + self._workflow_system_variables = workflow_system_variables + + def _handle_workflow_run_start( + self, + *, + session: Session, + workflow_id: str, + user_id: str, + created_by_role: CreatedByRole, + ) -> WorkflowRun: + workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) + workflow = session.scalar(workflow_stmt) + if not workflow: + raise ValueError(f"Workflow not found: {workflow_id}") + + max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where( + WorkflowRun.tenant_id == workflow.tenant_id, + WorkflowRun.app_id == workflow.app_id, + ) + max_sequence = session.scalar(max_sequence_stmt) or 0 + new_sequence_number = max_sequence + 1 + + inputs = {**self._application_generate_entity.inputs} + for key, value in (self._workflow_system_variables or {}).items(): + if key.value == "conversation": + continue + inputs[f"sys.{key.value}"] = value + + triggered_from = ( + WorkflowRunTriggeredFrom.DEBUGGING + if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER + else WorkflowRunTriggeredFrom.APP_RUN + ) + + # handle special values + inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) + + # init workflow run + # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this + workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4()) + + workflow_run = WorkflowRun() + workflow_run.id = workflow_run_id + workflow_run.tenant_id = workflow.tenant_id + workflow_run.app_id = workflow.app_id + workflow_run.sequence_number = new_sequence_number + workflow_run.workflow_id = workflow.id + workflow_run.type = workflow.type + workflow_run.triggered_from = triggered_from.value + workflow_run.version = workflow.version + workflow_run.graph = workflow.graph + workflow_run.inputs = json.dumps(inputs) + workflow_run.status = WorkflowRunStatus.RUNNING + workflow_run.created_by_role = created_by_role + workflow_run.created_by = user_id + workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) + + session.add(workflow_run) + + return workflow_run + + def _handle_workflow_run_success( + self, + *, + session: Session, + workflow_run_id: str, + start_at: float, + total_tokens: int, + total_steps: int, + outputs: Mapping[str, Any] | None = None, + conversation_id: Optional[str] = None, + trace_manager: Optional[TraceQueueManager] = None, + ) -> WorkflowRun: + """ + Workflow run success + :param workflow_run: workflow run + :param start_at: start time + :param total_tokens: total tokens + :param total_steps: total steps + :param outputs: outputs + :param conversation_id: conversation id + :return: + """ + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) + + outputs = WorkflowEntry.handle_special_values(outputs) + + workflow_run.status = WorkflowRunStatus.SUCCEEDED.value + workflow_run.outputs = json.dumps(outputs or {}) + workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.total_tokens = total_tokens + workflow_run.total_steps = total_steps + workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) + + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.WORKFLOW_TRACE, + workflow_run=workflow_run, + conversation_id=conversation_id, + user_id=trace_manager.user_id, + ) + ) + + return workflow_run + + def _handle_workflow_run_partial_success( + self, + *, + session: Session, + workflow_run_id: str, + start_at: float, + total_tokens: int, + total_steps: int, + outputs: Mapping[str, Any] | None = None, + exceptions_count: int = 0, + conversation_id: Optional[str] = None, + trace_manager: Optional[TraceQueueManager] = None, + ) -> WorkflowRun: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) + outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) + + workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value + workflow_run.outputs = json.dumps(outputs or {}) + workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.total_tokens = total_tokens + workflow_run.total_steps = total_steps + workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) + workflow_run.exceptions_count = exceptions_count + + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.WORKFLOW_TRACE, + workflow_run=workflow_run, + conversation_id=conversation_id, + user_id=trace_manager.user_id, + ) + ) + + return workflow_run + + def _handle_workflow_run_failed( + self, + *, + session: Session, + workflow_run_id: str, + start_at: float, + total_tokens: int, + total_steps: int, + status: WorkflowRunStatus, + error: str, + conversation_id: Optional[str] = None, + trace_manager: Optional[TraceQueueManager] = None, + exceptions_count: int = 0, + ) -> WorkflowRun: + """ + Workflow run failed + :param workflow_run: workflow run + :param start_at: start time + :param total_tokens: total tokens + :param total_steps: total steps + :param status: status + :param error: error message + :return: + """ + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) + + workflow_run.status = status.value + workflow_run.error = error + workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.total_tokens = total_tokens + workflow_run.total_steps = total_steps + workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) + workflow_run.exceptions_count = exceptions_count + + stmt = select(WorkflowNodeExecution.node_execution_id).where( + WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, + WorkflowNodeExecution.app_id == workflow_run.app_id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == workflow_run.id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, + ) + ids = session.scalars(stmt).all() + # Use self._get_workflow_node_execution here to make sure the cache is updated + running_workflow_node_executions = [ + self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id + ] + + for workflow_node_execution in running_workflow_node_executions: + now = datetime.now(UTC).replace(tzinfo=None) + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + workflow_node_execution.finished_at = now + workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds() + + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.WORKFLOW_TRACE, + workflow_run=workflow_run, + conversation_id=conversation_id, + user_id=trace_manager.user_id, + ) + ) + + return workflow_run + + def _handle_node_execution_start( + self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent + ) -> WorkflowNodeExecution: + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = str(uuid4()) + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id + workflow_node_execution.index = event.node_run_index + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.execution_metadata = json.dumps( + { + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + } + ) + workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) + + session.add(workflow_node_execution) + + self._workflow_node_executions[event.node_execution_id] = workflow_node_execution + return workflow_node_execution + + def _handle_workflow_node_execution_success( + self, *, session: Session, event: QueueNodeSucceededEvent + ) -> WorkflowNodeExecution: + workflow_node_execution = self._get_workflow_node_execution( + session=session, node_execution_id=event.node_execution_id + ) + inputs = WorkflowEntry.handle_special_values(event.inputs) + process_data = WorkflowEntry.handle_special_values(event.process_data) + outputs = WorkflowEntry.handle_special_values(event.outputs) + execution_metadata = ( + json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None + ) + finished_at = datetime.now(UTC).replace(tzinfo=None) + elapsed_time = (finished_at - event.start_at).total_seconds() + + process_data = WorkflowEntry.handle_special_values(event.process_data) + + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + workflow_node_execution.inputs = json.dumps(inputs) if inputs else None + workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.execution_metadata = execution_metadata + workflow_node_execution.finished_at = finished_at + workflow_node_execution.elapsed_time = elapsed_time + + workflow_node_execution = session.merge(workflow_node_execution) + return workflow_node_execution + + def _handle_workflow_node_execution_failed( + self, + *, + session: Session, + event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent, + ) -> WorkflowNodeExecution: + """ + Workflow node execution failed + :param event: queue node failed event + :return: + """ + workflow_node_execution = self._get_workflow_node_execution( + session=session, node_execution_id=event.node_execution_id + ) + + inputs = WorkflowEntry.handle_special_values(event.inputs) + process_data = WorkflowEntry.handle_special_values(event.process_data) + outputs = WorkflowEntry.handle_special_values(event.outputs) + finished_at = datetime.now(UTC).replace(tzinfo=None) + elapsed_time = (finished_at - event.start_at).total_seconds() + execution_metadata = ( + json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None + ) + process_data = WorkflowEntry.handle_special_values(event.process_data) + workflow_node_execution.status = ( + WorkflowNodeExecutionStatus.FAILED.value + if not isinstance(event, QueueNodeExceptionEvent) + else WorkflowNodeExecutionStatus.EXCEPTION.value + ) + workflow_node_execution.error = event.error + workflow_node_execution.inputs = json.dumps(inputs) if inputs else None + workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.finished_at = finished_at + workflow_node_execution.elapsed_time = elapsed_time + workflow_node_execution.execution_metadata = execution_metadata + + workflow_node_execution = session.merge(workflow_node_execution) + return workflow_node_execution + + def _handle_workflow_node_execution_retried( + self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent + ) -> WorkflowNodeExecution: + """ + Workflow node execution failed + :param event: queue node failed event + :return: + """ + created_at = event.start_at + finished_at = datetime.now(UTC).replace(tzinfo=None) + elapsed_time = (finished_at - created_at).total_seconds() + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = WorkflowEntry.handle_special_values(event.outputs) + origin_metadata = { + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + } + merged_metadata = ( + {**jsonable_encoder(event.execution_metadata), **origin_metadata} + if event.execution_metadata is not None + else origin_metadata + ) + execution_metadata = json.dumps(merged_metadata) + + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = str(uuid4()) + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.created_at = created_at + workflow_node_execution.finished_at = finished_at + workflow_node_execution.elapsed_time = elapsed_time + workflow_node_execution.error = event.error + workflow_node_execution.inputs = json.dumps(inputs) if inputs else None + workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.execution_metadata = execution_metadata + workflow_node_execution.index = event.node_run_index + + session.add(workflow_node_execution) + + self._workflow_node_executions[event.node_execution_id] = workflow_node_execution + return workflow_node_execution + + ################################################# + # to stream responses # + ################################################# + + def _workflow_start_to_stream_response( + self, + *, + session: Session, + task_id: str, + workflow_run: WorkflowRun, + ) -> WorkflowStartStreamResponse: + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session + return WorkflowStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=WorkflowStartStreamResponse.Data( + id=workflow_run.id, + workflow_id=workflow_run.workflow_id, + sequence_number=workflow_run.sequence_number, + inputs=dict(workflow_run.inputs_dict or {}), + created_at=int(workflow_run.created_at.timestamp()), + ), + ) + + def _workflow_finish_to_stream_response( + self, + *, + session: Session, + task_id: str, + workflow_run: WorkflowRun, + ) -> WorkflowFinishStreamResponse: + created_by = None + if workflow_run.created_by_role == CreatedByRole.ACCOUNT: + stmt = select(Account).where(Account.id == workflow_run.created_by) + account = session.scalar(stmt) + if account: + created_by = { + "id": account.id, + "name": account.name, + "email": account.email, + } + elif workflow_run.created_by_role == CreatedByRole.END_USER: + stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) + end_user = session.scalar(stmt) + if end_user: + created_by = { + "id": end_user.id, + "user": end_user.session_id, + } + else: + raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}") + + return WorkflowFinishStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=WorkflowFinishStreamResponse.Data( + id=workflow_run.id, + workflow_id=workflow_run.workflow_id, + sequence_number=workflow_run.sequence_number, + status=workflow_run.status, + outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None, + error=workflow_run.error, + elapsed_time=workflow_run.elapsed_time, + total_tokens=workflow_run.total_tokens, + total_steps=workflow_run.total_steps, + created_by=created_by, + created_at=int(workflow_run.created_at.timestamp()), + finished_at=int(workflow_run.finished_at.timestamp()), + files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)), + exceptions_count=workflow_run.exceptions_count, + ), + ) + + def _workflow_node_start_to_stream_response( + self, + *, + session: Session, + event: QueueNodeStartedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, + ) -> Optional[NodeStartStreamResponse]: + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session + + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + return None + if not workflow_node_execution.workflow_run_id: + return None + + response = NodeStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_node_execution.workflow_run_id, + data=NodeStartStreamResponse.Data( + id=workflow_node_execution.id, + node_id=workflow_node_execution.node_id, + node_type=workflow_node_execution.node_type, + title=workflow_node_execution.title, + index=workflow_node_execution.index, + predecessor_node_id=workflow_node_execution.predecessor_node_id, + inputs=workflow_node_execution.inputs_dict, + created_at=int(workflow_node_execution.created_at.timestamp()), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + parallel_run_id=event.parallel_mode_run_id, + ), + ) + + # extras logic + if event.node_type == NodeType.TOOL: + node_data = cast(ToolNodeData, event.node_data) + response.data.extras["icon"] = ToolManager.get_tool_icon( + tenant_id=self._application_generate_entity.app_config.tenant_id, + provider_type=node_data.provider_type, + provider_id=node_data.provider_id, + ) + + return response + + def _workflow_node_finish_to_stream_response( + self, + *, + session: Session, + event: QueueNodeSucceededEvent + | QueueNodeFailedEvent + | QueueNodeInIterationFailedEvent + | QueueNodeExceptionEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, + ) -> Optional[NodeFinishStreamResponse]: + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + return None + if not workflow_node_execution.workflow_run_id: + return None + if not workflow_node_execution.finished_at: + return None + + return NodeFinishStreamResponse( + task_id=task_id, + workflow_run_id=workflow_node_execution.workflow_run_id, + data=NodeFinishStreamResponse.Data( + id=workflow_node_execution.id, + node_id=workflow_node_execution.node_id, + node_type=workflow_node_execution.node_type, + index=workflow_node_execution.index, + title=workflow_node_execution.title, + predecessor_node_id=workflow_node_execution.predecessor_node_id, + inputs=workflow_node_execution.inputs_dict, + process_data=workflow_node_execution.process_data_dict, + outputs=workflow_node_execution.outputs_dict, + status=workflow_node_execution.status, + error=workflow_node_execution.error, + elapsed_time=workflow_node_execution.elapsed_time, + execution_metadata=workflow_node_execution.execution_metadata_dict, + created_at=int(workflow_node_execution.created_at.timestamp()), + finished_at=int(workflow_node_execution.finished_at.timestamp()), + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + ), + ) + + def _workflow_node_retry_to_stream_response( + self, + *, + session: Session, + event: QueueNodeRetryEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, + ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + return None + if not workflow_node_execution.workflow_run_id: + return None + if not workflow_node_execution.finished_at: + return None + + return NodeRetryStreamResponse( + task_id=task_id, + workflow_run_id=workflow_node_execution.workflow_run_id, + data=NodeRetryStreamResponse.Data( + id=workflow_node_execution.id, + node_id=workflow_node_execution.node_id, + node_type=workflow_node_execution.node_type, + index=workflow_node_execution.index, + title=workflow_node_execution.title, + predecessor_node_id=workflow_node_execution.predecessor_node_id, + inputs=workflow_node_execution.inputs_dict, + process_data=workflow_node_execution.process_data_dict, + outputs=workflow_node_execution.outputs_dict, + status=workflow_node_execution.status, + error=workflow_node_execution.error, + elapsed_time=workflow_node_execution.elapsed_time, + execution_metadata=workflow_node_execution.execution_metadata_dict, + created_at=int(workflow_node_execution.created_at.timestamp()), + finished_at=int(workflow_node_execution.finished_at.timestamp()), + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + retry_index=event.retry_index, + ), + ) + + def _workflow_parallel_branch_start_to_stream_response( + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent + ) -> ParallelBranchStartStreamResponse: + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session + return ParallelBranchStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=ParallelBranchStartStreamResponse.Data( + parallel_id=event.parallel_id, + parallel_branch_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + created_at=int(time.time()), + ), + ) + + def _workflow_parallel_branch_finished_to_stream_response( + self, + *, + session: Session, + task_id: str, + workflow_run: WorkflowRun, + event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, + ) -> ParallelBranchFinishedStreamResponse: + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session + return ParallelBranchFinishedStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=ParallelBranchFinishedStreamResponse.Data( + parallel_id=event.parallel_id, + parallel_branch_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed", + error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None, + created_at=int(time.time()), + ), + ) + + def _workflow_iteration_start_to_stream_response( + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent + ) -> IterationNodeStartStreamResponse: + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session + return IterationNodeStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeStartStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + metadata=event.metadata or {}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ), + ) + + def _workflow_iteration_next_to_stream_response( + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent + ) -> IterationNodeNextStreamResponse: + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session + return IterationNodeNextStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeNextStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + index=event.index, + pre_iteration_output=event.output, + created_at=int(time.time()), + extras={}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parallel_mode_run_id=event.parallel_mode_run_id, + duration=event.duration, + ), + ) + + def _workflow_iteration_completed_to_stream_response( + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent + ) -> IterationNodeCompletedStreamResponse: + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session + return IterationNodeCompletedStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeCompletedStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + outputs=event.outputs, + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + status=WorkflowNodeExecutionStatus.SUCCEEDED + if event.error is None + else WorkflowNodeExecutionStatus.FAILED, + error=None, + elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(), + total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, + execution_metadata=event.metadata, + finished_at=int(time.time()), + steps=event.steps, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ), + ) + + def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]: + """ + Fetch files from node outputs + :param outputs_dict: node outputs dict + :return: + """ + if not outputs_dict: + return [] + + files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] + # Remove None + files = [file for file in files if file] + # Flatten list + # Flatten the list of sequences into a single list of mappings + flattened_files = [file for sublist in files if sublist for file in sublist] + + # Convert to tuple to match Sequence type + return tuple(flattened_files) + + def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]: + """ + Fetch files from variable value + :param value: variable value + :return: + """ + if not value: + return [] + + files = [] + if isinstance(value, list): + for item in value: + file = self._get_file_var_from_value(item) + if file: + files.append(file) + elif isinstance(value, dict): + file = self._get_file_var_from_value(value) + if file: + files.append(file) + + return files + + def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None: + """ + Get file var from value + :param value: variable value + :return: + """ + if not value: + return None + + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + return value + elif isinstance(value, File): + return value.to_dict() + + return None + + def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: + if self._workflow_run and self._workflow_run.id == workflow_run_id: + cached_workflow_run = self._workflow_run + cached_workflow_run = session.merge(cached_workflow_run) + return cached_workflow_run + stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) + workflow_run = session.scalar(stmt) + if not workflow_run: + raise WorkflowRunNotFoundError(workflow_run_id) + self._workflow_run = workflow_run + + return workflow_run + + def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: + if node_execution_id not in self._workflow_node_executions: + raise ValueError(f"Workflow node execution not found: {node_execution_id}") + cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] + return cached_workflow_node_execution diff --git a/api/core/callback_handler/__init__.py b/api/core/callback_handler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..effc7eff9179ae21c9631859699e2a461d52fe5c --- /dev/null +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -0,0 +1,116 @@ +from collections.abc import Mapping, Sequence +from typing import Any, Optional, TextIO, Union + +from pydantic import BaseModel + +from configs import dify_config +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.tools.entities.tool_entities import ToolInvokeMessage + +_TEXT_COLOR_MAPPING = { + "blue": "36;1", + "yellow": "33;1", + "pink": "38;5;200", + "green": "32;1", + "red": "31;1", +} + + +def get_colored_text(text: str, color: str) -> str: + """Get colored text.""" + color_str = _TEXT_COLOR_MAPPING[color] + return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" + + +def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None: + """Print text with highlighting and no end characters.""" + text_to_print = get_colored_text(text, color) if color else text + print(text_to_print, end=end, file=file) + if file: + file.flush() # ensure all printed content are written to file + + +class DifyAgentCallbackHandler(BaseModel): + """Callback Handler that prints to std out.""" + + color: Optional[str] = "" + current_loop: int = 1 + + def __init__(self, color: Optional[str] = None) -> None: + super().__init__() + """Initialize callback handler.""" + # use a specific color is not specified + self.color = color or "green" + self.current_loop = 1 + + def on_tool_start( + self, + tool_name: str, + tool_inputs: Mapping[str, Any], + ) -> None: + """Do nothing.""" + if dify_config.DEBUG: + print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) + + def on_tool_end( + self, + tool_name: str, + tool_inputs: Mapping[str, Any], + tool_outputs: Sequence[ToolInvokeMessage] | str, + message_id: Optional[str] = None, + timer: Optional[Any] = None, + trace_manager: Optional[TraceQueueManager] = None, + ) -> None: + """If not the final action, print out observation.""" + if dify_config.DEBUG: + print_text("\n[on_tool_end]\n", color=self.color) + print_text("Tool: " + tool_name + "\n", color=self.color) + print_text("Inputs: " + str(tool_inputs) + "\n", color=self.color) + print_text("Outputs: " + str(tool_outputs)[:1000] + "\n", color=self.color) + print_text("\n") + + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.TOOL_TRACE, + message_id=message_id, + tool_name=tool_name, + tool_inputs=tool_inputs, + tool_outputs=tool_outputs, + timer=timer, + ) + ) + + def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: + """Do nothing.""" + if dify_config.DEBUG: + print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red") + + def on_agent_start(self, thought: str) -> None: + """Run on agent start.""" + if dify_config.DEBUG: + if thought: + print_text( + "\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\nThought: " + thought + "\n", + color=self.color, + ) + else: + print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) + + def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None: + """Run on agent end.""" + if dify_config.DEBUG: + print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) + + self.current_loop += 1 + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return not dify_config.DEBUG + + @property + def ignore_chat_model(self) -> bool: + """Whether to ignore chat model callbacks.""" + return not dify_config.DEBUG diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8aaa93d6f98628cec65edfbcd754fca98f5923 --- /dev/null +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -0,0 +1,83 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueRetrieverResourcesEvent +from core.rag.models.document import Document +from extensions.ext_database import db +from models.dataset import DatasetQuery, DocumentSegment +from models.model import DatasetRetrieverResource + + +class DatasetIndexToolCallbackHandler: + """Callback handler for dataset tool.""" + + def __init__( + self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom + ) -> None: + self._queue_manager = queue_manager + self._app_id = app_id + self._message_id = message_id + self._user_id = user_id + self._invoke_from = invoke_from + + def on_query(self, query: str, dataset_id: str) -> None: + """ + Handle query. + """ + dataset_query = DatasetQuery( + dataset_id=dataset_id, + content=query, + source="app", + source_app_id=self._app_id, + created_by_role=( + "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" + ), + created_by=self._user_id, + ) + + db.session.add(dataset_query) + db.session.commit() + + def on_tool_end(self, documents: list[Document]) -> None: + """Handle tool end.""" + for document in documents: + if document.metadata is not None: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata["doc_id"] + ) + + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + + # add hit count to document segment + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + + db.session.commit() + + def return_retriever_resource_info(self, resource: list): + """Handle return_retriever_resource_info.""" + if resource and len(resource) > 0: + for item in resource: + dataset_retriever_resource = DatasetRetrieverResource( + message_id=self._message_id, + position=item.get("position") or 0, + dataset_id=item.get("dataset_id"), + dataset_name=item.get("dataset_name"), + document_id=item.get("document_id"), + document_name=item.get("document_name"), + data_source_type=item.get("data_source_type"), + segment_id=item.get("segment_id"), + score=item.get("score") if "score" in item else None, + hit_count=item.get("hit_count") if "hit_count" in item else None, + word_count=item.get("word_count") if "word_count" in item else None, + segment_position=item.get("segment_position") if "segment_position" in item else None, + index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None, + content=item.get("content"), + retriever_from=item.get("retriever_from"), + created_by=self._user_id, + ) + db.session.add(dataset_retriever_resource) + db.session.commit() + + self._queue_manager.publish( + QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER + ) diff --git a/api/core/callback_handler/workflow_tool_callback_handler.py b/api/core/callback_handler/workflow_tool_callback_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..8ac12f72f29d6c0d91fcd5c0b069280e4ec0209b --- /dev/null +++ b/api/core/callback_handler/workflow_tool_callback_handler.py @@ -0,0 +1,5 @@ +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler + + +class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): + """Callback Handler that prints to std out.""" diff --git a/api/core/entities/__init__.py b/api/core/entities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/entities/agent_entities.py b/api/core/entities/agent_entities.py new file mode 100644 index 0000000000000000000000000000000000000000..656bf4aa724893c072668adb4ae6f3d9dda18a3f --- /dev/null +++ b/api/core/entities/agent_entities.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class PlanningStrategy(Enum): + ROUTER = "router" + REACT_ROUTER = "react_router" + REACT = "react" + FUNCTION_CALL = "function_call" diff --git a/api/core/entities/embedding_type.py b/api/core/entities/embedding_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9b4934646bc0e8b011b126d688eba5e442049645 --- /dev/null +++ b/api/core/entities/embedding_type.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class EmbeddingInputType(Enum): + """ + Enum for embedding input type. + """ + + DOCUMENT = "document" + QUERY = "query" diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py new file mode 100644 index 0000000000000000000000000000000000000000..90c98797338270c9701c6583166c1f9fe753e6bb --- /dev/null +++ b/api/core/entities/knowledge_entities.py @@ -0,0 +1,19 @@ +from typing import Optional + +from pydantic import BaseModel + + +class PreviewDetail(BaseModel): + content: str + child_chunks: Optional[list[str]] = None + + +class QAPreviewDetail(BaseModel): + question: str + answer: str + + +class IndexingEstimate(BaseModel): + total_segments: int + preview: list[PreviewDetail] + qa_preview: Optional[list[QAPreviewDetail]] = None diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py new file mode 100644 index 0000000000000000000000000000000000000000..5017835565789cb1216f26a65ee1682060f3e225 --- /dev/null +++ b/api/core/entities/model_entities.py @@ -0,0 +1,89 @@ +from collections.abc import Sequence +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, ConfigDict + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ModelType, ProviderModel +from core.model_runtime.entities.provider_entities import ProviderEntity + + +class ModelStatus(Enum): + """ + Enum class for model status. + """ + + ACTIVE = "active" + NO_CONFIGURE = "no-configure" + QUOTA_EXCEEDED = "quota-exceeded" + NO_PERMISSION = "no-permission" + DISABLED = "disabled" + + +class SimpleModelProviderEntity(BaseModel): + """ + Simple provider. + """ + + provider: str + label: I18nObject + icon_small: Optional[I18nObject] = None + icon_large: Optional[I18nObject] = None + supported_model_types: list[ModelType] + + def __init__(self, provider_entity: ProviderEntity) -> None: + """ + Init simple provider. + + :param provider_entity: provider entity + """ + super().__init__( + provider=provider_entity.provider, + label=provider_entity.label, + icon_small=provider_entity.icon_small, + icon_large=provider_entity.icon_large, + supported_model_types=provider_entity.supported_model_types, + ) + + +class ProviderModelWithStatusEntity(ProviderModel): + """ + Model class for model response. + """ + + status: ModelStatus + load_balancing_enabled: bool = False + + +class ModelWithProviderEntity(ProviderModelWithStatusEntity): + """ + Model with provider entity. + """ + + provider: SimpleModelProviderEntity + + +class DefaultModelProviderEntity(BaseModel): + """ + Default model provider entity. + """ + + provider: str + label: I18nObject + icon_small: Optional[I18nObject] = None + icon_large: Optional[I18nObject] = None + supported_model_types: Sequence[ModelType] = [] + + +class DefaultModelEntity(BaseModel): + """ + Default model entity. + """ + + model: str + model_type: ModelType + provider: DefaultModelProviderEntity + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..bff5a0ec9c6be7639a9a882d677e28f6048c9493 --- /dev/null +++ b/api/core/entities/provider_configuration.py @@ -0,0 +1,1068 @@ +import datetime +import json +import logging +from collections import defaultdict +from collections.abc import Iterator +from json import JSONDecodeError +from typing import Optional + +from pydantic import BaseModel, ConfigDict + +from constants import HIDDEN_VALUE +from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity +from core.entities.provider_entities import ( + CustomConfiguration, + ModelSettings, + SystemConfiguration, + SystemConfigurationStatus, +) +from core.helper import encrypter +from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType +from core.model_runtime.entities.model_entities import FetchFrom, ModelType +from core.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from core.model_runtime.model_providers import model_provider_factory +from core.model_runtime.model_providers.__base.ai_model import AIModel +from core.model_runtime.model_providers.__base.model_provider import ModelProvider +from extensions.ext_database import db +from models.provider import ( + LoadBalancingModelConfig, + Provider, + ProviderModel, + ProviderModelSetting, + ProviderType, + TenantPreferredModelProvider, +) + +logger = logging.getLogger(__name__) + +original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {} + + +class ProviderConfiguration(BaseModel): + """ + Model class for provider configuration. + """ + + tenant_id: str + provider: ProviderEntity + preferred_provider_type: ProviderType + using_provider_type: ProviderType + system_configuration: SystemConfiguration + custom_configuration: CustomConfiguration + model_settings: list[ModelSettings] + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) + + def __init__(self, **data): + super().__init__(**data) + + if self.provider.provider not in original_provider_configurate_methods: + original_provider_configurate_methods[self.provider.provider] = [] + for configurate_method in self.provider.configurate_methods: + original_provider_configurate_methods[self.provider.provider].append(configurate_method) + + if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: + if ( + any( + len(quota_configuration.restrict_models) > 0 + for quota_configuration in self.system_configuration.quota_configurations + ) + and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods + ): + self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) + + def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: + """ + Get current credentials. + + :param model_type: model type + :param model: model name + :return: + """ + if self.model_settings: + # check if model is disabled by admin + for model_setting in self.model_settings: + if model_setting.model_type == model_type and model_setting.model == model: + if not model_setting.enabled: + raise ValueError(f"Model {model} is disabled.") + + if self.using_provider_type == ProviderType.SYSTEM: + restrict_models = [] + for quota_configuration in self.system_configuration.quota_configurations: + if self.system_configuration.current_quota_type != quota_configuration.quota_type: + continue + + restrict_models = quota_configuration.restrict_models + if self.system_configuration.credentials is None: + return None + copy_credentials = self.system_configuration.credentials.copy() + if restrict_models: + for restrict_model in restrict_models: + if ( + restrict_model.model_type == model_type + and restrict_model.model == model + and restrict_model.base_model_name + ): + copy_credentials["base_model_name"] = restrict_model.base_model_name + + return copy_credentials + else: + credentials = None + if self.custom_configuration.models: + for model_configuration in self.custom_configuration.models: + if model_configuration.model_type == model_type and model_configuration.model == model: + credentials = model_configuration.credentials + break + + if not credentials and self.custom_configuration.provider: + credentials = self.custom_configuration.provider.credentials + + return credentials + + def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]: + """ + Get system configuration status. + :return: + """ + if self.system_configuration.enabled is False: + return SystemConfigurationStatus.UNSUPPORTED + + current_quota_type = self.system_configuration.current_quota_type + current_quota_configuration = next( + (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None + ) + if current_quota_configuration is None: + return None + + return ( + SystemConfigurationStatus.ACTIVE + if current_quota_configuration.is_valid + else SystemConfigurationStatus.QUOTA_EXCEEDED + ) + + def is_custom_configuration_available(self) -> bool: + """ + Check custom configuration available. + :return: + """ + return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 + + def get_custom_credentials(self, obfuscated: bool = False): + """ + Get custom credentials. + + :param obfuscated: obfuscated secret data in credentials + :return: + """ + if self.custom_configuration.provider is None: + return None + + credentials = self.custom_configuration.provider.credentials + if not obfuscated: + return credentials + + # Obfuscate credentials + return self.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [], + ) + + def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]: + """ + Validate custom credentials. + :param credentials: provider credentials + :return: + """ + # get provider + provider_record = ( + db.session.query(Provider) + .filter( + Provider.tenant_id == self.tenant_id, + Provider.provider_name == self.provider.provider, + Provider.provider_type == ProviderType.CUSTOM.value, + ) + .first() + ) + + # Get provider credential secret variables + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [] + ) + + if provider_record: + try: + # fix origin data + if provider_record.encrypted_config: + if not provider_record.encrypted_config.startswith("{"): + original_credentials = {"openai_api_key": provider_record.encrypted_config} + else: + original_credentials = json.loads(provider_record.encrypted_config) + else: + original_credentials = {} + except JSONDecodeError: + original_credentials = {} + + # encrypt credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) + + credentials = model_provider_factory.provider_credentials_validate( + provider=self.provider.provider, credentials=credentials + ) + + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return provider_record, credentials + + def add_or_update_custom_credentials(self, credentials: dict) -> None: + """ + Add or update custom provider credentials. + :param credentials: + :return: + """ + # validate custom provider config + provider_record, credentials = self.custom_credentials_validate(credentials) + + # save provider + # Note: Do not switch the preferred provider, which allows users to use quotas first + if provider_record: + provider_record.encrypted_config = json.dumps(credentials) + provider_record.is_valid = True + provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + else: + provider_record = Provider( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(credentials), + is_valid=True, + ) + db.session.add(provider_record) + db.session.commit() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER + ) + + provider_model_credentials_cache.delete() + + self.switch_preferred_provider_type(ProviderType.CUSTOM) + + def delete_custom_credentials(self) -> None: + """ + Delete custom provider credentials. + :return: + """ + # get provider + provider_record = ( + db.session.query(Provider) + .filter( + Provider.tenant_id == self.tenant_id, + Provider.provider_name == self.provider.provider, + Provider.provider_type == ProviderType.CUSTOM.value, + ) + .first() + ) + + # delete provider + if provider_record: + self.switch_preferred_provider_type(ProviderType.SYSTEM) + + db.session.delete(provider_record) + db.session.commit() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + + provider_model_credentials_cache.delete() + + def get_custom_model_credentials( + self, model_type: ModelType, model: str, obfuscated: bool = False + ) -> Optional[dict]: + """ + Get custom model credentials. + + :param model_type: model type + :param model: model name + :param obfuscated: obfuscated secret data in credentials + :return: + """ + if not self.custom_configuration.models: + return None + + for model_configuration in self.custom_configuration.models: + if model_configuration.model_type == model_type and model_configuration.model == model: + credentials = model_configuration.credentials + if not obfuscated: + return credentials + + # Obfuscate credentials + return self.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [], + ) + + return None + + def custom_model_credentials_validate( + self, model_type: ModelType, model: str, credentials: dict + ) -> tuple[Optional[ProviderModel], dict]: + """ + Validate custom model credentials. + + :param model_type: model type + :param model: model name + :param credentials: model credentials + :return: + """ + # get provider model + provider_model_record = ( + db.session.query(ProviderModel) + .filter( + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name == self.provider.provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) + + # Get provider credential secret variables + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [] + ) + + if provider_model_record: + try: + original_credentials = ( + json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} + ) + except JSONDecodeError: + original_credentials = {} + + # decrypt credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) + + credentials = model_provider_factory.model_credentials_validate( + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials + ) + + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return provider_model_record, credentials + + def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None: + """ + Add or update custom model credentials. + + :param model_type: model type + :param model: model name + :param credentials: model credentials + :return: + """ + # validate custom model config + provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials) + + # save provider model + # Note: Do not switch the preferred provider, which allows users to use quotas first + if provider_model_record: + provider_model_record.encrypted_config = json.dumps(credentials) + provider_model_record.is_valid = True + provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + else: + provider_model_record = ProviderModel( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_name=model, + model_type=model_type.to_origin_model_type(), + encrypted_config=json.dumps(credentials), + is_valid=True, + ) + db.session.add(provider_model_record) + db.session.commit() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL, + ) + + provider_model_credentials_cache.delete() + + def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None: + """ + Delete custom model credentials. + :param model_type: model type + :param model: model name + :return: + """ + # get provider model + provider_model_record = ( + db.session.query(ProviderModel) + .filter( + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name == self.provider.provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) + + # delete provider model + if provider_model_record: + db.session.delete(provider_model_record) + db.session.commit() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL, + ) + + provider_model_credentials_cache.delete() + + def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: + """ + Enable model. + :param model_type: model type + :param model: model name + :return: + """ + model_setting = ( + db.session.query(ProviderModelSetting) + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) + + if model_setting: + model_setting.enabled = True + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + enabled=True, + ) + db.session.add(model_setting) + db.session.commit() + + return model_setting + + def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: + """ + Disable model. + :param model_type: model type + :param model: model name + :return: + """ + model_setting = ( + db.session.query(ProviderModelSetting) + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) + + if model_setting: + model_setting.enabled = False + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + enabled=False, + ) + db.session.add(model_setting) + db.session.commit() + + return model_setting + + def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]: + """ + Get provider model setting. + :param model_type: model type + :param model: model name + :return: + """ + return ( + db.session.query(ProviderModelSetting) + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) + + def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: + """ + Enable model load balancing. + :param model_type: model type + :param model: model name + :return: + """ + load_balancing_config_count = ( + db.session.query(LoadBalancingModelConfig) + .filter( + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .count() + ) + + if load_balancing_config_count <= 1: + raise ValueError("Model load balancing configuration must be more than 1.") + + model_setting = ( + db.session.query(ProviderModelSetting) + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) + + if model_setting: + model_setting.load_balancing_enabled = True + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + load_balancing_enabled=True, + ) + db.session.add(model_setting) + db.session.commit() + + return model_setting + + def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: + """ + Disable model load balancing. + :param model_type: model type + :param model: model name + :return: + """ + model_setting = ( + db.session.query(ProviderModelSetting) + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) + + if model_setting: + model_setting.load_balancing_enabled = False + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + load_balancing_enabled=False, + ) + db.session.add(model_setting) + db.session.commit() + + return model_setting + + def get_provider_instance(self) -> ModelProvider: + """ + Get provider instance. + :return: + """ + return model_provider_factory.get_provider_instance(self.provider.provider) + + def get_model_type_instance(self, model_type: ModelType) -> AIModel: + """ + Get current model type instance. + + :param model_type: model type + :return: + """ + # Get provider instance + provider_instance = self.get_provider_instance() + + # Get model instance of LLM + return provider_instance.get_model_instance(model_type) + + def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: + """ + Switch preferred provider type. + :param provider_type: + :return: + """ + if provider_type == self.preferred_provider_type: + return + + if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: + return + + # get preferred provider + preferred_model_provider = ( + db.session.query(TenantPreferredModelProvider) + .filter( + TenantPreferredModelProvider.tenant_id == self.tenant_id, + TenantPreferredModelProvider.provider_name == self.provider.provider, + ) + .first() + ) + + if preferred_model_provider: + preferred_model_provider.preferred_provider_type = provider_type.value + else: + preferred_model_provider = TenantPreferredModelProvider( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + preferred_provider_type=provider_type.value, + ) + db.session.add(preferred_model_provider) + + db.session.commit() + + def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: + """ + Extract secret input form variables. + + :param credential_form_schemas: + :return: + """ + secret_input_form_variables = [] + for credential_form_schema in credential_form_schemas: + if credential_form_schema.type == FormType.SECRET_INPUT: + secret_input_form_variables.append(credential_form_schema.variable) + + return secret_input_form_variables + + def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: + """ + Obfuscated credentials. + + :param credentials: credentials + :param credential_form_schemas: credential form schemas + :return: + """ + # Get provider credential secret variables + credential_secret_variables = self.extract_secret_variables(credential_form_schemas) + + # Obfuscate provider credentials + copy_credentials = credentials.copy() + for key, value in copy_credentials.items(): + if key in credential_secret_variables: + copy_credentials[key] = encrypter.obfuscated_token(value) + + return copy_credentials + + def get_provider_model( + self, model_type: ModelType, model: str, only_active: bool = False + ) -> Optional[ModelWithProviderEntity]: + """ + Get provider model. + :param model_type: model type + :param model: model name + :param only_active: return active model only + :return: + """ + provider_models = self.get_provider_models(model_type, only_active) + + for provider_model in provider_models: + if provider_model.model == model: + return provider_model + + return None + + def get_provider_models( + self, model_type: Optional[ModelType] = None, only_active: bool = False + ) -> list[ModelWithProviderEntity]: + """ + Get provider models. + :param model_type: model type + :param only_active: only active models + :return: + """ + provider_instance = self.get_provider_instance() + + model_types = [] + if model_type: + model_types.append(model_type) + else: + model_types = list(provider_instance.get_provider_schema().supported_model_types) + + # Group model settings by model type and model + model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) + for model_setting in self.model_settings: + model_setting_map[model_setting.model_type][model_setting.model] = model_setting + + if self.using_provider_type == ProviderType.SYSTEM: + provider_models = self._get_system_provider_models( + model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map + ) + else: + provider_models = self._get_custom_provider_models( + model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map + ) + + if only_active: + provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE] + + # resort provider_models + return sorted(provider_models, key=lambda x: x.model_type.value) + + def _get_system_provider_models( + self, + model_types: list[ModelType], + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]], + ) -> list[ModelWithProviderEntity]: + """ + Get system provider models. + + :param model_types: model types + :param provider_instance: provider instance + :param model_setting_map: model setting map + :return: + """ + provider_models = [] + for model_type in model_types: + for m in provider_instance.models(model_type): + status = ModelStatus.ACTIVE + if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: + model_setting = model_setting_map[m.model_type][m.model] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + provider_models.append( + ModelWithProviderEntity( + model=m.model, + label=m.label, + model_type=m.model_type, + features=m.features, + fetch_from=m.fetch_from, + model_properties=m.model_properties, + deprecated=m.deprecated, + provider=SimpleModelProviderEntity(self.provider), + status=status, + ) + ) + + if self.provider.provider not in original_provider_configurate_methods: + original_provider_configurate_methods[self.provider.provider] = [] + for configurate_method in provider_instance.get_provider_schema().configurate_methods: + original_provider_configurate_methods[self.provider.provider].append(configurate_method) + + should_use_custom_model = False + if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: + should_use_custom_model = True + + for quota_configuration in self.system_configuration.quota_configurations: + if self.system_configuration.current_quota_type != quota_configuration.quota_type: + continue + + restrict_models = quota_configuration.restrict_models + if len(restrict_models) == 0: + break + + if should_use_custom_model: + if original_provider_configurate_methods[self.provider.provider] == [ + ConfigurateMethod.CUSTOMIZABLE_MODEL + ]: + # only customizable model + for restrict_model in restrict_models: + if self.system_configuration.credentials is not None: + copy_credentials = self.system_configuration.credentials.copy() + if restrict_model.base_model_name: + copy_credentials["base_model_name"] = restrict_model.base_model_name + + try: + custom_model_schema = provider_instance.get_model_instance( + restrict_model.model_type + ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) + except Exception as ex: + logger.warning(f"get custom model schema failed, {ex}") + continue + + if not custom_model_schema: + continue + + if custom_model_schema.model_type not in model_types: + continue + + status = ModelStatus.ACTIVE + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): + model_setting = model_setting_map[custom_model_schema.model_type][ + custom_model_schema.model + ] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + provider_models.append( + ModelWithProviderEntity( + model=custom_model_schema.model, + label=custom_model_schema.label, + model_type=custom_model_schema.model_type, + features=custom_model_schema.features, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties=custom_model_schema.model_properties, + deprecated=custom_model_schema.deprecated, + provider=SimpleModelProviderEntity(self.provider), + status=status, + ) + ) + + # if llm name not in restricted llm list, remove it + restrict_model_names = [rm.model for rm in restrict_models] + for model in provider_models: + if model.model_type == ModelType.LLM and model.model not in restrict_model_names: + model.status = ModelStatus.NO_PERMISSION + elif not quota_configuration.is_valid: + model.status = ModelStatus.QUOTA_EXCEEDED + + return provider_models + + def _get_custom_provider_models( + self, + model_types: list[ModelType], + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]], + ) -> list[ModelWithProviderEntity]: + """ + Get custom provider models. + + :param model_types: model types + :param provider_instance: provider instance + :param model_setting_map: model setting map + :return: + """ + provider_models = [] + + credentials = None + if self.custom_configuration.provider: + credentials = self.custom_configuration.provider.credentials + + for model_type in model_types: + if model_type not in self.provider.supported_model_types: + continue + + models = provider_instance.models(model_type) + for m in models: + status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE + load_balancing_enabled = False + if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: + model_setting = model_setting_map[m.model_type][m.model] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + if len(model_setting.load_balancing_configs) > 1: + load_balancing_enabled = True + + provider_models.append( + ModelWithProviderEntity( + model=m.model, + label=m.label, + model_type=m.model_type, + features=m.features, + fetch_from=m.fetch_from, + model_properties=m.model_properties, + deprecated=m.deprecated, + provider=SimpleModelProviderEntity(self.provider), + status=status, + load_balancing_enabled=load_balancing_enabled, + ) + ) + + # custom models + for model_configuration in self.custom_configuration.models: + if model_configuration.model_type not in model_types: + continue + + try: + custom_model_schema = provider_instance.get_model_instance( + model_configuration.model_type + ).get_customizable_model_schema_from_credentials( + model_configuration.model, model_configuration.credentials + ) + except Exception as ex: + logger.warning(f"get custom model schema failed, {ex}") + continue + + if not custom_model_schema: + continue + + status = ModelStatus.ACTIVE + load_balancing_enabled = False + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): + model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + if len(model_setting.load_balancing_configs) > 1: + load_balancing_enabled = True + + provider_models.append( + ModelWithProviderEntity( + model=custom_model_schema.model, + label=custom_model_schema.label, + model_type=custom_model_schema.model_type, + features=custom_model_schema.features, + fetch_from=custom_model_schema.fetch_from, + model_properties=custom_model_schema.model_properties, + deprecated=custom_model_schema.deprecated, + provider=SimpleModelProviderEntity(self.provider), + status=status, + load_balancing_enabled=load_balancing_enabled, + ) + ) + + return provider_models + + +class ProviderConfigurations(BaseModel): + """ + Model class for provider configuration dict. + """ + + tenant_id: str + configurations: dict[str, ProviderConfiguration] = {} + + def __init__(self, tenant_id: str): + super().__init__(tenant_id=tenant_id) + + def get_models( + self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False + ) -> list[ModelWithProviderEntity]: + """ + Get available models. + + If preferred provider type is `system`: + Get the current **system mode** if provider supported, + if all system modes are not available (no quota), it is considered to be the **custom credential mode**. + If there is no model configured in custom mode, it is treated as no_configure. + system > custom > no_configure + + If preferred provider type is `custom`: + If custom credentials are configured, it is treated as custom mode. + Otherwise, get the current **system mode** if supported, + If all system modes are not available (no quota), it is treated as no_configure. + custom > system > no_configure + + If real mode is `system`, use system credentials to get models, + paid quotas > provider free quotas > system free quotas + include pre-defined models (exclude GPT-4, status marked as `no_permission`). + If real mode is `custom`, use workspace custom credentials to get models, + include pre-defined models, custom models(manual append). + If real mode is `no_configure`, only return pre-defined models from `model runtime`. + (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`) + model status marked as `active` is available. + + :param provider: provider name + :param model_type: model type + :param only_active: only active models + :return: + """ + all_models = [] + for provider_configuration in self.values(): + if provider and provider_configuration.provider.provider != provider: + continue + + all_models.extend(provider_configuration.get_provider_models(model_type, only_active)) + + return all_models + + def to_list(self) -> list[ProviderConfiguration]: + """ + Convert to list. + + :return: + """ + return list(self.values()) + + def __getitem__(self, key): + return self.configurations[key] + + def __setitem__(self, key, value): + self.configurations[key] = value + + def __iter__(self): + return iter(self.configurations) + + def values(self) -> Iterator[ProviderConfiguration]: + return iter(self.configurations.values()) + + def get(self, key, default=None): + return self.configurations.get(key, default) + + +class ProviderModelBundle(BaseModel): + """ + Provider model bundle. + """ + + configuration: ProviderConfiguration + provider_instance: ModelProvider + model_type_instance: AIModel + + # pydantic configs + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py new file mode 100644 index 0000000000000000000000000000000000000000..44725623dc4bd4e34c0d754f06e2d6f45ee255af --- /dev/null +++ b/api/core/entities/provider_entities.py @@ -0,0 +1,110 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, ConfigDict + +from core.model_runtime.entities.model_entities import ModelType +from models.provider import ProviderQuotaType + + +class QuotaUnit(Enum): + TIMES = "times" + TOKENS = "tokens" + CREDITS = "credits" + + +class SystemConfigurationStatus(Enum): + """ + Enum class for system configuration status. + """ + + ACTIVE = "active" + QUOTA_EXCEEDED = "quota-exceeded" + UNSUPPORTED = "unsupported" + + +class RestrictModel(BaseModel): + model: str + base_model_name: Optional[str] = None + model_type: ModelType + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) + + +class QuotaConfiguration(BaseModel): + """ + Model class for provider quota configuration. + """ + + quota_type: ProviderQuotaType + quota_unit: QuotaUnit + quota_limit: int + quota_used: int + is_valid: bool + restrict_models: list[RestrictModel] = [] + + +class SystemConfiguration(BaseModel): + """ + Model class for provider system configuration. + """ + + enabled: bool + current_quota_type: Optional[ProviderQuotaType] = None + quota_configurations: list[QuotaConfiguration] = [] + credentials: Optional[dict] = None + + +class CustomProviderConfiguration(BaseModel): + """ + Model class for provider custom configuration. + """ + + credentials: dict + + +class CustomModelConfiguration(BaseModel): + """ + Model class for provider custom model configuration. + """ + + model: str + model_type: ModelType + credentials: dict + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) + + +class CustomConfiguration(BaseModel): + """ + Model class for provider custom configuration. + """ + + provider: Optional[CustomProviderConfiguration] = None + models: list[CustomModelConfiguration] = [] + + +class ModelLoadBalancingConfiguration(BaseModel): + """ + Class for model load balancing configuration. + """ + + id: str + name: str + credentials: dict + + +class ModelSettings(BaseModel): + """ + Model class for model settings. + """ + + model: str + model_type: ModelType + enabled: bool = True + load_balancing_configs: list[ModelLoadBalancingConfiguration] = [] + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) diff --git a/api/core/errors/__init__.py b/api/core/errors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/errors/error.py b/api/core/errors/error.py new file mode 100644 index 0000000000000000000000000000000000000000..ad921bc2556ffe6f6fece3b3ca7db6974976bb30 --- /dev/null +++ b/api/core/errors/error.py @@ -0,0 +1,57 @@ +from typing import Optional + + +class LLMError(ValueError): + """Base class for all LLM exceptions.""" + + description: Optional[str] = None + + def __init__(self, description: Optional[str] = None) -> None: + self.description = description + + +class LLMBadRequestError(LLMError): + """Raised when the LLM returns bad request.""" + + description = "Bad Request" + + +class ProviderTokenNotInitError(ValueError): + """ + Custom exception raised when the provider token is not initialized. + """ + + description = "Provider Token Not Init" + + def __init__(self, *args, **kwargs): + self.description = args[0] if args else self.description + + +class QuotaExceededError(ValueError): + """ + Custom exception raised when the quota for a provider has been exceeded. + """ + + description = "Quota Exceeded" + + +class AppInvokeQuotaExceededError(ValueError): + """ + Custom exception raised when the quota for an app has been exceeded. + """ + + description = "App Invoke Quota Exceeded" + + +class ModelCurrentlyNotSupportError(ValueError): + """ + Custom exception raised when the model not support + """ + + description = "Model Currently Not Support" + + +class InvokeRateLimitError(ValueError): + """Raised when the Invoke returns rate limit error.""" + + description = "Rate Limit Error" diff --git a/api/core/extension/__init__.py b/api/core/extension/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4e20ec245302c328114d20c37cd27aaeee7f79 --- /dev/null +++ b/api/core/extension/api_based_extension_requestor.py @@ -0,0 +1,56 @@ +from typing import cast + +import requests + +from configs import dify_config +from models.api_based_extension import APIBasedExtensionPoint + + +class APIBasedExtensionRequestor: + timeout: tuple[int, int] = (5, 60) + """timeout for request connect and read""" + + def __init__(self, api_endpoint: str, api_key: str) -> None: + self.api_endpoint = api_endpoint + self.api_key = api_key + + def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: + """ + Request the api. + + :param point: the api point + :param params: the request params + :return: the response json + """ + headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(self.api_key)} + + url = self.api_endpoint + + try: + # proxy support for security + proxies = None + if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: + proxies = { + "http": dify_config.SSRF_PROXY_HTTP_URL, + "https": dify_config.SSRF_PROXY_HTTPS_URL, + } + + response = requests.request( + method="POST", + url=url, + json={"point": point.value, "params": params}, + headers=headers, + timeout=self.timeout, + proxies=proxies, + ) + except requests.exceptions.Timeout: + raise ValueError("request timeout") + except requests.exceptions.ConnectionError: + raise ValueError("request connection error") + + if response.status_code != 200: + raise ValueError( + "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100]) + ) + + return cast(dict, response.json()) diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py new file mode 100644 index 0000000000000000000000000000000000000000..231743bf2a948c351c4afd3ee9e848336feefce6 --- /dev/null +++ b/api/core/extension/extensible.py @@ -0,0 +1,120 @@ +import enum +import importlib.util +import json +import logging +import os +from pathlib import Path +from typing import Any, Optional + +from pydantic import BaseModel + +from core.helper.position_helper import sort_to_dict_by_position_map + + +class ExtensionModule(enum.Enum): + MODERATION = "moderation" + EXTERNAL_DATA_TOOL = "external_data_tool" + + +class ModuleExtension(BaseModel): + extension_class: Any = None + name: str + label: Optional[dict] = None + form_schema: Optional[list] = None + builtin: bool = True + position: Optional[int] = None + + +class Extensible: + module: ExtensionModule + + name: str + tenant_id: str + config: Optional[dict] = None + + def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: + self.tenant_id = tenant_id + self.config = config + + @classmethod + def scan_extensions(cls): + extensions = [] + position_map: dict[str, int] = {} + + # get the path of the current class + current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") + current_dir_path = os.path.dirname(current_path) + + # traverse subdirectories + for subdir_name in os.listdir(current_dir_path): + if subdir_name.startswith("__"): + continue + + subdir_path = os.path.join(current_dir_path, subdir_name) + extension_name = subdir_name + if os.path.isdir(subdir_path): + file_names = os.listdir(subdir_path) + + # is builtin extension, builtin extension + # in the front-end page and business logic, there are special treatments. + builtin = False + # default position is 0 can not be None for sort_to_dict_by_position_map + position = 0 + if "__builtin__" in file_names: + builtin = True + + builtin_file_path = os.path.join(subdir_path, "__builtin__") + if os.path.exists(builtin_file_path): + position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip()) + position_map[extension_name] = position + + if (extension_name + ".py") not in file_names: + logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") + continue + + # Dynamic loading {subdir_name}.py file and find the subclass of Extensible + py_path = os.path.join(subdir_path, extension_name + ".py") + spec = importlib.util.spec_from_file_location(extension_name, py_path) + if not spec or not spec.loader: + raise Exception(f"Failed to load module {extension_name} from {py_path}") + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + extension_class = None + for name, obj in vars(mod).items(): + if isinstance(obj, type) and issubclass(obj, cls) and obj != cls: + extension_class = obj + break + + if not extension_class: + logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") + continue + + json_data: dict[str, Any] = {} + if not builtin: + if "schema.json" not in file_names: + logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") + continue + + json_path = os.path.join(subdir_path, "schema.json") + json_data = {} + if os.path.exists(json_path): + with open(json_path, encoding="utf-8") as f: + json_data = json.load(f) + + extensions.append( + ModuleExtension( + extension_class=extension_class, + name=extension_name, + label=json_data.get("label"), + form_schema=json_data.get("form_schema"), + builtin=builtin, + position=position, + ) + ) + + sorted_extensions = sort_to_dict_by_position_map( + position_map=position_map, data=extensions, name_func=lambda x: x.name + ) + + return sorted_extensions diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb9e0306b577fe59ec00db72f720a4438305fb8 --- /dev/null +++ b/api/core/extension/extension.py @@ -0,0 +1,48 @@ +from typing import cast + +from core.extension.extensible import Extensible, ExtensionModule, ModuleExtension +from core.external_data_tool.base import ExternalDataTool +from core.moderation.base import Moderation + + +class Extension: + __module_extensions: dict[str, dict[str, ModuleExtension]] = {} + + module_classes = {ExtensionModule.MODERATION: Moderation, ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool} + + def init(self): + for module, module_class in self.module_classes.items(): + m = cast(Extensible, module_class) + self.__module_extensions[module.value] = m.scan_extensions() + + def module_extensions(self, module: str) -> list[ModuleExtension]: + module_extensions = self.__module_extensions.get(module) + + if not module_extensions: + raise ValueError(f"Extension Module {module} not found") + + return list(module_extensions.values()) + + def module_extension(self, module: ExtensionModule, extension_name: str) -> ModuleExtension: + module_extensions = self.__module_extensions.get(module.value) + + if not module_extensions: + raise ValueError(f"Extension Module {module} not found") + + module_extension = module_extensions.get(extension_name) + + if not module_extension: + raise ValueError(f"Extension {extension_name} not found") + + return module_extension + + def extension_class(self, module: ExtensionModule, extension_name: str) -> type: + module_extension = self.module_extension(module, extension_name) + t: type = module_extension.extension_class + return t + + def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None: + module_extension = self.module_extension(module, extension_name) + form_schema = module_extension.form_schema + + # TODO validate form_schema diff --git a/api/core/external_data_tool/__init__.py b/api/core/external_data_tool/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/external_data_tool/api/__builtin__ b/api/core/external_data_tool/api/__builtin__ new file mode 100644 index 0000000000000000000000000000000000000000..56a6051ca2b02b04ef92d5150c9ef600403cb1de --- /dev/null +++ b/api/core/external_data_tool/api/__builtin__ @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/api/core/external_data_tool/api/__init__.py b/api/core/external_data_tool/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py new file mode 100644 index 0000000000000000000000000000000000000000..53acdf075f8cf0bf4235db5daa2d3beec6bb0564 --- /dev/null +++ b/api/core/external_data_tool/api/api.py @@ -0,0 +1,96 @@ +from typing import Optional + +from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor +from core.external_data_tool.base import ExternalDataTool +from core.helper import encrypter +from extensions.ext_database import db +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint + + +class ApiExternalDataTool(ExternalDataTool): + """ + The api external data tool. + """ + + name: str = "api" + """the unique name of external data tool""" + + @classmethod + def validate_config(cls, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + # own validation logic + api_based_extension_id = config.get("api_based_extension_id") + if not api_based_extension_id: + raise ValueError("api_based_extension_id is required") + + # get api_based_extension + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) + + if not api_based_extension: + raise ValueError("api_based_extension_id is invalid") + + def query(self, inputs: dict, query: Optional[str] = None) -> str: + """ + Query the external data tool. + + :param inputs: user inputs + :param query: the query of chat app + :return: the tool query result + """ + # get params from config + if not self.config: + raise ValueError("config is required, config: {}".format(self.config)) + api_based_extension_id = self.config.get("api_based_extension_id") + assert api_based_extension_id is not None, "api_based_extension_id is required" + + # get api_based_extension + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) + + if not api_based_extension: + raise ValueError( + "[External data tool] API query failed, variable: {}, error: api_based_extension_id is invalid".format( + self.variable + ) + ) + + # decrypt api_key + api_key = encrypter.decrypt_token(tenant_id=self.tenant_id, token=api_based_extension.api_key) + + try: + # request api + requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key) + except Exception as e: + raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(self.variable, e)) + + response_json = requestor.request( + point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, + params={"app_id": self.app_id, "tool_variable": self.variable, "inputs": inputs, "query": query}, + ) + + if "result" not in response_json: + raise ValueError( + "[External data tool] API query failed, variable: {}, error: result not found in response".format( + self.variable + ) + ) + + if not isinstance(response_json["result"], str): + raise ValueError( + "[External data tool] API query failed, variable: {}, error: result is not string".format(self.variable) + ) + + return response_json["result"] diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py new file mode 100644 index 0000000000000000000000000000000000000000..0db736f096360660d23c77866e877e134859d146 --- /dev/null +++ b/api/core/external_data_tool/base.py @@ -0,0 +1,45 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from core.extension.extensible import Extensible, ExtensionModule + + +class ExternalDataTool(Extensible, ABC): + """ + The base class of external data tool. + """ + + module: ExtensionModule = ExtensionModule.EXTERNAL_DATA_TOOL + + app_id: str + """the id of app""" + variable: str + """the tool variable name of app tool""" + + def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None: + super().__init__(tenant_id, config) + self.app_id = app_id + self.variable = variable + + @classmethod + @abstractmethod + def validate_config(cls, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + raise NotImplementedError + + @abstractmethod + def query(self, inputs: dict, query: Optional[str] = None) -> str: + """ + Query the external data tool. + + :param inputs: user inputs + :param query: the query of chat app + :return: the tool query result + """ + raise NotImplementedError diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py new file mode 100644 index 0000000000000000000000000000000000000000..6a9703a569b3083157529819df627af3034861a3 --- /dev/null +++ b/api/core/external_data_tool/external_data_fetch.py @@ -0,0 +1,89 @@ +import logging +from collections.abc import Mapping +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from typing import Any, Optional + +from flask import Flask, current_app + +from core.app.app_config.entities import ExternalDataVariableEntity +from core.external_data_tool.factory import ExternalDataToolFactory + +logger = logging.getLogger(__name__) + + +class ExternalDataFetch: + def fetch( + self, + tenant_id: str, + app_id: str, + external_data_tools: list[ExternalDataVariableEntity], + inputs: Mapping[str, Any], + query: str, + ) -> Mapping[str, Any]: + """ + Fill in variable inputs from external data tools if exists. + + :param tenant_id: workspace id + :param app_id: app id + :param external_data_tools: external data tools configs + :param inputs: the inputs + :param query: the query + :return: the filled inputs + """ + results: dict[str, Any] = {} + inputs = dict(inputs) + with ThreadPoolExecutor() as executor: + futures = {} + for tool in external_data_tools: + future: Future[tuple[str | None, str | None]] = executor.submit( + self._query_external_data_tool, + current_app._get_current_object(), # type: ignore + tenant_id, + app_id, + tool, + inputs, + query, + ) + + futures[future] = tool + + for future in as_completed(futures): + tool_variable, result = future.result() + if tool_variable is not None: + results[tool_variable] = result + + inputs.update(results) + return inputs + + def _query_external_data_tool( + self, + flask_app: Flask, + tenant_id: str, + app_id: str, + external_data_tool: ExternalDataVariableEntity, + inputs: Mapping[str, Any], + query: str, + ) -> tuple[Optional[str], Optional[str]]: + """ + Query external data tool. + :param flask_app: flask app + :param tenant_id: tenant id + :param app_id: app id + :param external_data_tool: external data tool + :param inputs: inputs + :param query: query + :return: + """ + with flask_app.app_context(): + tool_variable = external_data_tool.variable + tool_type = external_data_tool.type + tool_config = external_data_tool.config + + external_data_tool_factory = ExternalDataToolFactory( + name=tool_type, tenant_id=tenant_id, app_id=app_id, variable=tool_variable, config=tool_config + ) + + # query external data tool + result = external_data_tool_factory.query(inputs=inputs, query=query) + + return tool_variable, result diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..245507e17c7032276b3270e5f10fd30a15407dcc --- /dev/null +++ b/api/core/external_data_tool/factory.py @@ -0,0 +1,38 @@ +from collections.abc import Mapping +from typing import Any, Optional, cast + +from core.extension.extensible import ExtensionModule +from extensions.ext_code_based_extension import code_based_extension + + +class ExternalDataToolFactory: + def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: + extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) + self.__extension_instance = extension_class( + tenant_id=tenant_id, app_id=app_id, variable=variable, config=config + ) + + @classmethod + def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param name: the name of external data tool + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config) + extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) + # FIXME mypy issue here, figure out how to fix it + extension_class.validate_config(tenant_id, config) # type: ignore + + def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str: + """ + Query the external data tool. + + :param inputs: user inputs + :param query: the query of chat app + :return: the tool query result + """ + return cast(str, self.__extension_instance.query(inputs, query)) diff --git a/api/core/file/__init__.py b/api/core/file/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44749ebec3567e56ffb3a93865ac01f87508ac2e --- /dev/null +++ b/api/core/file/__init__.py @@ -0,0 +1,19 @@ +from .constants import FILE_MODEL_IDENTITY +from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType +from .models import ( + File, + FileUploadConfig, + ImageConfig, +) + +__all__ = [ + "FILE_MODEL_IDENTITY", + "ArrayFileAttribute", + "File", + "FileAttribute", + "FileBelongsTo", + "FileTransferMethod", + "FileType", + "FileUploadConfig", + "ImageConfig", +] diff --git a/api/core/file/constants.py b/api/core/file/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..ce1d238e93742ba14af32f6a8d94f3661872ff5e --- /dev/null +++ b/api/core/file/constants.py @@ -0,0 +1 @@ +FILE_MODEL_IDENTITY = "__dify__file__" diff --git a/api/core/file/enums.py b/api/core/file/enums.py new file mode 100644 index 0000000000000000000000000000000000000000..06b99d3eb0b4b4ac4f9567fce6940c6fa493099c --- /dev/null +++ b/api/core/file/enums.py @@ -0,0 +1,55 @@ +from enum import StrEnum + + +class FileType(StrEnum): + IMAGE = "image" + DOCUMENT = "document" + AUDIO = "audio" + VIDEO = "video" + CUSTOM = "custom" + + @staticmethod + def value_of(value): + for member in FileType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileTransferMethod(StrEnum): + REMOTE_URL = "remote_url" + LOCAL_FILE = "local_file" + TOOL_FILE = "tool_file" + + @staticmethod + def value_of(value): + for member in FileTransferMethod: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileBelongsTo(StrEnum): + USER = "user" + ASSISTANT = "assistant" + + @staticmethod + def value_of(value): + for member in FileBelongsTo: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileAttribute(StrEnum): + TYPE = "type" + SIZE = "size" + NAME = "name" + MIME_TYPE = "mime_type" + TRANSFER_METHOD = "transfer_method" + URL = "url" + EXTENSION = "extension" + + +class ArrayFileAttribute(StrEnum): + LENGTH = "length" diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..4a50fb85c9cca3a10dc36d0668d21a238b2cc82d --- /dev/null +++ b/api/core/file/file_manager.py @@ -0,0 +1,133 @@ +import base64 +from collections.abc import Mapping + +from configs import dify_config +from core.helper import ssrf_proxy +from core.model_runtime.entities import ( + AudioPromptMessageContent, + DocumentPromptMessageContent, + ImagePromptMessageContent, + MultiModalPromptMessageContent, + VideoPromptMessageContent, +) +from extensions.ext_storage import storage + +from . import helpers +from .enums import FileAttribute +from .models import File, FileTransferMethod, FileType +from .tool_file_parser import ToolFileParser + + +def get_attr(*, file: File, attr: FileAttribute): + match attr: + case FileAttribute.TYPE: + return file.type.value + case FileAttribute.SIZE: + return file.size + case FileAttribute.NAME: + return file.filename + case FileAttribute.MIME_TYPE: + return file.mime_type + case FileAttribute.TRANSFER_METHOD: + return file.transfer_method.value + case FileAttribute.URL: + return file.remote_url + case FileAttribute.EXTENSION: + return file.extension + + +def to_prompt_message_content( + f: File, + /, + *, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, +) -> MultiModalPromptMessageContent: + if f.extension is None: + raise ValueError("Missing file extension") + if f.mime_type is None: + raise ValueError("Missing file mime_type") + + params = { + "base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "", + "url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "", + "format": f.extension.removeprefix("."), + "mime_type": f.mime_type, + } + if f.type == FileType.IMAGE: + params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + + prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = { + FileType.IMAGE: ImagePromptMessageContent, + FileType.AUDIO: AudioPromptMessageContent, + FileType.VIDEO: VideoPromptMessageContent, + FileType.DOCUMENT: DocumentPromptMessageContent, + } + + try: + return prompt_class_map[f.type].model_validate(params) + except KeyError: + raise ValueError(f"file type {f.type} is not supported") + + +def download(f: File, /): + if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE): + return _download_file_content(f._storage_key) + elif f.transfer_method == FileTransferMethod.REMOTE_URL: + response = ssrf_proxy.get(f.remote_url, follow_redirects=True) + response.raise_for_status() + return response.content + raise ValueError(f"unsupported transfer method: {f.transfer_method}") + + +def _download_file_content(path: str, /): + """ + Download and return the contents of a file as bytes. + + This function loads the file from storage and ensures it's in bytes format. + + Args: + path (str): The path to the file in storage. + + Returns: + bytes: The contents of the file as a bytes object. + + Raises: + ValueError: If the loaded file is not a bytes object. + """ + data = storage.load(path, stream=False) + if not isinstance(data, bytes): + raise ValueError(f"file {path} is not a bytes object") + return data + + +def _get_encoded_string(f: File, /): + match f.transfer_method: + case FileTransferMethod.REMOTE_URL: + response = ssrf_proxy.get(f.remote_url, follow_redirects=True) + response.raise_for_status() + data = response.content + case FileTransferMethod.LOCAL_FILE: + data = _download_file_content(f._storage_key) + case FileTransferMethod.TOOL_FILE: + data = _download_file_content(f._storage_key) + + encoded_string = base64.b64encode(data).decode("utf-8") + return encoded_string + + +def _to_url(f: File, /): + if f.transfer_method == FileTransferMethod.REMOTE_URL: + if f.remote_url is None: + raise ValueError("Missing file remote_url") + return f.remote_url + elif f.transfer_method == FileTransferMethod.LOCAL_FILE: + if f.related_id is None: + raise ValueError("Missing file related_id") + return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id) + elif f.transfer_method == FileTransferMethod.TOOL_FILE: + # add sign url + if f.related_id is None or f.extension is None: + raise ValueError("Missing file related_id or extension") + return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension) + else: + raise ValueError(f"Unsupported transfer method: {f.transfer_method}") diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..12123cf3f74630622d9203220c305f9211303e37 --- /dev/null +++ b/api/core/file/helpers.py @@ -0,0 +1,48 @@ +import base64 +import hashlib +import hmac +import os +import time + +from configs import dify_config + + +def get_signed_file_url(upload_file_id: str) -> str: + url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + key = dify_config.SECRET_KEY.encode() + msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" + sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + +def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + +def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/file/models.py b/api/core/file/models.py new file mode 100644 index 0000000000000000000000000000000000000000..0de0089430ef32f2a7b4aed06ec13da37dbb2ef3 --- /dev/null +++ b/api/core/file/models.py @@ -0,0 +1,141 @@ +from collections.abc import Mapping, Sequence +from typing import Optional + +from pydantic import BaseModel, Field, model_validator + +from core.model_runtime.entities.message_entities import ImagePromptMessageContent + +from . import helpers +from .constants import FILE_MODEL_IDENTITY +from .enums import FileTransferMethod, FileType +from .tool_file_parser import ToolFileParser + + +class ImageConfig(BaseModel): + """ + NOTE: This part of validation is deprecated, but still used in app features "Image Upload". + """ + + number_limits: int = 0 + transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + detail: ImagePromptMessageContent.DETAIL | None = None + + +class FileUploadConfig(BaseModel): + """ + File Upload Entity. + """ + + image_config: Optional[ImageConfig] = None + allowed_file_types: Sequence[FileType] = Field(default_factory=list) + allowed_file_extensions: Sequence[str] = Field(default_factory=list) + allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + number_limits: int = 0 + + +class File(BaseModel): + dify_model_identity: str = FILE_MODEL_IDENTITY + + id: Optional[str] = None # message file id + tenant_id: str + type: FileType + transfer_method: FileTransferMethod + remote_url: Optional[str] = None # remote url + related_id: Optional[str] = None + filename: Optional[str] = None + extension: Optional[str] = Field(default=None, description="File extension, should contains dot") + mime_type: Optional[str] = None + size: int = -1 + + # Those properties are private, should not be exposed to the outside. + _storage_key: str + + def __init__( + self, + *, + id: Optional[str] = None, + tenant_id: str, + type: FileType, + transfer_method: FileTransferMethod, + remote_url: Optional[str] = None, + related_id: Optional[str] = None, + filename: Optional[str] = None, + extension: Optional[str] = None, + mime_type: Optional[str] = None, + size: int = -1, + storage_key: str, + ): + super().__init__( + id=id, + tenant_id=tenant_id, + type=type, + transfer_method=transfer_method, + remote_url=remote_url, + related_id=related_id, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + ) + self._storage_key = storage_key + + def to_dict(self) -> Mapping[str, str | int | None]: + data = self.model_dump(mode="json") + return { + **data, + "url": self.generate_url(), + } + + @property + def markdown(self) -> str: + url = self.generate_url() + if self.type == FileType.IMAGE: + text = f"![{self.filename or ''}]({url})" + else: + text = f"[{self.filename or url}]({url})" + + return text + + def generate_url(self) -> Optional[str]: + if self.type == FileType.IMAGE: + if self.transfer_method == FileTransferMethod.REMOTE_URL: + return self.remote_url + elif self.transfer_method == FileTransferMethod.LOCAL_FILE: + if self.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=self.related_id) + elif self.transfer_method == FileTransferMethod.TOOL_FILE: + assert self.related_id is not None + assert self.extension is not None + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=self.extension + ) + else: + if self.transfer_method == FileTransferMethod.REMOTE_URL: + return self.remote_url + elif self.transfer_method == FileTransferMethod.LOCAL_FILE: + if self.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=self.related_id) + elif self.transfer_method == FileTransferMethod.TOOL_FILE: + assert self.related_id is not None + assert self.extension is not None + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=self.extension + ) + + @model_validator(mode="after") + def validate_after(self): + match self.transfer_method: + case FileTransferMethod.REMOTE_URL: + if not self.remote_url: + raise ValueError("Missing file url") + if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): + raise ValueError("Invalid file url") + case FileTransferMethod.LOCAL_FILE: + if not self.related_id: + raise ValueError("Missing file related_id") + case FileTransferMethod.TOOL_FILE: + if not self.related_id: + raise ValueError("Missing file related_id") + return self diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..6fa101cf36192bc4db2a87b90251e48093b9a219 --- /dev/null +++ b/api/core/file/tool_file_parser.py @@ -0,0 +1,12 @@ +from typing import TYPE_CHECKING, Any, cast + +if TYPE_CHECKING: + from core.tools.tool_file_manager import ToolFileManager + +tool_file_manager: dict[str, Any] = {"manager": None} + + +class ToolFileParser: + @staticmethod + def get_tool_file_manager() -> "ToolFileManager": + return cast("ToolFileManager", tool_file_manager["manager"]) diff --git a/api/core/helper/__init__.py b/api/core/helper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/helper/code_executor/__init__.py b/api/core/helper/code_executor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec885c221879cbc941efad63b57aed08de314f66 --- /dev/null +++ b/api/core/helper/code_executor/__init__.py @@ -0,0 +1,3 @@ +from .code_executor import CodeExecutor, CodeLanguage + +__all__ = ["CodeExecutor", "CodeLanguage"] diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..15b501780e766c47b50b06b56ef53eff9bbb6b4b --- /dev/null +++ b/api/core/helper/code_executor/code_executor.py @@ -0,0 +1,140 @@ +import logging +from collections.abc import Mapping +from enum import StrEnum +from threading import Lock +from typing import Any, Optional + +from httpx import Timeout, post +from pydantic import BaseModel +from yarl import URL + +from configs import dify_config +from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer +from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer +from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer +from core.helper.code_executor.template_transformer import TemplateTransformer + +logger = logging.getLogger(__name__) + + +class CodeExecutionError(Exception): + pass + + +class CodeExecutionResponse(BaseModel): + class Data(BaseModel): + stdout: Optional[str] = None + error: Optional[str] = None + + code: int + message: str + data: Data + + +class CodeLanguage(StrEnum): + PYTHON3 = "python3" + JINJA2 = "jinja2" + JAVASCRIPT = "javascript" + + +class CodeExecutor: + dependencies_cache: dict[str, str] = {} + dependencies_cache_lock = Lock() + + code_template_transformers: dict[CodeLanguage, type[TemplateTransformer]] = { + CodeLanguage.PYTHON3: Python3TemplateTransformer, + CodeLanguage.JINJA2: Jinja2TemplateTransformer, + CodeLanguage.JAVASCRIPT: NodeJsTemplateTransformer, + } + + code_language_to_running_language = { + CodeLanguage.JAVASCRIPT: "nodejs", + CodeLanguage.JINJA2: CodeLanguage.PYTHON3, + CodeLanguage.PYTHON3: CodeLanguage.PYTHON3, + } + + supported_dependencies_languages: set[CodeLanguage] = {CodeLanguage.PYTHON3} + + @classmethod + def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str: + """ + Execute code + :param language: code language + :param code: code + :return: + """ + url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run" + + headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY} + + data = { + "language": cls.code_language_to_running_language.get(language), + "code": code, + "preload": preload, + "enable_network": True, + } + + try: + response = post( + str(url), + json=data, + headers=headers, + timeout=Timeout( + connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, + read=dify_config.CODE_EXECUTION_READ_TIMEOUT, + write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, + pool=None, + ), + ) + if response.status_code == 503: + raise CodeExecutionError("Code execution service is unavailable") + elif response.status_code != 200: + raise Exception( + f"Failed to execute code, got status code {response.status_code}," + f" please check if the sandbox service is running" + ) + except CodeExecutionError as e: + raise e + except Exception as e: + raise CodeExecutionError( + "Failed to execute code, which is likely a network issue," + " please check if the sandbox service is running." + f" ( Error: {str(e)} )" + ) + + try: + response_data = response.json() + except: + raise CodeExecutionError("Failed to parse response") + + if (code := response_data.get("code")) != 0: + raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}") + + response_code = CodeExecutionResponse(**response_data) + + if response_code.data.error: + raise CodeExecutionError(response_code.data.error) + + return response_code.data.stdout or "" + + @classmethod + def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]): + """ + Execute code + :param language: code language + :param code: code + :param inputs: inputs + :return: + """ + template_transformer = cls.code_template_transformers.get(language) + if not template_transformer: + raise CodeExecutionError(f"Unsupported language {language}") + + runner, preload = template_transformer.transform_caller(code, inputs) + + try: + response = cls.execute_code(language, preload, runner) + except CodeExecutionError as e: + raise e + + return template_transformer.transform_response(response) diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..e233a596b9da0e605e70d8928024a5a469499708 --- /dev/null +++ b/api/core/helper/code_executor/code_node_provider.py @@ -0,0 +1,34 @@ +from abc import abstractmethod + +from pydantic import BaseModel + + +class CodeNodeProvider(BaseModel): + @staticmethod + @abstractmethod + def get_language() -> str: + pass + + @classmethod + def is_accept_language(cls, language: str) -> bool: + return language == cls.get_language() + + @classmethod + @abstractmethod + def get_default_code(cls) -> str: + """ + get default code in specific programming language for the code node + """ + pass + + @classmethod + def get_default_config(cls) -> dict: + return { + "type": "code", + "config": { + "variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}], + "code_language": cls.get_language(), + "code": cls.get_default_code(), + "outputs": {"result": {"type": "string", "children": None}}, + }, + } diff --git a/api/core/helper/code_executor/javascript/__init__.py b/api/core/helper/code_executor/javascript/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/helper/code_executor/javascript/javascript_code_provider.py b/api/core/helper/code_executor/javascript/javascript_code_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..ae324b83a95124d6b16b444de9552d156dc7dc2d --- /dev/null +++ b/api/core/helper/code_executor/javascript/javascript_code_provider.py @@ -0,0 +1,22 @@ +from textwrap import dedent + +from core.helper.code_executor.code_executor import CodeLanguage +from core.helper.code_executor.code_node_provider import CodeNodeProvider + + +class JavascriptCodeProvider(CodeNodeProvider): + @staticmethod + def get_language() -> str: + return CodeLanguage.JAVASCRIPT + + @classmethod + def get_default_code(cls) -> str: + return dedent( + """ + function main({arg1, arg2}) { + return { + result: arg1 + arg2 + } + } + """ + ) diff --git a/api/core/helper/code_executor/javascript/javascript_transformer.py b/api/core/helper/code_executor/javascript/javascript_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d67a0903aa4d4cd07a1d999f2b109f3a84d48b29 --- /dev/null +++ b/api/core/helper/code_executor/javascript/javascript_transformer.py @@ -0,0 +1,26 @@ +from textwrap import dedent + +from core.helper.code_executor.template_transformer import TemplateTransformer + + +class NodeJsTemplateTransformer(TemplateTransformer): + @classmethod + def get_runner_script(cls) -> str: + runner_script = dedent( + f""" + // declare main function + {cls._code_placeholder} + + // decode and prepare input object + var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8')) + + // execute main function + var output_obj = main(inputs_obj) + + // convert output to json and print + var output_json = JSON.stringify(output_obj) + var result = `<>${{output_json}}<>` + console.log(result) + """ + ) + return runner_script diff --git a/api/core/helper/code_executor/jinja2/__init__.py b/api/core/helper/code_executor/jinja2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py new file mode 100644 index 0000000000000000000000000000000000000000..264947b5686d0ee49108dbe4508b148af7daac98 --- /dev/null +++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py @@ -0,0 +1,16 @@ +from collections.abc import Mapping + +from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage + + +class Jinja2Formatter: + @classmethod + def format(cls, template: str, inputs: Mapping[str, str]) -> str: + """ + Format template + :param template: template + :param inputs: inputs + :return: + """ + result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs) + return str(result.get("result", "")) diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..63d58edbc794e99ab28f563cc98059f39d55aa6c --- /dev/null +++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py @@ -0,0 +1,57 @@ +from textwrap import dedent + +from core.helper.code_executor.template_transformer import TemplateTransformer + + +class Jinja2TemplateTransformer(TemplateTransformer): + @classmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + return {"result": cls.extract_result_str_from_response(response)} + + @classmethod + def get_runner_script(cls) -> str: + runner_script = dedent(f""" + # declare main function + def main(**inputs): + import jinja2 + template = jinja2.Template('''{cls._code_placeholder}''') + return template.render(**inputs) + + import json + from base64 import b64decode + + # decode and prepare input dict + inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8')) + + # execute main function + output = main(**inputs_obj) + + # convert output and print + result = f'''<>{{output}}<>''' + print(result) + + """) + return runner_script + + @classmethod + def get_preload_script(cls) -> str: + preload_script = dedent(""" + import jinja2 + from base64 import b64decode + + def _jinja2_preload_(): + # prepare jinja2 environment, load template and render before to avoid sandbox issue + template = jinja2.Template('{{s}}') + template.render(s='a') + + if __name__ == '__main__': + _jinja2_preload_() + + """) + + return preload_script diff --git a/api/core/helper/code_executor/python3/__init__.py b/api/core/helper/code_executor/python3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/helper/code_executor/python3/python3_code_provider.py b/api/core/helper/code_executor/python3/python3_code_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..9cca8af7c698bcdd148e2ca72f1810eb4d30ead1 --- /dev/null +++ b/api/core/helper/code_executor/python3/python3_code_provider.py @@ -0,0 +1,21 @@ +from textwrap import dedent + +from core.helper.code_executor.code_executor import CodeLanguage +from core.helper.code_executor.code_node_provider import CodeNodeProvider + + +class Python3CodeProvider(CodeNodeProvider): + @staticmethod + def get_language() -> str: + return CodeLanguage.PYTHON3 + + @classmethod + def get_default_code(cls) -> str: + return dedent( + """ + def main(arg1: str, arg2: str) -> dict: + return { + "result": arg1 + arg2, + } + """ + ) diff --git a/api/core/helper/code_executor/python3/python3_transformer.py b/api/core/helper/code_executor/python3/python3_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..75a5a44d086c3c855e69ebfab8e094093486d01a --- /dev/null +++ b/api/core/helper/code_executor/python3/python3_transformer.py @@ -0,0 +1,27 @@ +from textwrap import dedent + +from core.helper.code_executor.template_transformer import TemplateTransformer + + +class Python3TemplateTransformer(TemplateTransformer): + @classmethod + def get_runner_script(cls) -> str: + runner_script = dedent(f""" + # declare main function + {cls._code_placeholder} + + import json + from base64 import b64decode + + # decode and prepare input dict + inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8')) + + # execute main function + output_obj = main(**inputs_obj) + + # convert output to json and print + output_json = json.dumps(output_obj, indent=4) + result = f'''<>{{output_json}}<>''' + print(result) + """) + return runner_script diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..baa792b5bc6c41301671034441d6330dbc4a586e --- /dev/null +++ b/api/core/helper/code_executor/template_transformer.py @@ -0,0 +1,79 @@ +import json +import re +from abc import ABC, abstractmethod +from base64 import b64encode +from collections.abc import Mapping +from typing import Any + + +class TemplateTransformer(ABC): + _code_placeholder: str = "{{code}}" + _inputs_placeholder: str = "{{inputs}}" + _result_tag: str = "<>" + + @classmethod + def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]: + """ + Transform code to python runner + :param code: code + :param inputs: inputs + :return: runner, preload + """ + runner_script = cls.assemble_runner_script(code, inputs) + preload_script = cls.get_preload_script() + + return runner_script, preload_script + + @classmethod + def extract_result_str_from_response(cls, response: str): + result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) + if not result: + raise ValueError("Failed to parse result") + return result.group(1) + + @classmethod + def transform_response(cls, response: str) -> Mapping[str, Any]: + """ + Transform response to dict + :param response: response + :return: + """ + try: + result = json.loads(cls.extract_result_str_from_response(response)) + except json.JSONDecodeError: + raise ValueError("failed to parse response") + if not isinstance(result, dict): + raise ValueError("result must be a dict") + if not all(isinstance(k, str) for k in result): + raise ValueError("result keys must be strings") + return result + + @classmethod + @abstractmethod + def get_runner_script(cls) -> str: + """ + Get runner script + """ + pass + + @classmethod + def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: + inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode() + input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") + return input_base64_encoded + + @classmethod + def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str: + # assemble runner script + script = cls.get_runner_script() + script = script.replace(cls._code_placeholder, code) + inputs_str = cls.serialize_inputs(inputs) + script = script.replace(cls._inputs_placeholder, inputs_str) + return script + + @classmethod + def get_preload_script(cls) -> str: + """ + Get preload script + """ + return "" diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py new file mode 100644 index 0000000000000000000000000000000000000000..744fce1cf99cfef5fe7e95bb77f4b3e57d66b2d5 --- /dev/null +++ b/api/core/helper/encrypter.py @@ -0,0 +1,39 @@ +import base64 + +from libs import rsa + + +def obfuscated_token(token: str): + if not token: + return token + if len(token) <= 8: + return "*" * 20 + return token[:6] + "*" * 12 + token[-2:] + + +def encrypt_token(tenant_id: str, token: str): + from models.account import Tenant + from models.engine import db + + if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): + raise ValueError(f"Tenant with id {tenant_id} not found") + encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) + return base64.b64encode(encrypted_token).decode() + + +def decrypt_token(tenant_id: str, token: str): + return rsa.decrypt(base64.b64decode(token), tenant_id) + + +def batch_decrypt_token(tenant_id: str, tokens: list[str]): + rsa_key, cipher_rsa = rsa.get_decrypt_decoding(tenant_id) + + return [rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa) for token in tokens] + + +def get_decrypt_decoding(tenant_id: str): + return rsa.get_decrypt_decoding(tenant_id) + + +def decrypt_token_with_decoding(token: str, rsa_key, cipher_rsa): + return rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa) diff --git a/api/core/helper/lru_cache.py b/api/core/helper/lru_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..81501d2e4e23b26314e87722f00ccd13c7de32b7 --- /dev/null +++ b/api/core/helper/lru_cache.py @@ -0,0 +1,22 @@ +from collections import OrderedDict +from typing import Any + + +class LRUCache: + def __init__(self, capacity: int): + self.cache: OrderedDict[Any, Any] = OrderedDict() + self.capacity = capacity + + def get(self, key: Any) -> Any: + if key not in self.cache: + return None + else: + self.cache.move_to_end(key) # move the key to the end of the OrderedDict + return self.cache[key] + + def put(self, key: Any, value: Any) -> None: + if key in self.cache: + self.cache.move_to_end(key) + self.cache[key] = value + if len(self.cache) > self.capacity: + self.cache.popitem(last=False) # pop the first item diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..35349210bd53abc0ff1d0b89ded1bb3366a7ac76 --- /dev/null +++ b/api/core/helper/model_provider_cache.py @@ -0,0 +1,52 @@ +import json +from enum import Enum +from json import JSONDecodeError +from typing import Optional + +from extensions.ext_redis import redis_client + + +class ProviderCredentialsCacheType(Enum): + PROVIDER = "provider" + MODEL = "provider_model" + LOAD_BALANCING_MODEL = "load_balancing_provider_model" + + +class ProviderCredentialsCache: + def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType): + self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" + + def get(self) -> Optional[dict]: + """ + Get cached model provider credentials. + + :return: + """ + cached_provider_credentials = redis_client.get(self.cache_key) + if cached_provider_credentials: + try: + cached_provider_credentials = cached_provider_credentials.decode("utf-8") + cached_provider_credentials = json.loads(cached_provider_credentials) + except JSONDecodeError: + return None + + return dict(cached_provider_credentials) + else: + return None + + def set(self, credentials: dict) -> None: + """ + Cache model provider credentials. + + :param credentials: provider credentials + :return: + """ + redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) + + def delete(self) -> None: + """ + Delete cached model provider credentials. + + :return: + """ + redis_client.delete(self.cache_key) diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py new file mode 100644 index 0000000000000000000000000000000000000000..543444463b9f1a7e674d0af58de6579667cb41f7 --- /dev/null +++ b/api/core/helper/moderation.py @@ -0,0 +1,49 @@ +import logging +import random + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel +from extensions.ext_hosting_provider import hosting_configuration +from models.provider import ProviderType + +logger = logging.getLogger(__name__) + + +def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: + moderation_config = hosting_configuration.moderation_config + if ( + moderation_config + and moderation_config.enabled is True + and "openai" in hosting_configuration.provider_map + and hosting_configuration.provider_map["openai"].enabled is True + ): + using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type + provider_name = model_config.provider + if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: + hosting_openai_config = hosting_configuration.provider_map["openai"] + assert hosting_openai_config is not None + + # 2000 text per chunk + length = 2000 + text_chunks = [text[i : i + length] for i in range(0, len(text), length)] + + if len(text_chunks) == 0: + return True + + text_chunk = random.choice(text_chunks) + + try: + model_type_instance = OpenAIModerationModel() + # FIXME, for type hint using assert or raise ValueError is better here? + moderation_result = model_type_instance.invoke( + model="text-moderation-stable", credentials=hosting_openai_config.credentials or {}, text=text_chunk + ) + + if moderation_result is True: + return True + except Exception as ex: + logger.exception(f"Fails to check moderation, provider_name: {provider_name}") + raise InvokeBadRequestError("Rate limit exceeded, please try again later.") + + return False diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..9a041667e46df53aeb9f46f10f1263bf5423ef5d --- /dev/null +++ b/api/core/helper/module_import_helper.py @@ -0,0 +1,63 @@ +import importlib.util +import logging +import sys +from types import ModuleType +from typing import AnyStr + + +def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType: + """ + Importing a module from the source file directly + """ + try: + existed_spec = importlib.util.find_spec(module_name) + if existed_spec: + spec = existed_spec + if not spec.loader: + raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") + else: + # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly + # FIXME: mypy does not support the type of spec.loader + spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore + if not spec or not spec.loader: + raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") + if use_lazy_loader: + # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports + spec.loader = importlib.util.LazyLoader(spec.loader) + module = importlib.util.module_from_spec(spec) + if not existed_spec: + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + except Exception as e: + logging.exception(f"Failed to load module {module_name} from script file '{py_file_path!r}'") + raise e + + +def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]: + """ + Get all the subclasses of the parent type from the module + """ + classes = [ + x for _, x in vars(mod).items() if isinstance(x, type) and x != parent_type and issubclass(x, parent_type) + ] + return classes + + +def load_single_subclass_from_source( + *, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False +) -> type: + """ + Load a single subclass from the source + """ + module = import_module_from_source( + module_name=module_name, py_file_path=script_path, use_lazy_loader=use_lazy_loader + ) + subclasses = get_subclasses_from_module(module, parent_type) + match len(subclasses): + case 1: + return subclasses[0] + case 0: + raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path!r}") + case _: + raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path!r}") diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..3efdc8aa471697e42b75a1ff71978cc31c26d8be --- /dev/null +++ b/api/core/helper/position_helper.py @@ -0,0 +1,137 @@ +import os +from collections import OrderedDict +from collections.abc import Callable +from typing import Any + +from configs import dify_config +from core.tools.utils.yaml_utils import load_yaml_file + + +def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping from name to index from a YAML file + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_file_path = os.path.join(folder_path, file_name) + yaml_content = load_yaml_file(file_path=position_file_path, default_value=[]) + positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()] + return {name: index for index, name in enumerate(positions)} + + +def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping for tools from name to index from a YAML file. + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_map = get_position_map(folder_path, file_name=file_name) + + return pin_position_map( + position_map, + pin_list=dify_config.POSITION_TOOL_PINS_LIST, + ) + + +def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping for providers from name to index from a YAML file. + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_map = get_position_map(folder_path, file_name=file_name) + return pin_position_map( + position_map, + pin_list=dify_config.POSITION_PROVIDER_PINS_LIST, + ) + + +def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]: + """ + Pin the items in the pin list to the beginning of the position map. + Overall logic: exclude > include > pin + :param position_map: the position map to be sorted and filtered + :param pin_list: the list of pins to be put at the beginning + :return: the sorted position map + """ + positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x]) + + # Add pins to position map + position_map = {name: idx for idx, name in enumerate(pin_list)} + + # Add remaining positions to position map + start_idx = len(position_map) + for name in positions: + if name not in position_map: + position_map[name] = start_idx + start_idx += 1 + + return position_map + + +def is_filtered( + include_set: set[str], + exclude_set: set[str], + data: Any, + name_func: Callable[[Any], str], +) -> bool: + """ + Check if the object should be filtered out. + Overall logic: exclude > include > pin + :param include_set: the set of names to be included + :param exclude_set: the set of names to be excluded + :param name_func: the function to get the name of the object + :param data: the data to be filtered + :return: True if the object should be filtered out, False otherwise + """ + if not data: + return False + if not include_set and not exclude_set: + return False + + name = name_func(data) + + if name in exclude_set: # exclude_set is prioritized + return True + if include_set and name not in include_set: # filter out only if include_set is not empty + return True + return False + + +def sort_by_position_map( + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], +) -> list[Any]: + """ + Sort the objects by the position map. + If the name of the object is not in the position map, it will be put at the end. + :param position_map: the map holding positions in the form of {name: index} + :param name_func: the function to get the name of the object + :param data: the data to be sorted + :return: the sorted objects + """ + if not position_map or not data: + return data + + return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf"))) + + +def sort_to_dict_by_position_map( + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], +) -> OrderedDict[str, Any]: + """ + Sort the objects into a ordered dict by the position map. + If the name of the object is not in the position map, it will be put at the end. + :param position_map: the map holding positions in the form of {name: index} + :param name_func: the function to get the name of the object + :param data: the data to be sorted + :return: an OrderedDict with the sorted pairs of name and object + """ + sorted_items = sort_by_position_map(position_map, data, name_func) + return OrderedDict([(name_func(item), item) for item in sorted_items]) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..af1e527eb55a1255972cfa346bb9fc07212eab10 --- /dev/null +++ b/api/core/helper/ssrf_proxy.py @@ -0,0 +1,93 @@ +""" +Proxy requests to avoid SSRF +""" + +import logging +import time + +import httpx + +from configs import dify_config + +SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES + +BACKOFF_FACTOR = 0.5 +STATUS_FORCELIST = [429, 500, 502, 503, 504] + + +class MaxRetriesExceededError(ValueError): + """Raised when the maximum number of retries is exceeded.""" + + pass + + +def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + if "allow_redirects" in kwargs: + allow_redirects = kwargs.pop("allow_redirects") + if "follow_redirects" not in kwargs: + kwargs["follow_redirects"] = allow_redirects + + if "timeout" not in kwargs: + kwargs["timeout"] = httpx.Timeout( + timeout=dify_config.SSRF_DEFAULT_TIME_OUT, + connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT, + read=dify_config.SSRF_DEFAULT_READ_TIME_OUT, + write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, + ) + + retries = 0 + stream = kwargs.pop("stream", False) + while retries <= max_retries: + try: + if dify_config.SSRF_PROXY_ALL_URL: + with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL) as client: + response = client.request(method=method, url=url, **kwargs) + elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: + proxy_mounts = { + "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL), + "https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL), + } + with httpx.Client(mounts=proxy_mounts) as client: + response = client.request(method=method, url=url, **kwargs) + else: + with httpx.Client() as client: + response = client.request(method=method, url=url, **kwargs) + + if response.status_code not in STATUS_FORCELIST: + return response + else: + logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list") + + except httpx.RequestError as e: + logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}") + if max_retries == 0: + raise + + retries += 1 + if retries <= max_retries: + time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1))) + raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}") + + +def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request("GET", url, max_retries=max_retries, **kwargs) + + +def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request("POST", url, max_retries=max_retries, **kwargs) + + +def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request("PUT", url, max_retries=max_retries, **kwargs) + + +def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request("PATCH", url, max_retries=max_retries, **kwargs) + + +def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request("DELETE", url, max_retries=max_retries, **kwargs) + + +def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request("HEAD", url, max_retries=max_retries, **kwargs) diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..3b67b3f84838d36a269bf65a3981d74167f489fc --- /dev/null +++ b/api/core/helper/tool_parameter_cache.py @@ -0,0 +1,55 @@ +import json +from enum import Enum +from json import JSONDecodeError +from typing import Optional + +from extensions.ext_redis import redis_client + + +class ToolParameterCacheType(Enum): + PARAMETER = "tool_parameter" + + +class ToolParameterCache: + def __init__( + self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str + ): + self.cache_key = ( + f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" + f":identity_id:{identity_id}" + ) + + def get(self) -> Optional[dict]: + """ + Get cached model provider credentials. + + :return: + """ + cached_tool_parameter = redis_client.get(self.cache_key) + if cached_tool_parameter: + try: + cached_tool_parameter = cached_tool_parameter.decode("utf-8") + cached_tool_parameter = json.loads(cached_tool_parameter) + except JSONDecodeError: + return None + + return dict(cached_tool_parameter) + else: + return None + + def set(self, parameters: dict) -> None: + """ + Cache model provider credentials. + + :param credentials: provider credentials + :return: + """ + redis_client.setex(self.cache_key, 86400, json.dumps(parameters)) + + def delete(self) -> None: + """ + Delete cached model provider credentials. + + :return: + """ + redis_client.delete(self.cache_key) diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..6de5e704abf4f5782a20aaada672c3e2da2680a9 --- /dev/null +++ b/api/core/helper/tool_provider_cache.py @@ -0,0 +1,50 @@ +import json +from enum import Enum +from json import JSONDecodeError +from typing import Optional + +from extensions.ext_redis import redis_client + + +class ToolProviderCredentialsCacheType(Enum): + PROVIDER = "tool_provider" + + +class ToolProviderCredentialsCache: + def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): + self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" + + def get(self) -> Optional[dict]: + """ + Get cached model provider credentials. + + :return: + """ + cached_provider_credentials = redis_client.get(self.cache_key) + if cached_provider_credentials: + try: + cached_provider_credentials = cached_provider_credentials.decode("utf-8") + cached_provider_credentials = json.loads(cached_provider_credentials) + except JSONDecodeError: + return None + + return dict(cached_provider_credentials) + else: + return None + + def set(self, credentials: dict) -> None: + """ + Cache model provider credentials. + + :param credentials: provider credentials + :return: + """ + redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) + + def delete(self) -> None: + """ + Delete cached model provider credentials. + + :return: + """ + redis_client.delete(self.cache_key) diff --git a/api/core/llm_generator/__init__.py b/api/core/llm_generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe3f68f2a8af59abc9d8169b33fb60fcaa5a0db --- /dev/null +++ b/api/core/llm_generator/llm_generator.py @@ -0,0 +1,342 @@ +import json +import logging +import re +from typing import Optional, cast + +from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser +from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser +from core.llm_generator.prompts import ( + CONVERSATION_TITLE_PROMPT, + GENERATOR_QA_PROMPT, + JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, + PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, + WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, +) +from core.model_manager import ModelManager +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.utils import measure_time +from core.prompt.utils.prompt_template_parser import PromptTemplateParser + + +class LLMGenerator: + @classmethod + def generate_conversation_name( + cls, tenant_id: str, query, conversation_id: Optional[str] = None, app_id: Optional[str] = None + ): + prompt = CONVERSATION_TITLE_PROMPT + + if len(query) > 2000: + query = query[:300] + "...[TRUNCATED]..." + query[-300:] + + query = query.replace("\n", " ") + + prompt += query + "\n" + + model_manager = ModelManager() + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + ) + prompts = [UserPromptMessage(content=prompt)] + + with measure_time() as timer: + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False + ), + ) + answer = cast(str, response.message.content) + cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) + if cleaned_answer is None: + return "" + result_dict = json.loads(cleaned_answer) + answer = result_dict["Your Output"] + name = answer.strip() + + if len(name) > 75: + name = name[:75] + "..." + + # get tracing instance + trace_manager = TraceQueueManager(app_id=app_id) + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.GENERATE_NAME_TRACE, + conversation_id=conversation_id, + generate_conversation_name=name, + inputs=prompt, + timer=timer, + tenant_id=tenant_id, + ) + ) + + return name + + @classmethod + def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str): + output_parser = SuggestedQuestionsAfterAnswerOutputParser() + format_instructions = output_parser.get_format_instructions() + + prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n") + + prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions}) + + try: + model_manager = ModelManager() + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + ) + except InvokeAuthorizationError: + return [] + + prompt_messages = [UserPromptMessage(content=prompt)] + + try: + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters={"max_tokens": 256, "temperature": 0}, + stream=False, + ), + ) + + questions = output_parser.parse(cast(str, response.message.content)) + except InvokeError: + questions = [] + except Exception as e: + logging.exception("Failed to generate suggested questions after answer") + questions = [] + + return questions + + @classmethod + def generate_rule_config( + cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512 + ) -> dict: + output_parser = RuleConfigGeneratorOutputParser() + + error = "" + error_step = "" + rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} + model_parameters = {"max_tokens": rule_config_max_tokens, "temperature": 0.01} + + if no_variable: + prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) + + prompt_generate = prompt_template.format( + inputs={ + "TASK_DESCRIPTION": instruction, + }, + remove_template_variables=False, + ) + + prompt_messages = [UserPromptMessage(content=prompt_generate)] + + model_manager = ModelManager() + + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + ) + + try: + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), + ) + + rule_config["prompt"] = cast(str, response.message.content) + + except InvokeError as e: + error = str(e) + error_step = "generate rule config" + except Exception as e: + logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}") + rule_config["error"] = str(e) + + rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" + + return rule_config + + # get rule config prompt, parameter and statement + prompt_generate, parameter_generate, statement_generate = output_parser.get_format_instructions() + + prompt_template = PromptTemplateParser(prompt_generate) + + parameter_template = PromptTemplateParser(parameter_generate) + + statement_template = PromptTemplateParser(statement_generate) + + # format the prompt_generate_prompt + prompt_generate_prompt = prompt_template.format( + inputs={ + "TASK_DESCRIPTION": instruction, + }, + remove_template_variables=False, + ) + prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)] + + # get model instance + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), + ) + + try: + try: + # the first step to generate the task prompt + prompt_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), + ) + except InvokeError as e: + error = str(e) + error_step = "generate prefix prompt" + rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" + + return rule_config + + rule_config["prompt"] = cast(str, prompt_content.message.content) + + if not isinstance(prompt_content.message.content, str): + raise NotImplementedError("prompt content is not a string") + parameter_generate_prompt = parameter_template.format( + inputs={ + "INPUT_TEXT": prompt_content.message.content, + }, + remove_template_variables=False, + ) + parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)] + + # the second step to generate the task_parameter and task_statement + statement_generate_prompt = statement_template.format( + inputs={ + "TASK_DESCRIPTION": instruction, + "INPUT_TEXT": prompt_content.message.content, + }, + remove_template_variables=False, + ) + statement_messages = [UserPromptMessage(content=statement_generate_prompt)] + + try: + parameter_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False + ), + ) + rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) + except InvokeError as e: + error = str(e) + error_step = "generate variables" + + try: + statement_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=statement_messages, model_parameters=model_parameters, stream=False + ), + ) + rule_config["opening_statement"] = cast(str, statement_content.message.content) + except InvokeError as e: + error = str(e) + error_step = "generate conversation opener" + + except Exception as e: + logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}") + rule_config["error"] = str(e) + + rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" + + return rule_config + + @classmethod + def generate_code( + cls, + tenant_id: str, + instruction: str, + model_config: dict, + code_language: str = "javascript", + max_tokens: int = 1000, + ) -> dict: + if code_language == "python": + prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) + else: + prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE) + + prompt = prompt_template.format( + inputs={ + "INSTRUCTION": instruction, + "CODE_LANGUAGE": code_language, + }, + remove_template_variables=False, + ) + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), + ) + + prompt_messages = [UserPromptMessage(content=prompt)] + model_parameters = {"max_tokens": max_tokens, "temperature": 0.01} + + try: + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), + ) + + generated_code = cast(str, response.message.content) + return {"code": generated_code, "language": code_language, "error": ""} + + except InvokeError as e: + error = str(e) + return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} + except Exception as e: + logging.exception( + f"Failed to invoke LLM model, model: {model_config.get('name')}, language: {code_language}" + ) + return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"} + + @classmethod + def generate_qa_document(cls, tenant_id: str, query, document_language: str): + prompt = GENERATOR_QA_PROMPT.format(language=document_language) + + model_manager = ModelManager() + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + ) + + prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] + + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters={"temperature": 0.01, "max_tokens": 2000}, + stream=False, + ), + ) + + answer = cast(str, response.message.content) + return answer.strip() diff --git a/api/core/llm_generator/output_parser/__init__.py b/api/core/llm_generator/output_parser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/core/llm_generator/output_parser/errors.py b/api/core/llm_generator/output_parser/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..0922806ca88ce67268c663dbe97b3984f9498e64 --- /dev/null +++ b/api/core/llm_generator/output_parser/errors.py @@ -0,0 +1,2 @@ +class OutputParserError(ValueError): + pass diff --git a/api/core/llm_generator/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..0c7683b16d373ea6ae2837779badb30b3079c8e5 --- /dev/null +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -0,0 +1,32 @@ +from typing import Any + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.prompts import ( + RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, + RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, + RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, +) +from libs.json_in_md_parser import parse_and_check_json_markdown + + +class RuleConfigGeneratorOutputParser: + def get_format_instructions(self) -> tuple[str, str, str]: + return ( + RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, + RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, + RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, + ) + + def parse(self, text: str) -> Any: + try: + expected_keys = ["prompt", "variables", "opening_statement"] + parsed = parse_and_check_json_markdown(text, expected_keys) + if not isinstance(parsed["prompt"], str): + raise ValueError("Expected 'prompt' to be a string.") + if not isinstance(parsed["variables"], list): + raise ValueError("Expected 'variables' to be a list.") + if not isinstance(parsed["opening_statement"], str): + raise ValueError("Expected 'opening_statement' to be a str.") + return parsed + except Exception as e: + raise OutputParserError(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}") diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py new file mode 100644 index 0000000000000000000000000000000000000000..c451bf514cbf28f6304bf3b96e65b64c5b16e124 --- /dev/null +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -0,0 +1,19 @@ +import json +import re +from typing import Any + +from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT + + +class SuggestedQuestionsAfterAnswerOutputParser: + def get_format_instructions(self) -> str: + return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT + + def parse(self, text: str) -> Any: + action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) + if action_match is not None: + json_obj = json.loads(action_match.group(0).strip()) + else: + json_obj = [] + + return json_obj diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..f9411e9ec78c7ab71245f8608c285e118447d241 --- /dev/null +++ b/api/core/llm_generator/prompts.py @@ -0,0 +1,222 @@ +# Written by YORKI MINAKO🤡, Edited by Xiaoyi +CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is. +Notice: the language type user use could be diverse, which can be English, Chinese, Español, Arabic, Japanese, French, and etc. +MAKE SURE your output is the SAME language as the user's input! +Your output is restricted only to: (Input language) Intention + Subject(short as possible) +Your output MUST be a valid JSON. + +Tip: When the user's question is directed at you (the language model), you can add an emoji to make it more fun. + + +example 1: +User Input: hi, yesterday i had some burgers. +{ + "Language Type": "The user's input is pure English", + "Your Reasoning": "The language of my output must be pure English.", + "Your Output": "sharing yesterday's food" +} + +example 2: +User Input: hello +{ + "Language Type": "The user's input is written in pure English", + "Your Reasoning": "The language of my output must be pure English.", + "Your Output": "Greeting myself☺️" +} + + +example 3: +User Input: why mmap file: oom +{ + "Language Type": "The user's input is written in pure English", + "Your Reasoning": "The language of my output must be pure English.", + "Your Output": "Asking about the reason for mmap file: oom" +} + + +example 4: +User Input: www.convinceme.yesterday-you-ate-seafood.tv讲了什么? +{ + "Language Type": "The user's input English-Chinese mixed", + "Your Reasoning": "The English-part is an URL, the main intention is still written in Chinese, so the language of my output must be using Chinese.", + "Your Output": "询问网站www.convinceme.yesterday-you-ate-seafood.tv" +} + +example 5: +User Input: why小红的年龄is老than小明? +{ + "Language Type": "The user's input is English-Chinese mixed", + "Your Reasoning": "The English parts are subjective particles, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.", + "Your Output": "询问小红和小明的年龄" +} + +example 6: +User Input: yo, 你今天咋样? +{ + "Language Type": "The user's input is English-Chinese mixed", + "Your Reasoning": "The English-part is a subjective particle, the main intention is written in Chinese, so the language of my output must be using Chinese.", + "Your Output": "查询今日我的状态☺️" +} + +User Input: +""" # noqa: E501 + +PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE = ( + "You are an expert programmer. Generate code based on the following instructions:\n\n" + "Instructions: {{INSTRUCTION}}\n\n" + "Write the code in {{CODE_LANGUAGE}}.\n\n" + "Please ensure that you meet the following requirements:\n" + "1. Define a function named 'main'.\n" + "2. The 'main' function must return a dictionary (dict).\n" + "3. You may modify the arguments of the 'main' function, but include appropriate type hints.\n" + "4. The returned dictionary should contain at least one key-value pair.\n\n" + "5. You may ONLY use the following libraries in your code: \n" + "- json\n" + "- datetime\n" + "- math\n" + "- random\n" + "- re\n" + "- string\n" + "- sys\n" + "- time\n" + "- traceback\n" + "- uuid\n" + "- os\n" + "- base64\n" + "- hashlib\n" + "- hmac\n" + "- binascii\n" + "- collections\n" + "- functools\n" + "- operator\n" + "- itertools\n\n" + "Example:\n" + "def main(arg1: str, arg2: int) -> dict:\n" + " return {\n" + ' "result": arg1 * arg2,\n' + " }\n\n" + "IMPORTANT:\n" + "- Provide ONLY the code without any additional explanations, comments, or markdown formatting.\n" + "- DO NOT use markdown code blocks (``` or ``` python). Return the raw code directly.\n" + "- The code should start immediately after this instruction, without any preceding newlines or spaces.\n" + "- The code should be complete, functional, and follow best practices for {{CODE_LANGUAGE}}.\n\n" + "- Always use the format return {'result': ...} for the output.\n\n" + "Generated Code:\n" +) +JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = ( + "You are an expert programmer. Generate code based on the following instructions:\n\n" + "Instructions: {{INSTRUCTION}}\n\n" + "Write the code in {{CODE_LANGUAGE}}.\n\n" + "Please ensure that you meet the following requirements:\n" + "1. Define a function named 'main'.\n" + "2. The 'main' function must return an object.\n" + "3. You may modify the arguments of the 'main' function, but include appropriate JSDoc annotations.\n" + "4. The returned object should contain at least one key-value pair.\n\n" + "5. The returned object should always be in the format: {result: ...}\n\n" + "Example:\n" + "function main(arg1, arg2) {\n" + " return {\n" + " result: arg1 * arg2\n" + " };\n" + "}\n\n" + "IMPORTANT:\n" + "- Provide ONLY the code without any additional explanations, comments, or markdown formatting.\n" + "- DO NOT use markdown code blocks (``` or ``` javascript). Return the raw code directly.\n" + "- The code should start immediately after this instruction, without any preceding newlines or spaces.\n" + "- The code should be complete, functional, and follow best practices for {{CODE_LANGUAGE}}.\n\n" + "Generated Code:\n" +) + + +SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( + "Please help me predict the three most likely questions that human would ask, " + "and keeping each question under 20 characters.\n" + "MAKE SURE your output is the SAME language as the Assistant's latest response. " + "The output must be an array in JSON format following the specified schema:\n" + '["question1","question2","question3"]\n' +) + +GENERATOR_QA_PROMPT = ( + " The user will send a long text. Generate a Question and Answer pairs only using the knowledge" + " in the long text. Please think step by step." + "Step 1: Understand and summarize the main content of this text.\n" + "Step 2: What key information or concepts are mentioned in this text?\n" + "Step 3: Decompose or combine multiple pieces of information and concepts.\n" + "Step 4: Generate questions and answers based on these key information and concepts.\n" + " The questions should be clear and detailed, and the answers should be detailed and complete. " + "You must answer in {language}, in a style that is clear and detailed in {language}." + " No language other than {language} should be used. \n" + " Use the following format: Q1:\nA1:\nQ2:\nA2:...\n" + "" +) + +WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """ +Here is a task description for which I would like you to create a high-quality prompt template for: + +{{TASK_DESCRIPTION}} + +Based on task description, please create a well-structured prompt template that another AI could use to consistently complete the task. The prompt template should include: +- Do not include or section and variables in the prompt, assume user will add them at their own will. +- Clear instructions for the AI that will be using this prompt, demarcated with tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag. +- Relevant examples if needed to clarify the task further, demarcated with tags. Do not include variables in the prompt. Give three pairs of input and output examples. +- Include other relevant sections demarcated with appropriate XML tags like , . +- Use the same language as task description. +- Output in ``` xml ``` and start with +Please generate the full prompt template with at least 300 words and output only the prompt template. +""" # noqa: E501 + +RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """ +Here is a task description for which I would like you to create a high-quality prompt template for: + +{{TASK_DESCRIPTION}} + +Based on task description, please create a well-structured prompt template that another AI could use to consistently complete the task. The prompt template should include: +- Descriptive variable names surrounded by {{ }} (two curly brackets) to indicate where the actual values will be substituted in. Choose variable names that clearly indicate the type of value expected. Variable names have to be composed of number, english alphabets and underline and nothing else. +- Clear instructions for the AI that will be using this prompt, demarcated with tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag. +- Relevant examples if needed to clarify the task further, demarcated with tags. Do not use curly brackets any other than in section. +- Any other relevant sections demarcated with appropriate XML tags like , , etc. +- Use the same language as task description. +- Output in ``` xml ``` and start with +Please generate the full prompt template and output only the prompt template. +""" # noqa: E501 + +RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE = """ +I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted. + +variables name bounded two double curly brackets. Variable name has to be composed of number, english alphabets and underline and nothing else. + + +Step 1: Carefully read the input and understand the structure of the expected output. +Step 2: Extract relevant parameters from the provided text based on the name and description of object. +Step 3: Structure the extracted parameters to JSON object as specified in . +Step 4: Ensure that the list of variable_names is properly formatted and valid. The output should not contain any XML tags. Output an empty list if there is no valid variable name in input text. + +### Structure +Here is the structure of the expected output, I should always follow the output structure. +["variable_name_1", "variable_name_2"] + +### Input Text +Inside XML tags, there is a text that I should extract parameters and convert to a JSON object. + +{{INPUT_TEXT}} + + +### Answer +I should always output a valid list. Output nothing other than the list of variable_name. Output an empty list if there is no variable name in input text. +""" # noqa: E501 + +RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE = """ + +Step 1: Identify the purpose of the chatbot from the variable {{TASK_DESCRIPTION}} and infer chatbot's tone (e.g., friendly, professional, etc.) to add personality traits. +Step 2: Create a coherent and engaging opening statement. +Step 3: Ensure the output is welcoming and clearly explains what the chatbot is designed to do. Do not include any XML tags in the output. +Please use the same language as the user's input language. If user uses chinese then generate opening statement in chinese, if user uses english then generate opening statement in english. +Example Input: +Provide customer support for an e-commerce website +Example Output: +Welcome! I'm here to assist you with any questions or issues you might have with your shopping experience. Whether you're looking for product information, need help with your order, or have any other inquiries, feel free to ask. I'm friendly, helpful, and ready to support you in any way I can. + +Here is the task description: {{INPUT_TEXT}} + +You just need to generate the output +""" # noqa: E501 diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..003a0c85b1f12e4cb69cb4a41fbdb7b08fdb7551 --- /dev/null +++ b/api/core/memory/token_buffer_memory.py @@ -0,0 +1,171 @@ +from collections.abc import Sequence +from typing import Optional + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.file import file_manager +from core.model_manager import ModelInstance +from core.model_runtime.entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, + UserPromptMessage, +) +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from extensions.ext_database import db +from factories import file_factory +from models.model import AppMode, Conversation, Message, MessageFile +from models.workflow import WorkflowRun + + +class TokenBufferMemory: + def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None: + self.conversation = conversation + self.model_instance = model_instance + + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: Optional[int] = None + ) -> Sequence[PromptMessage]: + """ + Get history prompt messages. + :param max_token_limit: max token limit + :param message_limit: message limit + """ + app_record = self.conversation.app + + # fetch limited messages, and return reversed + query = ( + db.session.query( + Message.id, + Message.query, + Message.answer, + Message.created_at, + Message.workflow_run_id, + Message.parent_message_id, + ) + .filter( + Message.conversation_id == self.conversation.id, + ) + .order_by(Message.created_at.desc()) + ) + + if message_limit and message_limit > 0: + message_limit = min(message_limit, 500) + else: + message_limit = 500 + + messages = query.limit(message_limit).all() + + # instead of all messages from the conversation, we only need to extract messages + # that belong to the thread of last message + thread_messages = extract_thread_messages(messages) + + # for newly created message, its answer is temporarily empty, we don't need to add it to memory + if thread_messages and not thread_messages[0].answer: + thread_messages.pop(0) + + messages = list(reversed(thread_messages)) + + prompt_messages: list[PromptMessage] = [] + for message in messages: + files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() + if files: + file_extra_config = None + if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) + else: + if message.workflow_run_id: + workflow_run = ( + db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() + ) + + if workflow_run and workflow_run.workflow: + file_extra_config = FileUploadConfigManager.convert( + workflow_run.workflow.features_dict, is_vision=False + ) + + detail = ImagePromptMessageContent.DETAIL.LOW + if file_extra_config and app_record: + file_objs = file_factory.build_from_message_files( + message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config + ) + if file_extra_config.image_config and file_extra_config.image_config.detail: + detail = file_extra_config.image_config.detail + else: + file_objs = [] + + if not file_objs: + prompt_messages.append(UserPromptMessage(content=message.query)) + else: + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=message.query)) + for file in file_objs: + prompt_message = file_manager.to_prompt_message_content( + file, + image_detail_config=detail, + ) + prompt_message_contents.append(prompt_message) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + + else: + prompt_messages.append(UserPromptMessage(content=message.query)) + + prompt_messages.append(AssistantPromptMessage(content=message.answer)) + + if not prompt_messages: + return [] + + # prune the chat message if it exceeds the max token limit + curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) + + if curr_message_tokens > max_token_limit: + pruned_memory = [] + while curr_message_tokens > max_token_limit and len(prompt_messages) > 1: + pruned_memory.append(prompt_messages.pop(0)) + curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) + + return prompt_messages + + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + ) -> str: + """ + Get history prompt text. + :param human_prefix: human prefix + :param ai_prefix: ai prefix + :param max_token_limit: max token limit + :param message_limit: message limit + :return: + """ + prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit) + + string_messages = [] + for m in prompt_messages: + if m.role == PromptMessageRole.USER: + role = human_prefix + elif m.role == PromptMessageRole.ASSISTANT: + role = ai_prefix + else: + continue + + if isinstance(m.content, list): + inner_msg = "" + for content in m.content: + if isinstance(content, TextPromptMessageContent): + inner_msg += f"{content.data}\n" + elif isinstance(content, ImagePromptMessageContent): + inner_msg += "[image]\n" + + string_messages.append(f"{role}: {inner_msg.strip()}") + else: + message = f"{role}: {m.content}" + string_messages.append(message) + + return "\n".join(string_messages)