LLMGameHub / src /agent /llm_graph.py
gsavin's picture
fix: parallel game generation
0a18f7d
"""LangGraph setup for the interactive fiction agent."""
from agent.music_agent import generate_music_prompt
import logging
from dataclasses import dataclass
from typing import Any, Dict, Optional
import asyncio
from langgraph.graph import END, StateGraph
from agent.image_agent import generate_image_prompt
from agent.tools import (
check_ending,
generate_scene,
generate_scene_image,
generate_story_frame,
update_state_with_choice,
)
from agent.redis_state import get_user_state
from audio.audio_generator import change_music_tone
logger = logging.getLogger(__name__)
@dataclass
class GraphState:
"""Mutable state passed between graph nodes."""
user_hash: Optional[str] = None
step: Optional[str] = None
setting: Optional[str] = None
character: Optional[Dict[str, Any]] = None
genre: Optional[str] = None
choice_text: Optional[str] = None
scene: Optional[Dict[str, Any]] = None
ending: Optional[Dict[str, Any]] = None
async def node_entry(state: GraphState) -> GraphState:
logger.debug("[Graph] entry state: %s", state)
return state
def route_step(state: GraphState) -> str:
if state.step == "start":
return "init_game"
if state.step == "choose":
return "player_step"
logger.warning("route_step received unknown step '%s'", state.step)
return "init_game"
async def node_init_game(state: GraphState) -> GraphState:
logger.debug("[Graph] node_init_game state: %s", state)
await generate_story_frame.ainvoke(
{
"user_hash": state.user_hash,
"setting": state.setting,
"character": state.character,
"genre": state.genre,
}
)
first_scene = await generate_scene.ainvoke(
{"user_hash": state.user_hash, "last_choice": "start"}
)
init_description = (
f"{first_scene['description']}\n"
"NOTE FOR THE ASSISTANT: YOU MUST GENERATE A NEW IMAGE FOR THE STARTING SCENE"
)
change_scene = await generate_image_prompt(state.user_hash, init_description)
logger.info(f"Change scene: {change_scene}")
await generate_scene_image.ainvoke(
{
"user_hash": state.user_hash,
"scene_id": first_scene["scene_id"],
"change_scene": change_scene,
}
)
state.scene = first_scene
return state
async def node_player_step(state: GraphState) -> GraphState:
logger.debug("[Graph] node_player_step state: %s", state)
user_state = await get_user_state(state.user_hash)
scene_id = user_state.current_scene_id
if state.choice_text:
await update_state_with_choice.ainvoke(
{
"user_hash": state.user_hash,
"scene_id": scene_id,
"choice_text": state.choice_text,
}
)
ending = await check_ending.ainvoke({"user_hash": state.user_hash})
state.ending = ending
if not ending.get("ending_reached", False):
next_scene = await generate_scene.ainvoke(
{
"user_hash": state.user_hash,
"last_choice": state.choice_text,
}
)
change_scene = await generate_image_prompt(state.user_hash, next_scene["description"], state.choice_text)
current_image = None
if scene_id and scene_id in user_state.scenes:
current_image = user_state.scenes[scene_id].image
image_task = generate_scene_image.ainvoke(
{
"user_hash": state.user_hash,
"scene_id": next_scene["scene_id"],
"current_image": current_image,
"change_scene": change_scene,
}
)
music_task = generate_music_prompt(state.user_hash, next_scene["description"], state.choice_text)
_, music_prompt = await asyncio.gather(image_task, music_task)
asyncio.create_task(change_music_tone(state.user_hash, music_prompt))
state.scene = next_scene
return state
def route_ending(state: GraphState) -> str:
return "game_over" if state.ending.get("ending_reached") else "continue"
async def node_game_over(state: GraphState) -> GraphState:
logger.info("[Graph] Game over for user %s", state.user_hash)
return state
def build_llm_game_graph() -> StateGraph:
graph = StateGraph(GraphState)
graph.add_node("entry", node_entry)
graph.add_node("init_game", node_init_game)
graph.add_node("player_step", node_player_step)
graph.add_node("game_over", node_game_over)
graph.set_entry_point("entry")
graph.add_conditional_edges(
"entry",
route_step,
{"init_game": "init_game", "player_step": "player_step"},
)
graph.add_edge("init_game", END)
graph.add_conditional_edges(
"player_step",
route_ending,
{"game_over": "game_over", "continue": END},
)
graph.add_edge("game_over", END)
return graph.compile()
llm_game_graph = build_llm_game_graph()