Spaces:
Running
Running
"""LangGraph setup for the interactive fiction agent.""" | |
from agent.music_agent import generate_music_prompt | |
import logging | |
from dataclasses import dataclass | |
from typing import Any, Dict, Optional | |
import asyncio | |
from langgraph.graph import END, StateGraph | |
from agent.image_agent import generate_image_prompt | |
from agent.tools import ( | |
check_ending, | |
generate_scene, | |
generate_scene_image, | |
generate_story_frame, | |
update_state_with_choice, | |
) | |
from agent.redis_state import get_user_state | |
from audio.audio_generator import change_music_tone | |
logger = logging.getLogger(__name__) | |
class GraphState: | |
"""Mutable state passed between graph nodes.""" | |
user_hash: Optional[str] = None | |
step: Optional[str] = None | |
setting: Optional[str] = None | |
character: Optional[Dict[str, Any]] = None | |
genre: Optional[str] = None | |
choice_text: Optional[str] = None | |
scene: Optional[Dict[str, Any]] = None | |
ending: Optional[Dict[str, Any]] = None | |
async def node_entry(state: GraphState) -> GraphState: | |
logger.debug("[Graph] entry state: %s", state) | |
return state | |
def route_step(state: GraphState) -> str: | |
if state.step == "start": | |
return "init_game" | |
if state.step == "choose": | |
return "player_step" | |
logger.warning("route_step received unknown step '%s'", state.step) | |
return "init_game" | |
async def node_init_game(state: GraphState) -> GraphState: | |
logger.debug("[Graph] node_init_game state: %s", state) | |
await generate_story_frame.ainvoke( | |
{ | |
"user_hash": state.user_hash, | |
"setting": state.setting, | |
"character": state.character, | |
"genre": state.genre, | |
} | |
) | |
first_scene = await generate_scene.ainvoke( | |
{"user_hash": state.user_hash, "last_choice": "start"} | |
) | |
init_description = ( | |
f"{first_scene['description']}\n" | |
"NOTE FOR THE ASSISTANT: YOU MUST GENERATE A NEW IMAGE FOR THE STARTING SCENE" | |
) | |
change_scene = await generate_image_prompt(state.user_hash, init_description) | |
logger.info(f"Change scene: {change_scene}") | |
await generate_scene_image.ainvoke( | |
{ | |
"user_hash": state.user_hash, | |
"scene_id": first_scene["scene_id"], | |
"change_scene": change_scene, | |
} | |
) | |
state.scene = first_scene | |
return state | |
async def node_player_step(state: GraphState) -> GraphState: | |
logger.debug("[Graph] node_player_step state: %s", state) | |
user_state = await get_user_state(state.user_hash) | |
scene_id = user_state.current_scene_id | |
if state.choice_text: | |
await update_state_with_choice.ainvoke( | |
{ | |
"user_hash": state.user_hash, | |
"scene_id": scene_id, | |
"choice_text": state.choice_text, | |
} | |
) | |
ending = await check_ending.ainvoke({"user_hash": state.user_hash}) | |
state.ending = ending | |
if not ending.get("ending_reached", False): | |
next_scene = await generate_scene.ainvoke( | |
{ | |
"user_hash": state.user_hash, | |
"last_choice": state.choice_text, | |
} | |
) | |
change_scene = await generate_image_prompt(state.user_hash, next_scene["description"], state.choice_text) | |
current_image = None | |
if scene_id and scene_id in user_state.scenes: | |
current_image = user_state.scenes[scene_id].image | |
image_task = generate_scene_image.ainvoke( | |
{ | |
"user_hash": state.user_hash, | |
"scene_id": next_scene["scene_id"], | |
"current_image": current_image, | |
"change_scene": change_scene, | |
} | |
) | |
music_task = generate_music_prompt(state.user_hash, next_scene["description"], state.choice_text) | |
_, music_prompt = await asyncio.gather(image_task, music_task) | |
asyncio.create_task(change_music_tone(state.user_hash, music_prompt)) | |
state.scene = next_scene | |
return state | |
def route_ending(state: GraphState) -> str: | |
return "game_over" if state.ending.get("ending_reached") else "continue" | |
async def node_game_over(state: GraphState) -> GraphState: | |
logger.info("[Graph] Game over for user %s", state.user_hash) | |
return state | |
def build_llm_game_graph() -> StateGraph: | |
graph = StateGraph(GraphState) | |
graph.add_node("entry", node_entry) | |
graph.add_node("init_game", node_init_game) | |
graph.add_node("player_step", node_player_step) | |
graph.add_node("game_over", node_game_over) | |
graph.set_entry_point("entry") | |
graph.add_conditional_edges( | |
"entry", | |
route_step, | |
{"init_game": "init_game", "player_step": "player_step"}, | |
) | |
graph.add_edge("init_game", END) | |
graph.add_conditional_edges( | |
"player_step", | |
route_ending, | |
{"game_over": "game_over", "continue": END}, | |
) | |
graph.add_edge("game_over", END) | |
return graph.compile() | |
llm_game_graph = build_llm_game_graph() | |