|
from fastapi import FastAPI |
|
from langgraph.graph import StateGraph |
|
from typing import TypedDict, Annotated, List |
|
from langgraph.graph.message import add_messages |
|
from pydantic import BaseModel |
|
|
|
|
|
app = FastAPI(title="LangGraph Agent API") |
|
|
|
class State(TypedDict): |
|
messages: Annotated[list[str], add_messages] |
|
current_step: str |
|
|
|
class AgentInput(BaseModel): |
|
messages: List[str] |
|
|
|
def collect_info(state: State) -> dict: |
|
print("\n--> In collect_info") |
|
print(f"Messages before: {state['messages']}") |
|
|
|
messages = state["messages"] + ["Information collected"] |
|
print(f"Messages after: {messages}") |
|
|
|
return { |
|
"messages": messages, |
|
"current_step": "process" |
|
} |
|
|
|
def process_info(state: State) -> dict: |
|
print("\n--> In process_info") |
|
print(f"Messages before: {state['messages']}") |
|
|
|
messages = state["messages"] + ["Information processed"] |
|
print(f"Messages after: {messages}") |
|
|
|
return { |
|
"messages": messages, |
|
"current_step": "end" |
|
} |
|
|
|
|
|
workflow = StateGraph(State) |
|
|
|
|
|
workflow.add_node("collect", collect_info) |
|
workflow.add_node("process", process_info) |
|
|
|
|
|
workflow.add_edge("collect", "process") |
|
|
|
|
|
workflow.set_entry_point("collect") |
|
workflow.set_finish_point("process") |
|
|
|
|
|
agent = workflow.compile() |
|
|
|
|
|
@app.post("/run-agent") |
|
async def run_agent(input_data: AgentInput): |
|
""" |
|
Run the agent with the provided input messages. |
|
""" |
|
initial_state = State(messages=input_data.messages, current_step="collect") |
|
final_state = agent.invoke(initial_state) |
|
return {"messages": final_state["messages"]} |
|
|
|
@app.get("/") |
|
async def root(): |
|
""" |
|
Root endpoint that returns basic API information. |
|
""" |
|
return {"message": "LangGraph Agent API is running", "endpoints": ["Navigate to https://jstoppa-langgraph-basic-example-api.hf.space/docs#/default/run_agent_run_agent_post to run the example"]} |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |