"""LangGraph setup for the interactive fiction agent.""" from agent.music_agent import generate_music_prompt import logging from dataclasses import dataclass from typing import Any, Dict, Optional import asyncio from langgraph.graph import END, StateGraph from agent.image_agent import generate_image_prompt from agent.tools import ( check_ending, generate_scene, generate_scene_image, generate_story_frame, update_state_with_choice, ) from agent.redis_state import get_user_state from audio.audio_generator import change_music_tone logger = logging.getLogger(__name__) @dataclass class GraphState: """Mutable state passed between graph nodes.""" user_hash: Optional[str] = None step: Optional[str] = None setting: Optional[str] = None character: Optional[Dict[str, Any]] = None genre: Optional[str] = None choice_text: Optional[str] = None scene: Optional[Dict[str, Any]] = None ending: Optional[Dict[str, Any]] = None async def node_entry(state: GraphState) -> GraphState: logger.debug("[Graph] entry state: %s", state) return state def route_step(state: GraphState) -> str: if state.step == "start": return "init_game" if state.step == "choose": return "player_step" logger.warning("route_step received unknown step '%s'", state.step) return "init_game" async def node_init_game(state: GraphState) -> GraphState: logger.debug("[Graph] node_init_game state: %s", state) await generate_story_frame.ainvoke( { "user_hash": state.user_hash, "setting": state.setting, "character": state.character, "genre": state.genre, } ) first_scene = await generate_scene.ainvoke( {"user_hash": state.user_hash, "last_choice": "start"} ) init_description = ( f"{first_scene['description']}\n" "NOTE FOR THE ASSISTANT: YOU MUST GENERATE A NEW IMAGE FOR THE STARTING SCENE" ) change_scene = await generate_image_prompt(state.user_hash, init_description) logger.info(f"Change scene: {change_scene}") await generate_scene_image.ainvoke( { "user_hash": state.user_hash, "scene_id": first_scene["scene_id"], "change_scene": change_scene, } ) state.scene = first_scene return state async def node_player_step(state: GraphState) -> GraphState: logger.debug("[Graph] node_player_step state: %s", state) user_state = await get_user_state(state.user_hash) scene_id = user_state.current_scene_id if state.choice_text: await update_state_with_choice.ainvoke( { "user_hash": state.user_hash, "scene_id": scene_id, "choice_text": state.choice_text, } ) ending = await check_ending.ainvoke({"user_hash": state.user_hash}) state.ending = ending if not ending.get("ending_reached", False): next_scene = await generate_scene.ainvoke( { "user_hash": state.user_hash, "last_choice": state.choice_text, } ) change_scene = await generate_image_prompt(state.user_hash, next_scene["description"], state.choice_text) current_image = None if scene_id and scene_id in user_state.scenes: current_image = user_state.scenes[scene_id].image image_task = generate_scene_image.ainvoke( { "user_hash": state.user_hash, "scene_id": next_scene["scene_id"], "current_image": current_image, "change_scene": change_scene, } ) music_task = generate_music_prompt(state.user_hash, next_scene["description"], state.choice_text) _, music_prompt = await asyncio.gather(image_task, music_task) asyncio.create_task(change_music_tone(state.user_hash, music_prompt)) state.scene = next_scene return state def route_ending(state: GraphState) -> str: return "game_over" if state.ending.get("ending_reached") else "continue" async def node_game_over(state: GraphState) -> GraphState: logger.info("[Graph] Game over for user %s", state.user_hash) return state def build_llm_game_graph() -> StateGraph: graph = StateGraph(GraphState) graph.add_node("entry", node_entry) graph.add_node("init_game", node_init_game) graph.add_node("player_step", node_player_step) graph.add_node("game_over", node_game_over) graph.set_entry_point("entry") graph.add_conditional_edges( "entry", route_step, {"init_game": "init_game", "player_step": "player_step"}, ) graph.add_edge("init_game", END) graph.add_conditional_edges( "player_step", route_ending, {"game_over": "game_over", "continue": END}, ) graph.add_edge("game_over", END) return graph.compile() llm_game_graph = build_llm_game_graph()