agent-chain / app.py
jonathanjordan21's picture
Upload folder using huggingface_hub
18f5c57 verified
from fastapi import FastAPI
from langchain_core.messages.base import message_to_dict
from .models import Agent, OutputCollector, APIEndpoint
from .chains import GeneralStates, build_chain, agent_builder
# from chains import *
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langgraph.graph import StateGraph, MessagesState, START, END
from datetime import datetime
from typing import List
app = FastAPI()
checkpointer = MemorySaver()
projects_chains = {
}
projects_agents = {
}
projects_output_collectors = {
}
projects_input_api_endpoints = {
}
projects_output_api_endpoints = {
}
@app.post("/agents")
def create_agent(data: Agent):
current_time = datetime.utcnow()
data_dict = data.dict()
data_dict["created_at"] = current_time
data_dict["updated_at"] = current_time
data_dict["deleted_at"] = None
if data_dict["output_collector"]:
project_collectors = projects_output_collectors.get(data_dict['project_id'])
collector = project_collectors.get(data_dict['output_collector'])
data_dict["output_collector"] = [f"{collector['keys'][i]}: {collector['data_types'][i]} = {collector['descriptions'][i]}" for i in range(len(collector['keys']))]
if data_dict["input_api_endpoints"]:
project_input_api = projects_input_api_endpoints.get(data_dict['project_id'])
input_api = [project_input_api.get(inp) for inp in data_dict['input_api_endpoints']]
data_dict['input_api_endpoints'] = input_api
if data_dict["output_api_endpoints"]:
project_output_api = projects_output_api_endpoints.get(data_dict['project_id'])
output_api = [project_output_api.get(out) for out in data_dict['output_api_endpoints']]
data_dict['output_api_endpoints'] = output_api
if not projects_agents.get(data_dict['project_id']):
projects_agents[data_dict['project_id']] = {data_dict['id']: data_dict}
else:
projects_agents[data_dict['project_id']].update({data_dict['id']: data_dict})
return {"message": "Agent created", "data": data_dict}
@app.post("/output_collectors")
def create_output_collector(data: OutputCollector):
current_time = datetime.utcnow()
data_dict = data.dict()
data_dict["created_at"] = current_time
data_dict["updated_at"] = current_time
data_dict["deleted_at"] = None
if not projects_output_collectors.get(data_dict['project_id']):
projects_output_collectors[data_dict['project_id']] = {data_dict['id']: data_dict}
else:
projects_output_collectors[data_dict['project_id']].update({data_dict['id']: data_dict})
return {"message": "Output Collector created", "data": data_dict}
@app.post("/output_api_endpoints")
def create_input_api(data: APIEndpoint):
current_time = datetime.utcnow()
data_dict = data.dict()
data_dict["created_at"] = current_time
data_dict["updated_at"] = current_time
data_dict["deleted_at"] = None
if not projects_input_api_endpoints.get(data_dict['project_id']):
projects_input_api_endpoints[data_dict['project_id']] = {data_dict['id']: data_dict}
else:
projects_input_api_endpoints[data_dict['project_id']].update({data_dict['id']: data_dict})
return {"message": "Input API Endpoint created", "data": data_dict}
@app.post("/output_api_endpoints")
def create_output_api(data: APIEndpoint):
current_time = datetime.utcnow()
data_dict = data.dict()
data_dict["created_at"] = current_time
data_dict["updated_at"] = current_time
data_dict["deleted_at"] = None
if not projects_output_api_endpoints.get(data_dict['project_id']):
projects_output_api_endpoints[data_dict['project_id']] = {data_dict['id']: data_dict}
else:
projects_output_api_endpoints[data_dict['project_id']].update({data_dict['id']: data_dict})
return {"message": "Output API Endpoint created", "data": data_dict}
@app.post("/build_chain")
def create_chain(project_id: str, id:str, chains: List[dict], welcome_message: str):
current_time = datetime.utcnow()
# data_dict = {}
# builder = StateGraph(GeneralStates)
agents = projects_agents.get(project_id)
def update_key(data, key_field, agents):
if isinstance(data, dict):
# If the current item is a dictionary, update the "key" field if it exists
if key_field in data:
data[key_field] = agents[data[key_field]]
update_key(data["child"], key_field, agents)
# Recursively process all values in the dictionary
# for key, value in data.items():
# update_key(value, key_field, agents)
elif isinstance(data, list):
# If the current item is a list, recursively process each item in the list
for item in data:
update_key(item, key_field, agents)
print("[CHAIN PREV]", chains)
update_key(chains, "agent", agents)
print("[CHAIN NEW]", chains)
graph = build_chain(chains, checkpointer)
# for start_point in chains:
# builder.add_edge(START, start_point["id"])
# graph = builder.compile(checkpointer=checkpointer)
if not projects_chains.get(project_id):
projects_chains[project_id] = {id: graph}
else:
projects_chains[project_id].update({id: graph})
# if projects_chains.get(project_id):
# projects_chains.append(graph)
# else:
# projects_chains[project_id] = [graph]
# data_dict["created_at"] = current_time
# data_dict["updated_at"] = current_time
# data_dict["deleted_at"] = None
return {"message": "Chain created", "data": chains}
@app.get("/{project_id}/{id}/run")
def run(project_id: str, id: str, session_id:str, message:str="Hello world!"):
chain = projects_chains.get(project_id)
graph = None
if chain:
graph = chain.get(id)
if graph == None:
agent = projects_agents.get(project_id)
if agent:
runnable = agent.get(id)
# print("[AGENT]", runnable)
chains = [{
"id":"123",
"agent":runnable,
"checkpoint":False,
"condition":None,
"condition_from":None,
"child":[]
}]
graph = build_chain(chains, agent, checkpointer)
# for start_point in chains:
# builder.add_edge(START, start_point["id"])
# graph = builder.compile(checkpointer=checkpointer)
print("[GRAPH]", graph)
# agent_builder(chain={}, row=0, depth=0)
else:
return {"Error":"Agent or Chain does not exists!", "status_code":401}
# print("[NODES]", builder.nodes)
# print("[EDGES]", builder.edges)
# for chunk in graph.stream(
# {"messages": [HumanMessage(message)]},
# {"configurable": {"thread_id": session_id}},
# stream_mode="values",
# ):
# print("[CHUNK]", chunk)
out = graph.invoke(
{
"messages": [HumanMessage(message)],
"variables": {"user_input":message}
},
{"configurable": {"thread_id": session_id}}
)
return {"results":out["messages"][-1].content}
@app.get("/agents")
def get_agents():
return {"results":projects_agents}
@app.get("/")
def greet_json():
return {"Hello": "World!"}