from fastapi import FastAPI from langchain_core.messages.base import message_to_dict from .models import Agent, OutputCollector, APIEndpoint from .chains import GeneralStates, build_chain, agent_builder # from chains import * from langgraph.checkpoint.memory import MemorySaver from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langgraph.graph import StateGraph, MessagesState, START, END from datetime import datetime from typing import List app = FastAPI() checkpointer = MemorySaver() projects_chains = { } projects_agents = { } projects_output_collectors = { } projects_input_api_endpoints = { } projects_output_api_endpoints = { } @app.post("/agents") def create_agent(data: Agent): current_time = datetime.utcnow() data_dict = data.dict() data_dict["created_at"] = current_time data_dict["updated_at"] = current_time data_dict["deleted_at"] = None if data_dict["output_collector"]: project_collectors = projects_output_collectors.get(data_dict['project_id']) collector = project_collectors.get(data_dict['output_collector']) data_dict["output_collector"] = [f"{collector['keys'][i]}: {collector['data_types'][i]} = {collector['descriptions'][i]}" for i in range(len(collector['keys']))] if data_dict["input_api_endpoints"]: project_input_api = projects_input_api_endpoints.get(data_dict['project_id']) input_api = [project_input_api.get(inp) for inp in data_dict['input_api_endpoints']] data_dict['input_api_endpoints'] = input_api if data_dict["output_api_endpoints"]: project_output_api = projects_output_api_endpoints.get(data_dict['project_id']) output_api = [project_output_api.get(out) for out in data_dict['output_api_endpoints']] data_dict['output_api_endpoints'] = output_api if not projects_agents.get(data_dict['project_id']): projects_agents[data_dict['project_id']] = {data_dict['id']: data_dict} else: projects_agents[data_dict['project_id']].update({data_dict['id']: data_dict}) return {"message": "Agent created", "data": data_dict} @app.post("/output_collectors") def create_output_collector(data: OutputCollector): current_time = datetime.utcnow() data_dict = data.dict() data_dict["created_at"] = current_time data_dict["updated_at"] = current_time data_dict["deleted_at"] = None if not projects_output_collectors.get(data_dict['project_id']): projects_output_collectors[data_dict['project_id']] = {data_dict['id']: data_dict} else: projects_output_collectors[data_dict['project_id']].update({data_dict['id']: data_dict}) return {"message": "Output Collector created", "data": data_dict} @app.post("/output_api_endpoints") def create_input_api(data: APIEndpoint): current_time = datetime.utcnow() data_dict = data.dict() data_dict["created_at"] = current_time data_dict["updated_at"] = current_time data_dict["deleted_at"] = None if not projects_input_api_endpoints.get(data_dict['project_id']): projects_input_api_endpoints[data_dict['project_id']] = {data_dict['id']: data_dict} else: projects_input_api_endpoints[data_dict['project_id']].update({data_dict['id']: data_dict}) return {"message": "Input API Endpoint created", "data": data_dict} @app.post("/output_api_endpoints") def create_output_api(data: APIEndpoint): current_time = datetime.utcnow() data_dict = data.dict() data_dict["created_at"] = current_time data_dict["updated_at"] = current_time data_dict["deleted_at"] = None if not projects_output_api_endpoints.get(data_dict['project_id']): projects_output_api_endpoints[data_dict['project_id']] = {data_dict['id']: data_dict} else: projects_output_api_endpoints[data_dict['project_id']].update({data_dict['id']: data_dict}) return {"message": "Output API Endpoint created", "data": data_dict} @app.post("/build_chain") def create_chain(project_id: str, id:str, chains: List[dict], welcome_message: str): current_time = datetime.utcnow() # data_dict = {} # builder = StateGraph(GeneralStates) agents = projects_agents.get(project_id) def update_key(data, key_field, agents): if isinstance(data, dict): # If the current item is a dictionary, update the "key" field if it exists if key_field in data: data[key_field] = agents[data[key_field]] update_key(data["child"], key_field, agents) # Recursively process all values in the dictionary # for key, value in data.items(): # update_key(value, key_field, agents) elif isinstance(data, list): # If the current item is a list, recursively process each item in the list for item in data: update_key(item, key_field, agents) print("[CHAIN PREV]", chains) update_key(chains, "agent", agents) print("[CHAIN NEW]", chains) graph = build_chain(chains, checkpointer) # for start_point in chains: # builder.add_edge(START, start_point["id"]) # graph = builder.compile(checkpointer=checkpointer) if not projects_chains.get(project_id): projects_chains[project_id] = {id: graph} else: projects_chains[project_id].update({id: graph}) # if projects_chains.get(project_id): # projects_chains.append(graph) # else: # projects_chains[project_id] = [graph] # data_dict["created_at"] = current_time # data_dict["updated_at"] = current_time # data_dict["deleted_at"] = None return {"message": "Chain created", "data": chains} @app.get("/{project_id}/{id}/run") def run(project_id: str, id: str, session_id:str, message:str="Hello world!"): chain = projects_chains.get(project_id) graph = None if chain: graph = chain.get(id) if graph == None: agent = projects_agents.get(project_id) if agent: runnable = agent.get(id) # print("[AGENT]", runnable) chains = [{ "id":"123", "agent":runnable, "checkpoint":False, "condition":None, "condition_from":None, "child":[] }] graph = build_chain(chains, agent, checkpointer) # for start_point in chains: # builder.add_edge(START, start_point["id"]) # graph = builder.compile(checkpointer=checkpointer) print("[GRAPH]", graph) # agent_builder(chain={}, row=0, depth=0) else: return {"Error":"Agent or Chain does not exists!", "status_code":401} # print("[NODES]", builder.nodes) # print("[EDGES]", builder.edges) # for chunk in graph.stream( # {"messages": [HumanMessage(message)]}, # {"configurable": {"thread_id": session_id}}, # stream_mode="values", # ): # print("[CHUNK]", chunk) out = graph.invoke( { "messages": [HumanMessage(message)], "variables": {"user_input":message} }, {"configurable": {"thread_id": session_id}} ) return {"results":out["messages"][-1].content} @app.get("/agents") def get_agents(): return {"results":projects_agents} @app.get("/") def greet_json(): return {"Hello": "World!"}