|
from typing import Annotated, Literal |
|
from typing_extensions import TypedDict |
|
|
|
from langgraph.graph import StateGraph, MessagesState, START, END |
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage |
|
from langchain_core.output_parsers import JsonOutputParser |
|
from langchain_community.document_transformers import BeautifulSoupTransformer, beautiful_soup_transformer |
|
|
|
from langgraph.types import Command |
|
|
|
from langchain_groq import ChatGroq |
|
|
|
import operator |
|
import pprint |
|
import os |
|
import requests |
|
import html2text |
|
|
|
API_KEY = os.getenv("GROQ_API_KEY") |
|
OUT_RES = "<|FINISHED|>" |
|
|
|
HTML_TRANSFORMER = html2text.HTML2Text() |
|
HTML_TRANSFORMER.ignore_links = True |
|
HTML_TRANSFORMER.ignore_images = True |
|
|
|
BS_TRANSFORMER = BeautifulSoupTransformer() |
|
|
|
|
|
def local_message_add(dict1, dict2): |
|
key2 = list(dict2.keys())[0] |
|
if key2 not in dict1: |
|
dict1[key2] = dict2[key2] |
|
else: |
|
dict1[key2] = dict1[key2] + dict2[key2] |
|
return dict1 |
|
|
|
def variable_state_update(dict1, dict2): |
|
dict1.update(dict2) |
|
return dict1 |
|
|
|
class GeneralStates(TypedDict): |
|
messages: Annotated[list[dict[str, str]], lambda x,y:x+y] |
|
checkpoints: dict[str,list] |
|
local_messages: Annotated[dict, local_message_add] |
|
variables: Annotated[dict, variable_state_update] |
|
|
|
|
|
def format_sequence(seq, nested=False): |
|
if isinstance(seq, (list, tuple, set, frozenset, dict)): |
|
|
|
if isinstance(seq, dict): |
|
return format_dict(seq, nested=nested) |
|
|
|
else: |
|
return format_list_like(seq, nested=nested) |
|
else: |
|
return seq |
|
|
|
|
|
|
|
def format_dict(d, nested=False): |
|
|
|
items = [] |
|
for i, (key, value) in enumerate(d.items()): |
|
if isinstance(value, (list, tuple, set, frozenset, dict)): |
|
value = format_sequence(value, nested=True) |
|
if not nested: |
|
items.append(f"{i+1}. {key}: {value}") |
|
else: |
|
items.append(f"{key}: {value}") |
|
return ",\n".join(items) |
|
|
|
def format_list_like(seq, nested=False): |
|
|
|
items = [] |
|
for i,item in enumerate(seq): |
|
if isinstance(item, (list, tuple, set, frozenset, dict)): |
|
item = format_sequence(item, nested=True) |
|
if not nested: |
|
items.append(f"{i+1}. {item}") |
|
else: |
|
items.append(str(item)) |
|
return ",\n".join(items) |
|
|
|
|
|
def format_dict_api(input_dict, combined): |
|
formatted_dict = {} |
|
for key, value in input_dict.items(): |
|
if isinstance(value, dict): |
|
formatted_dict[key] = format_dict_api(value, combined) |
|
elif isinstance(value, str): |
|
|
|
formatted_dict[key] = value.format(**combined) |
|
|
|
|
|
|
|
|
|
else: |
|
formatted_dict[key] = value |
|
|
|
return formatted_dict |
|
|
|
|
|
def run_api(api_endpoints, variables, response, input_message, chain_id): |
|
if not api_endpoints: |
|
return {} |
|
combined = variables.copy() |
|
if response: |
|
api_endpoint_type = "output" |
|
if isinstance(response, dict): |
|
combined = combined | response |
|
|
|
else: |
|
combined["output_message"] = response |
|
else: |
|
api_endpoint_type = "input" |
|
|
|
combined["input_message"] = input_message |
|
resp = [] |
|
errors = [] |
|
for x in api_endpoints: |
|
try: |
|
input_var = {inp: combined[inp] for inp in x["input_variables"]} |
|
res = requests.request( |
|
x['method'], |
|
x['url'], |
|
headers = format_dict_api(x['headers'], input_var) if x["headers"] else None, |
|
params = format_dict_api(x["params"], input_var) if x["params"] else None, |
|
json = format_dict_api(x["request_body"], input_var) if x["request_body"] else None, |
|
) |
|
|
|
if x['response_type'] == 'json': |
|
res = res.json() |
|
else: |
|
res = res.text |
|
|
|
if res[:15] == "<!DOCTYPE html>": |
|
if x["html_to_markdown"]: |
|
res = HTML_TRANSFORMER.handle(res) |
|
elif x["html_tags_to_extract"]: |
|
res = BS_TRANSFORMER.extract_tags(res, tags=x["html_tags_to_extract"]) |
|
resp.append([res, x["name"]]) |
|
except Exception as e: |
|
errors.append([e, x["name"]]) |
|
|
|
api_dict = {} |
|
|
|
for x in resp: |
|
|
|
api_dict[f"{api_endpoint_type}_{x[1]}_{chain_id}_success"] = x[0] |
|
for x in errors: |
|
|
|
api_dict[f"{api_endpoint_type}_{x[1]}_{chain_id}_error"] = x[0] |
|
variables.update(api_dict) |
|
return api_dict |
|
|
|
|
|
def agent_builder(states: GeneralStates, chain: dict, row:int, depth: int): |
|
|
|
|
|
model_config = chain.get("agent") |
|
print("[MODEL CONFIG]", model_config) |
|
child = chain.get("child") |
|
checkpoints = states.get("checkpoints", {}) |
|
|
|
print("[STATES]", states) |
|
|
|
for k,v in checkpoints.items(): |
|
|
|
if k == chain["id"]: |
|
return Command(goto=v) |
|
|
|
api_dict = {"variables":{}} |
|
|
|
variables = states.get("variables", {}) |
|
variables["input_message"] = states["messages"][-1].content |
|
|
|
|
|
|
|
|
|
api_res = run_api(model_config["input_api_endpoints"], variables, None, states["messages"][-1].content, chain["id"]) |
|
api_dict["variables"].update(api_res) |
|
|
|
for c in child: |
|
if c["condition_from"] == "input" and states['messages'][-1].content.strip() == c["condition"]: |
|
redirect_agent_message = AIMessage(f"Switch to Agent {c['id']}") |
|
|
|
|
|
local_message = states["local_messages"].get(chain["id"]) |
|
if local_message: |
|
update_dict = { |
|
"local_messages": { |
|
|
|
chain["id"]:[redirect_agent_message], |
|
c["id"]:[states['messages'][-1]] |
|
}, |
|
} |
|
else: |
|
update_dict = {} |
|
|
|
if c.get("checkpoint"): |
|
|
|
update_dict["checkpoints"] = {chain["id"]:c["id"]} |
|
|
|
return Command(goto=c["id"], update=update_dict | api_dict) |
|
|
|
|
|
messages = states["local_messages"].get(chain["id"]) |
|
|
|
if messages: |
|
messages.append(states["messages"][-1]) |
|
else: |
|
messages = states["messages"] |
|
|
|
input_var = model_config.get("input_variables") |
|
output_variables = model_config.get("output_variables") |
|
|
|
|
|
if model_config.get("is_template"): |
|
response = model_config.get("prompt") |
|
if input_var: |
|
|
|
response = response.format(**{var: variables[var] for var in input_var}) |
|
response = AIMessage(response) |
|
|
|
api_res = run_api(model_config["output_api_endpoints"], variables, response.content, messages[-1].content, chain["id"]) |
|
api_dict["variables"].update(api_res) |
|
|
|
|
|
if output_variables: |
|
out = {out_var: response.content for out_var in output_variables} |
|
if "messages" not in output_variables: |
|
api_dict["variables"].update(out) |
|
return api_dict |
|
else: |
|
out.pop("messages") |
|
api_dict["variables"].update(out) |
|
|
|
return {"messages":[response], "local_messages":{chain["id"]:[response]}} | api_dict |
|
|
|
def run_agent(i, loop_input_variables, variables): |
|
if input_var: |
|
print("[AGENT ID]", chain['id']) |
|
print("[INPUT VARIABLES]", input_var) |
|
print("[VARIABLES]", variables) |
|
user_input = "\n".join([str(variables[var]) for var in input_var]) |
|
|
|
if i == -1: |
|
prompt = model_config.get("prompt").format(**{var: variables[var] for var in input_var}) |
|
else: |
|
prompt = model_config.get("prompt").format( |
|
**{var: variables[var][i] if var in loop_input_variables else variables[var] for var in input_var} |
|
) |
|
else: |
|
user_input = messages[-1].content |
|
prompt = model_config.get("prompt") + "\n\n" + messages[-1].content |
|
|
|
|
|
model = ChatGroq( |
|
|
|
|
|
|
|
|
|
model="llama-3.3-70b-versatile", |
|
|
|
temperature=model_config.get("creativity"), |
|
max_tokens=None, |
|
timeout=None, |
|
max_retries=2, |
|
api_key=API_KEY |
|
) |
|
|
|
routes = model_config.get("routes") |
|
output_collector = model_config.get("output_collector") |
|
|
|
if routes: |
|
add_prompt = f"YOU MUST GENERATE OUTPUT STRICTLY one of the following list : [{', '.join(routes)}]\n\n" |
|
if model_config.get("routes_description"): |
|
add_prompt += "HERE IS THE CONDITIONS FOR EACH OUTPUT:\n" |
|
add_prompt += "\n".join([f"{x}: {y}" for x,y in zip(routes, model_config.get("routes_description"))]) |
|
add_prompt += "\n\n" |
|
|
|
prompt = add_prompt + prompt |
|
elif output_collector: |
|
|
|
add_prompt = f"YOU MUST GENERATE OUTPUT STRICTLY IN THE FOLLOWING JSON FORMAT, REMEMBER TO ADD {{}} BEFORE AND AFTER JSON CODE:\n" |
|
|
|
add_prompt += "\n".join(output_collector) |
|
add_prompt += "\n\n" |
|
|
|
prompt = prompt +"\n\n"+ add_prompt |
|
|
|
response = (model | JsonOutputParser()).invoke(messages[:-1] + [HumanMessage(content=prompt)]) |
|
|
|
if output_variables: |
|
for k in response.keys(): |
|
if k not in output_variables: |
|
del response[k] |
|
|
|
api_res = run_api(model_config["output_api_endpoints"], variables, response, messages[-1].content, chain["id"]) |
|
|
|
api_dict["variables"].update(api_res) |
|
|
|
return {"variables":response | api_dict["variables"]} |
|
|
|
response = model.invoke(messages[:-1] + [HumanMessage(content=prompt)]) |
|
|
|
for c in child: |
|
if c["condition_from"] == "output" and response.content.strip() == c["condition"]: |
|
redirect_agent_message = AIMessage(f"Switch to Agent {c['id']}") |
|
|
|
|
|
local_message = states["local_messages"].get(chain["id"]) |
|
if local_message: |
|
update_dict = { |
|
"local_messages": { |
|
|
|
chain["id"]:[redirect_agent_message], |
|
c["id"]:[HumanMessage(user_input)] |
|
}, |
|
} |
|
else: |
|
update_dict = {} |
|
|
|
if c.get("checkpoint"): |
|
|
|
update_dict["checkpoints"] = {chain["id"]:c["id"]} |
|
|
|
api_res = run_api(model_config["output_api_endpoints"], variables, response.content, messages[-1].content, chain["id"]) |
|
api_dict["variables"].update(api_res) |
|
|
|
|
|
|
|
|
|
if output_variables: |
|
out = {out_var: response.content for out_var in output_variables} |
|
if "messages" not in output_variables: |
|
api_dict["variables"].update(out) |
|
return api_dict |
|
else: |
|
api_dict["messages"] = out.pop("messages") |
|
api_dict["variables"].update(out) |
|
|
|
return Command(goto=c["id"], update=update_dict | api_dict) |
|
elif response.content.strip() == OUT_RES: |
|
|
|
api_res = run_api(model_config["output_api_endpoints"], variables, None, messages[-1].content, chain["id"]) |
|
api_dict["variables"].update(api_res) |
|
|
|
return {} | api_dict |
|
|
|
api_res = run_api(model_config["output_api_endpoints"], variables, response.content, messages[-1].content, chain["id"]) |
|
api_dict["variables"].update(api_res) |
|
|
|
if output_variables: |
|
out = {out_var: response.content for out_var in output_variables} |
|
if "messages" not in output_variables: |
|
api_dict["variables"].update(out) |
|
return api_dict |
|
else: |
|
out.pop("messages") |
|
api_dict["variables"].update(out) |
|
|
|
|
|
|
|
return {"messages":[response], "local_messages":{chain["id"]:[response]}} | api_dict |
|
|
|
if not chain["loop_input_variables"]: |
|
return run_agent(-1, [], variables) |
|
else: |
|
max_loop = min([len(states["variables"].get(x)) for x in chain["loop_input_variables"]]) |
|
|
|
updates = {"variables":{}} |
|
|
|
for i in range(max_loop): |
|
out_variables = run_agent(i, chain["loop_input_variables"], variables) |
|
|
|
if type(out_variables) == dict: |
|
if not out_variables.get("variables"): |
|
continue |
|
for k,v in out_variables["variables"].items(): |
|
if k not in updates["variables"].keys(): |
|
updates["variables"][k] = [] |
|
if type(v) == list: |
|
updates["variables"][k] += v |
|
else: |
|
updates["variables"][k].append(v) |
|
else: |
|
updates = out_variables |
|
return updates |
|
|
|
|
|
def route(states, routes): |
|
if states["messages"][-1].content.strip() in routes: |
|
return states["messages"][-1].content.strip() |
|
return END |
|
|
|
def build_chain(chains, checkpointer, parent_name=None, depth=0): |
|
print("[BUILD CHAIN] START....") |
|
|
|
stack = [(chains, parent_name, depth, 0)] |
|
|
|
builder = StateGraph(GeneralStates) |
|
|
|
while stack: |
|
current_chains, current_parent, current_depth, i = stack.pop() |
|
print("STACK", i) |
|
|
|
if i >= len(current_chains): |
|
continue |
|
|
|
c = current_chains[i] |
|
c_id = c["id"] |
|
|
|
|
|
|
|
try: |
|
print("ADDED NODE!") |
|
builder.add_node( |
|
c_id, |
|
lambda states, c=c, i=i, depth=current_depth: agent_builder(states, c, i, depth) |
|
) |
|
|
|
except ValueError as e: |
|
print("[ERROR]",e) |
|
pass |
|
|
|
|
|
|
|
if i + 1 < len(current_chains): |
|
stack.append((current_chains, current_parent, current_depth, i + 1)) |
|
|
|
|
|
if c.get("child"): |
|
stack.append(( |
|
c["child"], |
|
c_id, |
|
current_depth + 1, |
|
0 |
|
)) |
|
|
|
condition_ids = [] |
|
|
|
for x in c["child"]: |
|
if x["condition"]: |
|
condition_ids.append(x["id"]) |
|
else: |
|
builder.add_edge(c_id, x["id"]) |
|
|
|
if condition_ids: |
|
builder.add_conditional_edges( |
|
c_id, |
|
lambda states: route(states, condition_ids), path_map=condition_ids + [END] |
|
) |
|
else: |
|
builder.add_edge( |
|
c_id, |
|
END |
|
) |
|
|
|
print("SET STARTING POINTS") |
|
|
|
for start_point in chains: |
|
builder.add_edge(START, start_point["id"]) |
|
print("[NODES]", builder.nodes) |
|
print("[EDGES]", builder.edges) |
|
graph = builder.compile(checkpointer=checkpointer) |
|
return graph |
|
|