Spaces:
Running
Running
File size: 5,926 Bytes
d8e2b36 86b351a d8e2b36 86b351a d8e2b36 ee968a7 85d7f84 2999669 86b351a d8e2b36 86b351a d8e2b36 85d7f84 d8e2b36 70bca69 d8e2b36 70bca69 d8e2b36 70bca69 d8e2b36 2999669 d8e2b36 85d7f84 d8e2b36 85d7f84 d8e2b36 70bca69 d8e2b36 86b351a d8e2b36 2999669 d8e2b36 86b351a 2999669 70bca69 d8e2b36 70bca69 d8e2b36 70bca69 d8e2b36 70bca69 d8e2b36 70bca69 d8e2b36 2999669 d8e2b36 2999669 d8e2b36 85d7f84 d8e2b36 70bca69 d8e2b36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
"""LLM tools used by the game graph."""
import logging
import uuid
from typing import Annotated, Dict
from langchain_core.tools import tool
from agent.llm import create_llm
from agent.models import (
EndingCheckResult,
Scene,
SceneChoice,
SceneLLM,
StoryFrame,
StoryFrameLLM,
UserChoice,
)
from agent.prompts import ENDING_CHECK_PROMPT, SCENE_PROMPT, STORY_FRAME_PROMPT
from agent.redis_state import get_user_state, set_user_state
from agent.utils import with_retries
from images.image_generator import modify_image, generate_image
from agent.image_agent import ChangeScene
logger = logging.getLogger(__name__)
def _err(msg: str) -> str:
logger.error(msg)
return f"{{'error': '{msg}'}}"
@tool
async def generate_story_frame(
user_hash: Annotated[str, "User session ID"],
setting: Annotated[str, "Game world setting"],
character: Annotated[Dict[str, str], "Character info"],
genre: Annotated[str, "Genre"],
) -> Annotated[Dict, "Generated story frame"]:
"""Create the initial story frame and store it in user state."""
llm = create_llm().with_structured_output(StoryFrameLLM)
prompt = STORY_FRAME_PROMPT.format(
setting=setting,
character=character,
genre=genre,
)
resp: StoryFrameLLM = await with_retries(lambda: llm.ainvoke(prompt))
story_frame = StoryFrame(
lore=resp.lore,
goal=resp.goal,
milestones=resp.milestones,
endings=resp.endings,
setting=setting,
character=character,
genre=genre,
)
state = await get_user_state(user_hash)
state.story_frame = story_frame
await set_user_state(user_hash, state)
return story_frame.dict()
@tool
async def generate_scene(
user_hash: Annotated[str, "User session ID"],
last_choice: Annotated[str, "Last user choice"],
) -> Annotated[Dict, "Generated scene"]:
"""Generate a new scene based on the current user state."""
state = await get_user_state(user_hash)
if not state.story_frame:
return _err("Story frame not initialized")
llm = create_llm().with_structured_output(SceneLLM)
prompt = SCENE_PROMPT.format(
lore=state.story_frame.lore,
goal=state.story_frame.goal,
milestones=",".join(m.id for m in state.story_frame.milestones),
endings=",".join(e.id for e in state.story_frame.endings),
history="; ".join(f"{c.scene_id}:{c.choice_text}" for c in state.user_choices),
last_choice=last_choice,
)
resp: SceneLLM = await with_retries(lambda: llm.ainvoke(prompt))
if len(resp.choices) < 2:
resp = await with_retries(
lambda: llm.ainvoke(prompt + "\nThe scene must contain exactly two choices.")
)
scene_id = str(uuid.uuid4())
choices = [
SceneChoice(**ch.model_dump())
if hasattr(ch, "model_dump")
else SceneChoice(**ch)
for ch in resp.choices[:2]
]
scene = Scene(
scene_id=scene_id,
description=resp.description,
choices=choices,
image=None,
music=None,
)
state.current_scene_id = scene_id
state.scenes[scene_id] = scene
await set_user_state(user_hash, state)
return scene.dict()
@tool
async def generate_scene_image(
user_hash: Annotated[str, "User session ID"],
scene_id: Annotated[str, "Scene ID"],
change_scene: Annotated[ChangeScene, "Prompt for image generation"],
current_image: Annotated[str, "Current image"] | None = None,
) -> Annotated[str, "Path to generated image"]:
"""Generate an image for a scene and save the path in the state."""
try:
image_path = current_image
if change_scene.change_scene == "change_completely" or change_scene.change_scene == "modify":
image_path, _ = await (
generate_image(change_scene.scene_description)
if current_image is None
# for now always modify the image to avoid the generating an update in a completely wrong style
else modify_image(current_image, change_scene.scene_description)
)
state = await get_user_state(user_hash)
if scene_id in state.scenes:
state.scenes[scene_id].image = image_path
await set_user_state(user_hash, state)
return image_path
except Exception as exc: # noqa: BLE001
return _err(str(exc))
@tool
async def update_state_with_choice(
user_hash: Annotated[str, "User session ID"],
scene_id: Annotated[str, "Scene ID"],
choice_text: Annotated[str, "Chosen option"],
) -> Annotated[Dict, "Updated state"]:
"""Record the player's choice in the state."""
import datetime
state = await get_user_state(user_hash)
state.user_choices.append(
UserChoice(
scene_id=scene_id,
choice_text=choice_text,
timestamp=datetime.datetime.utcnow().isoformat(),
)
)
await set_user_state(user_hash, state)
return state.dict()
@tool
async def check_ending(
user_hash: Annotated[str, "User session ID"],
) -> Annotated[Dict, "Ending check result"]:
"""Check whether an ending has been reached."""
state = await get_user_state(user_hash)
if not state.story_frame:
return _err("No story frame")
llm = create_llm().with_structured_output(EndingCheckResult)
history = "; ".join(f"{c.scene_id}:{c.choice_text}" for c in state.user_choices)
prompt = ENDING_CHECK_PROMPT.format(
history=history,
endings=",".join(f"{e.id}:{e.condition}" for e in state.story_frame.endings),
)
resp: EndingCheckResult = await with_retries(lambda: llm.ainvoke(prompt))
if resp.ending_reached and resp.ending:
state.ending = resp.ending
await set_user_state(user_hash, state)
return {"ending_reached": True, "ending": resp.ending.dict()}
return {"ending_reached": False}
|