"""LLM tools used by the game graph.""" import logging import uuid from typing import Annotated, Dict from langchain_core.tools import tool from agent.llm import create_llm from agent.models import ( EndingCheckResult, Scene, SceneChoice, SceneLLM, StoryFrame, StoryFrameLLM, UserChoice, ) from agent.prompts import ENDING_CHECK_PROMPT, SCENE_PROMPT, STORY_FRAME_PROMPT from agent.redis_state import get_user_state, set_user_state from agent.utils import with_retries from images.image_generator import modify_image, generate_image from agent.image_agent import ChangeScene logger = logging.getLogger(__name__) def _err(msg: str) -> str: logger.error(msg) return f"{{'error': '{msg}'}}" @tool async def generate_story_frame( user_hash: Annotated[str, "User session ID"], setting: Annotated[str, "Game world setting"], character: Annotated[Dict[str, str], "Character info"], genre: Annotated[str, "Genre"], ) -> Annotated[Dict, "Generated story frame"]: """Create the initial story frame and store it in user state.""" llm = create_llm().with_structured_output(StoryFrameLLM) prompt = STORY_FRAME_PROMPT.format( setting=setting, character=character, genre=genre, ) resp: StoryFrameLLM = await with_retries(lambda: llm.ainvoke(prompt)) story_frame = StoryFrame( lore=resp.lore, goal=resp.goal, milestones=resp.milestones, endings=resp.endings, setting=setting, character=character, genre=genre, ) state = await get_user_state(user_hash) state.story_frame = story_frame await set_user_state(user_hash, state) return story_frame.dict() @tool async def generate_scene( user_hash: Annotated[str, "User session ID"], last_choice: Annotated[str, "Last user choice"], ) -> Annotated[Dict, "Generated scene"]: """Generate a new scene based on the current user state.""" state = await get_user_state(user_hash) if not state.story_frame: return _err("Story frame not initialized") llm = create_llm().with_structured_output(SceneLLM) prompt = SCENE_PROMPT.format( lore=state.story_frame.lore, goal=state.story_frame.goal, milestones=",".join(m.id for m in state.story_frame.milestones), endings=",".join(e.id for e in state.story_frame.endings), history="; ".join(f"{c.scene_id}:{c.choice_text}" for c in state.user_choices), last_choice=last_choice, ) resp: SceneLLM = await with_retries(lambda: llm.ainvoke(prompt)) if len(resp.choices) < 2: resp = await with_retries( lambda: llm.ainvoke(prompt + "\nThe scene must contain exactly two choices.") ) scene_id = str(uuid.uuid4()) choices = [ SceneChoice(**ch.model_dump()) if hasattr(ch, "model_dump") else SceneChoice(**ch) for ch in resp.choices[:2] ] scene = Scene( scene_id=scene_id, description=resp.description, choices=choices, image=None, music=None, ) state.current_scene_id = scene_id state.scenes[scene_id] = scene await set_user_state(user_hash, state) return scene.dict() @tool async def generate_scene_image( user_hash: Annotated[str, "User session ID"], scene_id: Annotated[str, "Scene ID"], change_scene: Annotated[ChangeScene, "Prompt for image generation"], current_image: Annotated[str, "Current image"] | None = None, ) -> Annotated[str, "Path to generated image"]: """Generate an image for a scene and save the path in the state.""" try: image_path = current_image if change_scene.change_scene == "change_completely" or change_scene.change_scene == "modify": image_path, _ = await ( generate_image(change_scene.scene_description) if current_image is None # for now always modify the image to avoid the generating an update in a completely wrong style else modify_image(current_image, change_scene.scene_description) ) state = await get_user_state(user_hash) if scene_id in state.scenes: state.scenes[scene_id].image = image_path await set_user_state(user_hash, state) return image_path except Exception as exc: # noqa: BLE001 return _err(str(exc)) @tool async def update_state_with_choice( user_hash: Annotated[str, "User session ID"], scene_id: Annotated[str, "Scene ID"], choice_text: Annotated[str, "Chosen option"], ) -> Annotated[Dict, "Updated state"]: """Record the player's choice in the state.""" import datetime state = await get_user_state(user_hash) state.user_choices.append( UserChoice( scene_id=scene_id, choice_text=choice_text, timestamp=datetime.datetime.utcnow().isoformat(), ) ) await set_user_state(user_hash, state) return state.dict() @tool async def check_ending( user_hash: Annotated[str, "User session ID"], ) -> Annotated[Dict, "Ending check result"]: """Check whether an ending has been reached.""" state = await get_user_state(user_hash) if not state.story_frame: return _err("No story frame") llm = create_llm().with_structured_output(EndingCheckResult) history = "; ".join(f"{c.scene_id}:{c.choice_text}" for c in state.user_choices) prompt = ENDING_CHECK_PROMPT.format( history=history, endings=",".join(f"{e.id}:{e.condition}" for e in state.story_frame.endings), ) resp: EndingCheckResult = await with_retries(lambda: llm.ainvoke(prompt)) if resp.ending_reached and resp.ending: state.ending = resp.ending await set_user_state(user_hash, state) return {"ending_reached": True, "ending": resp.ending.dict()} return {"ending_reached": False}