import uuid from fastapi import FastAPI from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from langchain_core.messages import BaseMessage, HumanMessage, trim_messages from langchain_core.tools import tool from langchain_openai import ChatOpenAI from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent from pydantic import BaseModel from typing import Optional import json from sse_starlette.sse import EventSourceResponse from datetime import datetime from fastapi import APIRouter from langchain_core.runnables import RunnableConfig from langchain_core.prompts import ChatPromptTemplate router = APIRouter( prefix="/presentation", tags=["presentation"] ) @tool def plan(input: dict) -> str: """Create a presentation plan with numbered slides and their descriptions. Args: input: Dictionary containing presentation details Returns: A dictionary with slide numbers as keys and descriptions as values """ return "plan created" @tool def create_slide(slideno: int, content: str) -> str: """Create a single presentation slide. Args: slideno: The slide number to create content: The content for the slide Returns: Confirmation of slide creation """ return f"slide {slideno} created" memory = MemorySaver() model = ChatOpenAI(model="gpt-4o-mini", streaming=True) prompt = ChatPromptTemplate.from_messages([ ("system", """You are a Presentation Creation Assistant. Your task is to help users create effective presentations. Follow these steps: 1. First use the plan tool to create an outline of the presentation 2. Then use create_slide tool for each slide in sequence 3. Guide the user through the presentation creation process Today's date is {datetime.now().strftime('%Y-%m-%d')}"""), ("placeholder", "{messages}"), ]) def state_modifier(state) -> list[BaseMessage]: try: formatted_prompt = prompt.invoke({ "messages": state["messages"] }) return trim_messages( formatted_prompt, token_counter=len, max_tokens=16000, strategy="last", start_on="human", include_system=True, allow_partial=False, ) except Exception as e: print(f"Error in state modifier: {str(e)}") return state["messages"] # Create the agent with presentation tools agent = create_react_agent( model, tools=[plan, create_slide], checkpointer=memory, state_modifier=state_modifier, ) class ChatInput(BaseModel): message: str thread_id: Optional[str] = None @router.post("/chat") async def chat(input_data: ChatInput): thread_id = input_data.thread_id or str(uuid.uuid4()) config = { "configurable": { "thread_id": thread_id } } input_message = HumanMessage(content=input_data.message) async def generate(): async for event in agent.astream_events( {"messages": [input_message]}, config, version="v2" ): kind = event["event"] if kind == "on_chat_model_stream": content = event["data"]["chunk"].content if content: yield f"{json.dumps({'type': 'token', 'content': content})}\n" elif kind == "on_tool_start": tool_input = event['data'].get('input', '') yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}\n" elif kind == "on_tool_end": tool_output = event['data'].get('output', '') yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}\n" return EventSourceResponse( generate(), media_type="text/event-stream" ) @router.get("/health") async def health_check(): return {"status": "healthy"}