File size: 5,926 Bytes
d8e2b36
 
86b351a
d8e2b36
 
 
86b351a
d8e2b36
 
 
 
 
 
 
 
 
 
 
 
ee968a7
85d7f84
2999669
 
86b351a
 
 
d8e2b36
86b351a
 
d8e2b36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85d7f84
d8e2b36
 
 
 
 
 
 
 
 
70bca69
d8e2b36
70bca69
d8e2b36
 
 
 
 
 
 
 
 
70bca69
d8e2b36
 
 
 
 
 
2999669
 
 
d8e2b36
 
85d7f84
d8e2b36
85d7f84
 
d8e2b36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70bca69
d8e2b36
86b351a
 
 
 
d8e2b36
 
2999669
 
d8e2b36
 
86b351a
2999669
 
 
 
 
 
 
 
70bca69
d8e2b36
 
70bca69
d8e2b36
 
 
 
 
 
 
 
 
 
 
 
 
 
70bca69
d8e2b36
 
 
 
 
 
 
70bca69
d8e2b36
 
 
 
 
 
 
 
70bca69
d8e2b36
 
 
2999669
d8e2b36
 
2999669
d8e2b36
85d7f84
d8e2b36
 
70bca69
d8e2b36
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""LLM tools used by the game graph."""

import logging
import uuid
from typing import Annotated, Dict

from langchain_core.tools import tool

from agent.llm import create_llm
from agent.models import (
    EndingCheckResult,
    Scene,
    SceneChoice,
    SceneLLM,
    StoryFrame,
    StoryFrameLLM,
    UserChoice,
)
from agent.prompts import ENDING_CHECK_PROMPT, SCENE_PROMPT, STORY_FRAME_PROMPT
from agent.redis_state import get_user_state, set_user_state
from agent.utils import with_retries
from images.image_generator import modify_image, generate_image
from agent.image_agent import ChangeScene

logger = logging.getLogger(__name__)


def _err(msg: str) -> str:
    logger.error(msg)
    return f"{{'error': '{msg}'}}"


@tool
async def generate_story_frame(
    user_hash: Annotated[str, "User session ID"],
    setting: Annotated[str, "Game world setting"],
    character: Annotated[Dict[str, str], "Character info"],
    genre: Annotated[str, "Genre"],
) -> Annotated[Dict, "Generated story frame"]:
    """Create the initial story frame and store it in user state."""
    llm = create_llm().with_structured_output(StoryFrameLLM)
    prompt = STORY_FRAME_PROMPT.format(
        setting=setting,
        character=character,
        genre=genre,
    )
    resp: StoryFrameLLM = await with_retries(lambda: llm.ainvoke(prompt))
    story_frame = StoryFrame(
        lore=resp.lore,
        goal=resp.goal,
        milestones=resp.milestones,
        endings=resp.endings,
        setting=setting,
        character=character,
        genre=genre,
    )
    state = await get_user_state(user_hash)
    state.story_frame = story_frame
    await set_user_state(user_hash, state)
    return story_frame.dict()


@tool
async def generate_scene(
    user_hash: Annotated[str, "User session ID"],
    last_choice: Annotated[str, "Last user choice"],
) -> Annotated[Dict, "Generated scene"]:
    """Generate a new scene based on the current user state."""
    state = await get_user_state(user_hash)
    if not state.story_frame:
        return _err("Story frame not initialized")
    llm = create_llm().with_structured_output(SceneLLM)
    prompt = SCENE_PROMPT.format(
        lore=state.story_frame.lore,
        goal=state.story_frame.goal,
        milestones=",".join(m.id for m in state.story_frame.milestones),
        endings=",".join(e.id for e in state.story_frame.endings),
        history="; ".join(f"{c.scene_id}:{c.choice_text}" for c in state.user_choices),
        last_choice=last_choice,
    )
    resp: SceneLLM = await with_retries(lambda: llm.ainvoke(prompt))
    if len(resp.choices) < 2:
        resp = await with_retries(
            lambda: llm.ainvoke(prompt + "\nThe scene must contain exactly two choices.")
        )
    scene_id = str(uuid.uuid4())
    choices = [
        SceneChoice(**ch.model_dump())
        if hasattr(ch, "model_dump")
        else SceneChoice(**ch)
        for ch in resp.choices[:2]
    ]
    scene = Scene(
        scene_id=scene_id,
        description=resp.description,
        choices=choices,
        image=None,
        music=None,
    )
    state.current_scene_id = scene_id
    state.scenes[scene_id] = scene
    await set_user_state(user_hash, state)
    return scene.dict()


@tool
async def generate_scene_image(
    user_hash: Annotated[str, "User session ID"],
    scene_id: Annotated[str, "Scene ID"],
    change_scene: Annotated[ChangeScene, "Prompt for image generation"],
    current_image: Annotated[str, "Current image"] | None = None,
) -> Annotated[str, "Path to generated image"]:
    """Generate an image for a scene and save the path in the state."""
    try:
        image_path = current_image
        if change_scene.change_scene == "change_completely" or change_scene.change_scene == "modify":
            image_path, _ = await (
                generate_image(change_scene.scene_description)
                if current_image is None
                # for now always modify the image to avoid the generating an update in a completely wrong style
                else modify_image(current_image, change_scene.scene_description)
            )
        state = await get_user_state(user_hash)
        if scene_id in state.scenes:
            state.scenes[scene_id].image = image_path
        await set_user_state(user_hash, state)
        return image_path
    except Exception as exc:  # noqa: BLE001
        return _err(str(exc))


@tool
async def update_state_with_choice(
    user_hash: Annotated[str, "User session ID"],
    scene_id: Annotated[str, "Scene ID"],
    choice_text: Annotated[str, "Chosen option"],
) -> Annotated[Dict, "Updated state"]:
    """Record the player's choice in the state."""
    import datetime

    state = await get_user_state(user_hash)
    state.user_choices.append(
        UserChoice(
            scene_id=scene_id,
            choice_text=choice_text,
            timestamp=datetime.datetime.utcnow().isoformat(),
        )
    )
    await set_user_state(user_hash, state)
    return state.dict()


@tool
async def check_ending(
    user_hash: Annotated[str, "User session ID"],
) -> Annotated[Dict, "Ending check result"]:
    """Check whether an ending has been reached."""
    state = await get_user_state(user_hash)
    if not state.story_frame:
        return _err("No story frame")
    llm = create_llm().with_structured_output(EndingCheckResult)
    history = "; ".join(f"{c.scene_id}:{c.choice_text}" for c in state.user_choices)
    prompt = ENDING_CHECK_PROMPT.format(
        history=history,
        endings=",".join(f"{e.id}:{e.condition}" for e in state.story_frame.endings),
    )
    resp: EndingCheckResult = await with_retries(lambda: llm.ainvoke(prompt))
    if resp.ending_reached and resp.ending:
        state.ending = resp.ending
        await set_user_state(user_hash, state)
        return {"ending_reached": True, "ending": resp.ending.dict()}
    return {"ending_reached": False}