Spaces:
Running
Running
File size: 4,997 Bytes
d8e2b36 dca13c2 d8e2b36 2999669 d8e2b36 2999669 86b351a d8e2b36 ee968a7 2999669 d8e2b36 86b351a d8e2b36 86b351a d8e2b36 86b351a d8e2b36 86b351a d8e2b36 0a18f7d 2999669 d8e2b36 2999669 d8e2b36 70bca69 d8e2b36 0e8d8ac 0db7319 2999669 d8e2b36 0db7319 2999669 d8e2b36 0e8d8ac 60e195a d8e2b36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"""LangGraph setup for the interactive fiction agent."""
from agent.music_agent import generate_music_prompt
import logging
from dataclasses import dataclass
from typing import Any, Dict, Optional
import asyncio
from langgraph.graph import END, StateGraph
from agent.image_agent import generate_image_prompt
from agent.tools import (
check_ending,
generate_scene,
generate_scene_image,
generate_story_frame,
update_state_with_choice,
)
from agent.redis_state import get_user_state
from audio.audio_generator import change_music_tone
logger = logging.getLogger(__name__)
@dataclass
class GraphState:
"""Mutable state passed between graph nodes."""
user_hash: Optional[str] = None
step: Optional[str] = None
setting: Optional[str] = None
character: Optional[Dict[str, Any]] = None
genre: Optional[str] = None
choice_text: Optional[str] = None
scene: Optional[Dict[str, Any]] = None
ending: Optional[Dict[str, Any]] = None
async def node_entry(state: GraphState) -> GraphState:
logger.debug("[Graph] entry state: %s", state)
return state
def route_step(state: GraphState) -> str:
if state.step == "start":
return "init_game"
if state.step == "choose":
return "player_step"
logger.warning("route_step received unknown step '%s'", state.step)
return "init_game"
async def node_init_game(state: GraphState) -> GraphState:
logger.debug("[Graph] node_init_game state: %s", state)
await generate_story_frame.ainvoke(
{
"user_hash": state.user_hash,
"setting": state.setting,
"character": state.character,
"genre": state.genre,
}
)
first_scene = await generate_scene.ainvoke(
{"user_hash": state.user_hash, "last_choice": "start"}
)
init_description = (
f"{first_scene['description']}\n"
"NOTE FOR THE ASSISTANT: YOU MUST GENERATE A NEW IMAGE FOR THE STARTING SCENE"
)
change_scene = await generate_image_prompt(state.user_hash, init_description)
logger.info(f"Change scene: {change_scene}")
await generate_scene_image.ainvoke(
{
"user_hash": state.user_hash,
"scene_id": first_scene["scene_id"],
"change_scene": change_scene,
}
)
state.scene = first_scene
return state
async def node_player_step(state: GraphState) -> GraphState:
logger.debug("[Graph] node_player_step state: %s", state)
user_state = await get_user_state(state.user_hash)
scene_id = user_state.current_scene_id
if state.choice_text:
await update_state_with_choice.ainvoke(
{
"user_hash": state.user_hash,
"scene_id": scene_id,
"choice_text": state.choice_text,
}
)
ending = await check_ending.ainvoke({"user_hash": state.user_hash})
state.ending = ending
if not ending.get("ending_reached", False):
next_scene = await generate_scene.ainvoke(
{
"user_hash": state.user_hash,
"last_choice": state.choice_text,
}
)
change_scene = await generate_image_prompt(state.user_hash, next_scene["description"], state.choice_text)
current_image = None
if scene_id and scene_id in user_state.scenes:
current_image = user_state.scenes[scene_id].image
image_task = generate_scene_image.ainvoke(
{
"user_hash": state.user_hash,
"scene_id": next_scene["scene_id"],
"current_image": current_image,
"change_scene": change_scene,
}
)
music_task = generate_music_prompt(state.user_hash, next_scene["description"], state.choice_text)
_, music_prompt = await asyncio.gather(image_task, music_task)
asyncio.create_task(change_music_tone(state.user_hash, music_prompt))
state.scene = next_scene
return state
def route_ending(state: GraphState) -> str:
return "game_over" if state.ending.get("ending_reached") else "continue"
async def node_game_over(state: GraphState) -> GraphState:
logger.info("[Graph] Game over for user %s", state.user_hash)
return state
def build_llm_game_graph() -> StateGraph:
graph = StateGraph(GraphState)
graph.add_node("entry", node_entry)
graph.add_node("init_game", node_init_game)
graph.add_node("player_step", node_player_step)
graph.add_node("game_over", node_game_over)
graph.set_entry_point("entry")
graph.add_conditional_edges(
"entry",
route_step,
{"init_game": "init_game", "player_step": "player_step"},
)
graph.add_edge("init_game", END)
graph.add_conditional_edges(
"player_step",
route_ending,
{"game_over": "game_over", "continue": END},
)
graph.add_edge("game_over", END)
return graph.compile()
llm_game_graph = build_llm_game_graph()
|