agent-chain / chains.py
jonathanjordan21's picture
Upload folder using huggingface_hub
18f5c57 verified
raw
history blame
17 kB
from typing import Annotated, Literal
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, MessagesState, START, END
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.output_parsers import JsonOutputParser
from langchain_community.document_transformers import BeautifulSoupTransformer, beautiful_soup_transformer
from langgraph.types import Command
from langchain_groq import ChatGroq
import operator
import pprint
import os
import requests
import html2text
API_KEY = os.getenv("GROQ_API_KEY")
OUT_RES = "<|FINISHED|>"
HTML_TRANSFORMER = html2text.HTML2Text()
HTML_TRANSFORMER.ignore_links = True
HTML_TRANSFORMER.ignore_images = True
BS_TRANSFORMER = BeautifulSoupTransformer()
def local_message_add(dict1, dict2):
key2 = list(dict2.keys())[0]
if key2 not in dict1:
dict1[key2] = dict2[key2]
else:
dict1[key2] = dict1[key2] + dict2[key2]
return dict1
def variable_state_update(dict1, dict2):
dict1.update(dict2)
return dict1
class GeneralStates(TypedDict):
messages: Annotated[list[dict[str, str]], lambda x,y:x+y]
checkpoints: dict[str,list]
local_messages: Annotated[dict, local_message_add]
variables: Annotated[dict, variable_state_update]
def format_sequence(seq, nested=False):
if isinstance(seq, (list, tuple, set, frozenset, dict)):# and not isinstance(seq, str):
# Handle dictionaries
if isinstance(seq, dict):
return format_dict(seq, nested=nested)
# Handle lists, tuples, sets, and frozensets
else:
return format_list_like(seq, nested=nested)
else:
return seq
# else:
# raise TypeError(f"Input must be a sequence (list, tuple, set, frozenset, dict, string), not a {type(seq)}")
def format_dict(d, nested=False):
# Format dictionary without enclosing braces
items = []
for i, (key, value) in enumerate(d.items()):
if isinstance(value, (list, tuple, set, frozenset, dict)):
value = format_sequence(value, nested=True) # Recursively format nested sequences
if not nested:
items.append(f"{i+1}. {key}: {value}")
else:
items.append(f"{key}: {value}")
return ",\n".join(items)
def format_list_like(seq, nested=False):
# Format list-like objects without enclosing brackets/parentheses
items = []
for i,item in enumerate(seq):
if isinstance(item, (list, tuple, set, frozenset, dict)):
item = format_sequence(item, nested=True) # Recursively format nested sequences
if not nested:
items.append(f"{i+1}. {item}")
else:
items.append(str(item))
return ",\n".join(items)
def format_dict_api(input_dict, combined):
formatted_dict = {}
for key, value in input_dict.items():
if isinstance(value, dict):
formatted_dict[key] = format_dict_api(value, combined)
elif isinstance(value, str):
# try:
formatted_dict[key] = value.format(**combined)
# except KeyError as e:
# print(f"Warning: Key {e} not found in combined dictionary for string {value}. Skipping formatting for this string.")
# formatted_dict[key] = value # keep original string if key not found
else:
formatted_dict[key] = value
return formatted_dict
def run_api(api_endpoints, variables, response, input_message, chain_id):
if not api_endpoints:
return {}
combined = variables.copy()
if response:
api_endpoint_type = "output"
if isinstance(response, dict):
combined = combined | response
# combined |= response
else:
combined["output_message"] = response
else:
api_endpoint_type = "input"
combined["input_message"] = input_message
resp = []
errors = []
for x in api_endpoints:
try:
input_var = {inp: combined[inp] for inp in x["input_variables"]}
res = requests.request(
x['method'],
x['url'],
headers = format_dict_api(x['headers'], input_var) if x["headers"] else None,
params = format_dict_api(x["params"], input_var) if x["params"] else None,
json = format_dict_api(x["request_body"], input_var) if x["request_body"] else None,
)
if x['response_type'] == 'json':
res = res.json()
else:
res = res.text
if res[:15] == "<!DOCTYPE html>":
if x["html_to_markdown"]:
res = HTML_TRANSFORMER.handle(res)
elif x["html_tags_to_extract"]:
res = BS_TRANSFORMER.extract_tags(res, tags=x["html_tags_to_extract"])
resp.append([res, x["name"]])
except Exception as e:
errors.append([e, x["name"]])
api_dict = {}
# if resp:
for x in resp:
# api_dict[f"<|{api_endpoint_type}_API_SUCCESS_{chain_id}_{x[1]}|>"] = x[0]
api_dict[f"{api_endpoint_type}_{x[1]}_{chain_id}_success"] = x[0]
for x in errors:
# api_dict[f"<|{api_endpoint_type}_API_ERROR_{chain_id}_{x[1]}|>"] = x[0]
api_dict[f"{api_endpoint_type}_{x[1]}_{chain_id}_error"] = x[0]
variables.update(api_dict)
return api_dict
def agent_builder(states: GeneralStates, chain: dict, row:int, depth: int):
# print("[BUILD AGENT] Start....")
# agent = chain.get("agent")
model_config = chain.get("agent")
print("[MODEL CONFIG]", model_config)
child = chain.get("child")
checkpoints = states.get("checkpoints", {})
print("[STATES]", states)
for k,v in checkpoints.items():
# if k == model_config["name"]:
if k == chain["id"]:
return Command(goto=v)
api_dict = {"variables":{}}
variables = states.get("variables", {})
variables["input_message"] = states["messages"][-1].content
# print("[CHAIN]", chain),
# print()
api_res = run_api(model_config["input_api_endpoints"], variables, None, states["messages"][-1].content, chain["id"])
api_dict["variables"].update(api_res)
for c in child:
if c["condition_from"] == "input" and states['messages'][-1].content.strip() == c["condition"]:
redirect_agent_message = AIMessage(f"Switch to Agent {c['id']}")
# local_message = states["local_messages"].get(model_config["name"])
local_message = states["local_messages"].get(chain["id"])
if local_message:
update_dict = {
"local_messages": {
# model_config["name"]:[redirect_agent_message],
chain["id"]:[redirect_agent_message],
c["id"]:[states['messages'][-1]]
},
}
else:
update_dict = {}
if c.get("checkpoint"):
# update_dict["checkpoints"] = {model_config["name"]:c["id"]}
update_dict["checkpoints"] = {chain["id"]:c["id"]}
return Command(goto=c["id"], update=update_dict | api_dict)
# messages = states["local_messages"].get(model_config["name"])
messages = states["local_messages"].get(chain["id"])
if messages:
messages.append(states["messages"][-1])
else:
messages = states["messages"]
input_var = model_config.get("input_variables")
output_variables = model_config.get("output_variables")
if model_config.get("is_template"):
response = model_config.get("prompt")
if input_var:
# response = response.format(**{var: format_sequence(variables[var]) for var in input_var})
response = response.format(**{var: variables[var] for var in input_var})
response = AIMessage(response)
api_res = run_api(model_config["output_api_endpoints"], variables, response.content, messages[-1].content, chain["id"])
api_dict["variables"].update(api_res)
# return {"messages":[response], "local_messages":{model_config["name"]:[response]}}
if output_variables:
out = {out_var: response.content for out_var in output_variables}
if "messages" not in output_variables:
api_dict["variables"].update(out)
return api_dict
else:
out.pop("messages")
api_dict["variables"].update(out)
return {"messages":[response], "local_messages":{chain["id"]:[response]}} | api_dict
def run_agent(i, loop_input_variables, variables):
if input_var:
print("[AGENT ID]", chain['id'])
print("[INPUT VARIABLES]", input_var)
print("[VARIABLES]", variables)
user_input = "\n".join([str(variables[var]) for var in input_var])
# print("[USER INPUT]", user_input)
if i == -1:
prompt = model_config.get("prompt").format(**{var: variables[var] for var in input_var})
else:
prompt = model_config.get("prompt").format(
**{var: variables[var][i] if var in loop_input_variables else variables[var] for var in input_var}
)
else:
user_input = messages[-1].content
prompt = model_config.get("prompt") + "\n\n" + messages[-1].content
model = ChatGroq(
# model="mixtral-8x7b-32768",
# model="llama-3.2-11b-vision-preview",
# model="llama-3.1-8b-instant",
# model = "gemma2-9b-it",
model="llama-3.3-70b-versatile",
# model="deepseek-r1-distill-llama-70b",
temperature=model_config.get("creativity"),
max_tokens=None,
timeout=None,
max_retries=2,
api_key=API_KEY
)
routes = model_config.get("routes")
output_collector = model_config.get("output_collector")
if routes:
add_prompt = f"YOU MUST GENERATE OUTPUT STRICTLY one of the following list : [{', '.join(routes)}]\n\n"
if model_config.get("routes_description"):
add_prompt += "HERE IS THE CONDITIONS FOR EACH OUTPUT:\n"
add_prompt += "\n".join([f"{x}: {y}" for x,y in zip(routes, model_config.get("routes_description"))])
add_prompt += "\n\n"
prompt = add_prompt + prompt
elif output_collector:
# CLOSING AND OPENING BRACKETS
add_prompt = f"YOU MUST GENERATE OUTPUT STRICTLY IN THE FOLLOWING JSON FORMAT, REMEMBER TO ADD {{}} BEFORE AND AFTER JSON CODE:\n"
# add_prompt += "\n".join([f"{x['name']}: {x['data_type']} = {x['description']}" for x in output_collector])
add_prompt += "\n".join(output_collector)
add_prompt += "\n\n"
prompt = prompt +"\n\n"+ add_prompt
response = (model | JsonOutputParser()).invoke(messages[:-1] + [HumanMessage(content=prompt)])
if output_variables:
for k in response.keys():
if k not in output_variables:
del response[k]
api_res = run_api(model_config["output_api_endpoints"], variables, response, messages[-1].content, chain["id"])
api_dict["variables"].update(api_res)
return {"variables":response | api_dict["variables"]}
response = model.invoke(messages[:-1] + [HumanMessage(content=prompt)])
for c in child:
if c["condition_from"] == "output" and response.content.strip() == c["condition"]:
redirect_agent_message = AIMessage(f"Switch to Agent {c['id']}")
# local_message = states["local_messages"].get(model_config["name"])
local_message = states["local_messages"].get(chain["id"])
if local_message:
update_dict = {
"local_messages": {
# model_config["name"]:[redirect_agent_message],
chain["id"]:[redirect_agent_message],
c["id"]:[HumanMessage(user_input)]
},
}
else:
update_dict = {}
if c.get("checkpoint"):
# update_dict["checkpoints"] = {model_config["name"]:c["id"]}
update_dict["checkpoints"] = {chain["id"]:c["id"]}
api_res = run_api(model_config["output_api_endpoints"], variables, response.content, messages[-1].content, chain["id"])
api_dict["variables"].update(api_res)
# if output_variables:
# api_dict["variables"].update({out_var: response.content for out_var in output_variables})
if output_variables:
out = {out_var: response.content for out_var in output_variables}
if "messages" not in output_variables:
api_dict["variables"].update(out)
return api_dict
else:
api_dict["messages"] = out.pop("messages")
api_dict["variables"].update(out)
return Command(goto=c["id"], update=update_dict | api_dict)
elif response.content.strip() == OUT_RES:
api_res = run_api(model_config["output_api_endpoints"], variables, None, messages[-1].content, chain["id"])
api_dict["variables"].update(api_res)
return {} | api_dict
api_res = run_api(model_config["output_api_endpoints"], variables, response.content, messages[-1].content, chain["id"])
api_dict["variables"].update(api_res)
if output_variables:
out = {out_var: response.content for out_var in output_variables}
if "messages" not in output_variables:
api_dict["variables"].update(out)
return api_dict
else:
out.pop("messages")
api_dict["variables"].update(out)
# return {"messages":[response], "local_messages":{model_config["name"]:[response]}} | api_dict
return {"messages":[response], "local_messages":{chain["id"]:[response]}} | api_dict
if not chain["loop_input_variables"]:
return run_agent(-1, [], variables)
else:
max_loop = min([len(states["variables"].get(x)) for x in chain["loop_input_variables"]])
updates = {"variables":{}}
for i in range(max_loop):
out_variables = run_agent(i, chain["loop_input_variables"], variables)
if type(out_variables) == dict:
if not out_variables.get("variables"):
continue
for k,v in out_variables["variables"].items():
if k not in updates["variables"].keys():
updates["variables"][k] = []
if type(v) == list:
updates["variables"][k] += v
else:
updates["variables"][k].append(v)
else:
updates = out_variables
return updates
def route(states, routes):
if states["messages"][-1].content.strip() in routes:
return states["messages"][-1].content.strip()
return END
def build_chain(chains, checkpointer, parent_name=None, depth=0):
print("[BUILD CHAIN] START....")
stack = [(chains, parent_name, depth, 0)]
builder = StateGraph(GeneralStates)
while stack:
current_chains, current_parent, current_depth, i = stack.pop()
print("STACK", i)
if i >= len(current_chains):
continue
c = current_chains[i]
c_id = c["id"]
# agent = agents[c["agent"]]
try:
print("ADDED NODE!")
builder.add_node(
c_id,
lambda states, c=c, i=i, depth=current_depth: agent_builder(states, c, i, depth)
)
# print("[ADD NODE]", c_id, i, current_depth)
except ValueError as e:
print("[ERROR]",e)
pass
# Push the next chain in the current list to be processed after children
if i + 1 < len(current_chains):
stack.append((current_chains, current_parent, current_depth, i + 1))
# Process children or add edge to END
if c.get("child"):
stack.append((
c["child"],
c_id,
current_depth + 1,
0
))
condition_ids = []
for x in c["child"]:
if x["condition"]:
condition_ids.append(x["id"])
else:
builder.add_edge(c_id, x["id"])
if condition_ids:
builder.add_conditional_edges(
c_id,
lambda states: route(states, condition_ids), path_map=condition_ids + [END]
)
else:
builder.add_edge(
c_id,
END
)
print("SET STARTING POINTS")
for start_point in chains:
builder.add_edge(START, start_point["id"])
print("[NODES]", builder.nodes)
print("[EDGES]", builder.edges)
graph = builder.compile(checkpointer=checkpointer)
return graph