|
import copy |
|
import time |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import hydra |
|
from pydantic import root_validator |
|
|
|
from langchain import LLMChain, PromptTemplate |
|
from langchain.agents import AgentExecutor, BaseMultiActionAgent, ZeroShotAgent |
|
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.schema import ( |
|
AgentAction, |
|
AgentFinish, |
|
OutputParserException, |
|
) |
|
|
|
from flows.base_flows import Flow, CompositeFlow, GenericLCTool |
|
from flows.messages import OutputMessage, UpdateMessage_Generic |
|
from flows.utils.caching_utils import flow_run_cache |
|
|
|
|
|
class GenericZeroShotAgent(ZeroShotAgent): |
|
@classmethod |
|
def create_prompt( |
|
cls, |
|
tools: Dict[str, Flow], |
|
prefix: str = PREFIX, |
|
suffix: str = SUFFIX, |
|
format_instructions: str = FORMAT_INSTRUCTIONS, |
|
input_variables: Optional[List[str]] = None, |
|
) -> PromptTemplate: |
|
"""Create prompt in the style of the zero shot agent. |
|
|
|
Args: |
|
tools: List of tools the agent will have access to, used to format the |
|
prompt. |
|
prefix: String to put before the list of tools. |
|
suffix: String to put after the list of tools. |
|
input_variables: List of input variables the final prompt will expect. |
|
|
|
Returns: |
|
A PromptTemplate with the template assembled from the pieces here. |
|
""" |
|
|
|
|
|
tool_strings = "\n".join([f"{tool_name}: {tool.flow_config['description']}" for tool_name, tool in tools.items()]) |
|
tool_names = ", ".join(tools.keys()) |
|
format_instructions = format_instructions.format(tool_names=tool_names) |
|
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) |
|
if input_variables is None: |
|
input_variables = ["input", "agent_scratchpad"] |
|
return PromptTemplate(template=template, input_variables=input_variables) |
|
|
|
|
|
class GenericAgentExecutor(AgentExecutor): |
|
tools: Dict[str, Flow] |
|
|
|
@root_validator() |
|
def validate_tools(cls, values: Dict) -> Dict: |
|
"""Validate that tools are compatible with agent.""" |
|
agent = values["agent"] |
|
tools = values["tools"] |
|
allowed_tools = agent.get_allowed_tools() |
|
if allowed_tools is not None: |
|
if set(allowed_tools) != set(tools.keys()): |
|
raise ValueError( |
|
f"Allowed tools ({allowed_tools}) different than " |
|
f"provided tools ({tools.keys()})" |
|
) |
|
return values |
|
|
|
@root_validator() |
|
def validate_return_direct_tool(cls, values: Dict) -> Dict: |
|
"""Validate that tools are compatible with agent.""" |
|
agent = values["agent"] |
|
tools = values["tools"] |
|
if isinstance(agent, BaseMultiActionAgent): |
|
for tool in tools: |
|
if tool.flow_config["return_direct"]: |
|
raise ValueError( |
|
"Tools that have `return_direct=True` are not allowed " |
|
"in multi-action agents" |
|
) |
|
return values |
|
|
|
def _get_tool_return( |
|
self, next_step_output: Tuple[AgentAction, str] |
|
) -> Optional[AgentFinish]: |
|
"""Check if the tool is a returning tool.""" |
|
agent_action, observation = next_step_output |
|
|
|
|
|
if agent_action.tool in self.tools: |
|
if self.tools[agent_action.tool].flow_config["return_direct"]: |
|
return AgentFinish( |
|
{self.agent.return_values[0]: observation}, |
|
"", |
|
) |
|
return None |
|
|
|
|
|
class ReActFlow(CompositeFlow): |
|
EXCEPTION_FLOW_CONFIG = { |
|
"_target_": "flows.base_flows.GenericLCTool.instantiate_from_config", |
|
"config": { |
|
"name": "_Exception", |
|
"description": "Exception tool", |
|
|
|
"tool_type": "exception", |
|
"input_keys": ["query"], |
|
"output_keys": ["raw_response"], |
|
|
|
"verbose": False, |
|
"clear_flow_namespace_on_run_end": False, |
|
|
|
"input_data_transformations": [], |
|
"output_data_transformations": [], |
|
"keep_raw_response": True |
|
} |
|
} |
|
|
|
INVALID_FLOW_CONFIG = { |
|
"_target_": "flows.base_flows.GenericLCTool.instantiate_from_config", |
|
"config": { |
|
"name": "invalid_tool", |
|
"description": "Called when tool name is invalid.", |
|
|
|
"tool_type": "invalid", |
|
"input_keys": ["tool_name"], |
|
"output_keys": ["raw_response"], |
|
|
|
"verbose": False, |
|
"clear_flow_namespace_on_run_end": False, |
|
|
|
"input_data_transformations": [], |
|
"output_data_transformations": [], |
|
"keep_raw_response": True |
|
} |
|
} |
|
|
|
SUPPORTS_CACHING: bool = True |
|
|
|
api_keys: Dict[str, str] |
|
|
|
backend: GenericAgentExecutor |
|
react_prompt_template: PromptTemplate |
|
|
|
exception_flow: GenericLCTool |
|
invalid_flow: GenericLCTool |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.api_keys = None |
|
self.backend = None |
|
self.react_prompt_template = GenericZeroShotAgent.create_prompt( |
|
tools=self.subflows, |
|
**self.flow_config.get("prompt_config", {}) |
|
) |
|
|
|
self._set_up_necessary_subflows() |
|
|
|
def set_up_flow_state(self): |
|
super().set_up_flow_state() |
|
self.flow_state["intermediate_steps"]: List[Tuple[AgentAction, str]] = [] |
|
|
|
def _set_up_necessary_subflows(self): |
|
self.exception_flow = hydra.utils.instantiate( |
|
self.EXCEPTION_FLOW_CONFIG, _convert_="partial", _recursive_=False |
|
) |
|
self.invalid_flow = hydra.utils.instantiate( |
|
self.INVALID_FLOW_CONFIG, _convert_="partial", _recursive_=False |
|
) |
|
|
|
def _get_prompt_message(self, input_data: Dict[str, Any]) -> str: |
|
data = copy.deepcopy(input_data) |
|
data["agent_scratchpad"] = "{agent_scratchpad}" |
|
|
|
return self.react_prompt_template.format(**data) |
|
|
|
@staticmethod |
|
def get_raw_response(output: OutputMessage) -> str: |
|
key = output.data["output_keys"][0] |
|
return output.data["output_data"]["raw_response"][key] |
|
|
|
def _take_next_step( |
|
self, |
|
|
|
|
|
inputs: Dict[str, str], |
|
intermediate_steps: List[Tuple[AgentAction, str]], |
|
|
|
|
|
private_keys: Optional[List[str]] = [], |
|
keys_to_ignore_for_hash: Optional[List[str]] = [] |
|
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: |
|
"""Take a single step in the thought-action-observation loop. |
|
|
|
Override this to take control of how the agent makes and acts on choices. |
|
""" |
|
try: |
|
|
|
output = self.backend.agent.plan( |
|
intermediate_steps, |
|
|
|
**inputs, |
|
) |
|
except OutputParserException as e: |
|
if isinstance(self.backend.handle_parsing_errors, bool): |
|
raise_error = not self.backend.handle_parsing_errors |
|
else: |
|
raise_error = False |
|
if raise_error: |
|
raise e |
|
text = str(e) |
|
|
|
if isinstance(self.backend.handle_parsing_errors, bool): |
|
if e.send_to_llm: |
|
observation = str(e.observation) |
|
text = str(e.llm_output) |
|
else: |
|
observation = "Invalid or incomplete response" |
|
elif isinstance(self.backend.handle_parsing_errors, str): |
|
observation = self.backend.handle_parsing_errors |
|
elif callable(self.backend.handle_parsing_errors): |
|
observation = self.backend.handle_parsing_errors(e) |
|
else: |
|
raise ValueError("Got unexpected type of `handle_parsing_errors`") |
|
|
|
output = AgentAction("_Exception", observation, text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._state_update_dict({"query": output.tool_input}) |
|
tool_output = self._call_flow_from_state( |
|
self.exception_flow, |
|
private_keys=private_keys, |
|
keys_to_ignore_for_hash=keys_to_ignore_for_hash, |
|
search_class_namespace_for_inputs=False |
|
) |
|
observation = self.get_raw_response(tool_output) |
|
return [(output, observation)] |
|
|
|
|
|
if isinstance(output, AgentFinish): |
|
return output |
|
|
|
actions: List[AgentAction] |
|
if isinstance(output, AgentAction): |
|
actions = [output] |
|
else: |
|
actions = output |
|
result = [] |
|
for agent_action in actions: |
|
|
|
|
|
|
|
if agent_action.tool in self.subflows: |
|
tool = self.subflows[agent_action.tool] |
|
|
|
if isinstance(agent_action.tool_input, dict): |
|
self._state_update_dict(agent_action.tool_input) |
|
else: |
|
self._state_update_dict({tool.flow_config["input_keys"][0]:agent_action.tool_input}) |
|
|
|
tool_output = self._call_flow_from_state( |
|
tool, |
|
private_keys=private_keys, |
|
keys_to_ignore_for_hash=keys_to_ignore_for_hash, |
|
search_class_namespace_for_inputs=False |
|
) |
|
observation = self.get_raw_response(tool_output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._state_update_dict({"tool_name": agent_action.tool}) |
|
tool_output = self._call_flow_from_state( |
|
self.invalid_flow, |
|
private_keys=private_keys, |
|
keys_to_ignore_for_hash=keys_to_ignore_for_hash, |
|
search_class_namespace_for_inputs=False |
|
) |
|
observation = self.get_raw_response(tool_output) |
|
result.append((agent_action, observation)) |
|
return result |
|
|
|
def _run( |
|
self, |
|
input_data: Dict[str, Any], |
|
private_keys: Optional[List[str]] = [], |
|
keys_to_ignore_for_hash: Optional[List[str]] = [] |
|
) -> str: |
|
"""Run text through and get agent response.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.flow_state["intermediate_steps"] = [] |
|
intermediate_steps = self.flow_state["intermediate_steps"] |
|
|
|
iterations = 0 |
|
time_elapsed = 0.0 |
|
start_time = time.time() |
|
|
|
while self.backend._should_continue(iterations, time_elapsed): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
next_step_output = self._take_next_step( |
|
input_data, |
|
intermediate_steps, |
|
private_keys, |
|
keys_to_ignore_for_hash |
|
) |
|
if isinstance(next_step_output, AgentFinish): |
|
|
|
return next_step_output.return_values["output"] |
|
|
|
intermediate_steps.extend(next_step_output) |
|
for act, obs in next_step_output: |
|
pass |
|
|
|
|
|
|
|
if len(next_step_output) == 1: |
|
next_step_action = next_step_output[0] |
|
|
|
tool_return = self.backend._get_tool_return(next_step_action) |
|
if tool_return is not None: |
|
|
|
return tool_return.return_values["output"] |
|
|
|
iterations += 1 |
|
time_elapsed = time.time() - start_time |
|
|
|
output = self.backend.agent.return_stopped_response( |
|
self.backend.early_stopping_method, intermediate_steps, **input_data |
|
) |
|
return output.return_values["output"] |
|
|
|
@flow_run_cache() |
|
def run( |
|
self, |
|
input_data: Dict[str, Any], |
|
private_keys: Optional[List[str]] = [], |
|
keys_to_ignore_for_hash: Optional[List[str]] = [] |
|
) -> Dict[str, Any]: |
|
self.api_keys = input_data["api_keys"] |
|
del input_data["api_keys"] |
|
|
|
llm = ChatOpenAI( |
|
model_name=self.flow_config["model_name"], |
|
openai_api_key=self.api_keys["openai"], |
|
**self.flow_config["generation_parameters"], |
|
) |
|
llm_chain = LLMChain(llm=llm, prompt=self.react_prompt_template) |
|
agent = GenericZeroShotAgent(llm_chain=llm_chain, allowed_tools=list(self.subflows.keys())) |
|
|
|
self.backend = GenericAgentExecutor.from_agent_and_tools( |
|
agent=agent, |
|
tools=self.subflows, |
|
max_iterations=self.flow_config.get("max_iterations", 15), |
|
max_execution_time=self.flow_config.get("max_execution_time") |
|
) |
|
|
|
data = {k: input_data[k] for k in self.get_input_keys(input_data)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = self._run(data, private_keys, keys_to_ignore_for_hash) |
|
|
|
return {input_data["output_keys"][0]: output} |
|
|