update
Browse files
server/api/models.py
CHANGED
@@ -9,11 +9,10 @@ class Choice(BaseModel):
|
|
9 |
class StorySegmentResponse(BaseModel):
|
10 |
story_text: str = Field(description="The story text. No more than 30 words.")
|
11 |
|
12 |
-
@validator('story_text')
|
13 |
def validate_story_text_length(cls, v):
|
14 |
words = v.split()
|
15 |
-
if len(words) >
|
16 |
-
raise ValueError('Story text must not exceed
|
17 |
return v
|
18 |
|
19 |
class StoryPromptsResponse(BaseModel):
|
|
|
9 |
class StorySegmentResponse(BaseModel):
|
10 |
story_text: str = Field(description="The story text. No more than 30 words.")
|
11 |
|
|
|
12 |
def validate_story_text_length(cls, v):
|
13 |
words = v.split()
|
14 |
+
if len(words) > 40:
|
15 |
+
raise ValueError('Story text must not exceed 30 words')
|
16 |
return v
|
17 |
|
18 |
class StoryPromptsResponse(BaseModel):
|
server/core/generators/story_segment_generator.py
CHANGED
@@ -17,7 +17,6 @@ class StorySegmentGenerator(BaseGenerator):
|
|
17 |
self.universe_epoch = universe_epoch
|
18 |
self.universe_story = universe_story
|
19 |
self.universe_macguffin = universe_macguffin
|
20 |
-
self.max_retries = 5
|
21 |
# Then call parent constructor which will create the prompt
|
22 |
super().__init__(mistral_client, hero_name=hero_name, hero_desc=hero_desc)
|
23 |
|
@@ -90,7 +89,6 @@ Your task is to generate the next segment of the story, following these rules:
|
|
90 |
|
91 |
Hero Description: {self.hero_desc}
|
92 |
|
93 |
-
- MANDATORY: Each segment must be close to 15 words, no exceptions.
|
94 |
"""
|
95 |
|
96 |
human_template = """
|
@@ -112,15 +110,12 @@ Story history:
|
|
112 |
|
113 |
{what_to_represent}
|
114 |
|
115 |
-
|
116 |
-
Be short. Never describes game variables.
|
117 |
|
118 |
IT MUST BE THE DIRECT CONTINUATION OF THE CURRENT STORY.
|
119 |
You MUST mention the previous situation and what is happening now with the new choice.
|
120 |
Never propose choices or options. Never describe the game variables.
|
121 |
-
|
122 |
-
MANDATORY: Each segment must be close to 15 words, keep it concise.
|
123 |
-
Be short. Never describes game variables.
|
124 |
"""
|
125 |
return ChatPromptTemplate(
|
126 |
messages=[
|
@@ -191,10 +186,7 @@ Be short. Never describes game variables.
|
|
191 |
return 0 <= word_count <= 30
|
192 |
|
193 |
async def generate(self, story_beat: int, current_time: str, current_location: str, previous_choice: str, story_history: str = "", turn_before_end: int = 0, is_winning_story: bool = False) -> StorySegmentResponse:
|
194 |
-
"""Generate the next story segment
|
195 |
-
retry_count = 0
|
196 |
-
last_attempt = None
|
197 |
-
|
198 |
is_end = True if story_beat == turn_before_end else False
|
199 |
is_death = True if is_end and is_winning_story else False
|
200 |
is_victory = True if is_end and not is_winning_story else False
|
@@ -211,10 +203,9 @@ Write a story segment that:
|
|
211 |
2. Maintains consistency with the universe and story
|
212 |
3. Respects all previous rules about length and style
|
213 |
4. Naturally integrates the custom elements while staying true to the plot
|
214 |
-
Close to 15 words.
|
215 |
"""
|
216 |
|
217 |
-
# Créer les messages
|
218 |
messages = self.prompt.format_messages(
|
219 |
hero_description=self.hero_desc,
|
220 |
FORMATTING_RULES=FORMATTING_RULES,
|
@@ -230,32 +221,6 @@ Close to 15 words.
|
|
230 |
universe_macguffin=self.universe_macguffin
|
231 |
)
|
232 |
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
try:
|
237 |
-
story_text = await self.mistral_client.generate_text(current_messages)
|
238 |
-
word_count = len(story_text.split())
|
239 |
-
|
240 |
-
if self._is_valid_length(story_text):
|
241 |
-
return StorySegmentResponse(story_text=story_text)
|
242 |
-
|
243 |
-
retry_count += 1
|
244 |
-
if retry_count < self.max_retries:
|
245 |
-
# Créer un nouveau message avec le feedback sur la longueur
|
246 |
-
if word_count > 15:
|
247 |
-
feedback = f"The previous response was too long ({word_count} words). Here was your last attempt:\n\n{story_text}\n\nPlease generate a MUCH SHORTER story segment close to 15 words that continues from: {story_history}"
|
248 |
-
|
249 |
-
# Réinitialiser les messages avec les messages de base
|
250 |
-
current_messages = messages.copy()
|
251 |
-
# Ajouter le feedback
|
252 |
-
current_messages.append(HumanMessage(content=feedback))
|
253 |
-
last_attempt = story_text
|
254 |
-
continue
|
255 |
-
|
256 |
-
raise ValueError(f"Failed to generate text of valid length after {self.max_retries} attempts. Last attempt had {word_count} words.")
|
257 |
-
|
258 |
-
except Exception as e:
|
259 |
-
retry_count += 1
|
260 |
-
if retry_count >= self.max_retries:
|
261 |
-
raise e
|
|
|
17 |
self.universe_epoch = universe_epoch
|
18 |
self.universe_story = universe_story
|
19 |
self.universe_macguffin = universe_macguffin
|
|
|
20 |
# Then call parent constructor which will create the prompt
|
21 |
super().__init__(mistral_client, hero_name=hero_name, hero_desc=hero_desc)
|
22 |
|
|
|
89 |
|
90 |
Hero Description: {self.hero_desc}
|
91 |
|
|
|
92 |
"""
|
93 |
|
94 |
human_template = """
|
|
|
110 |
|
111 |
{what_to_represent}
|
112 |
|
113 |
+
Never describes game variables.
|
|
|
114 |
|
115 |
IT MUST BE THE DIRECT CONTINUATION OF THE CURRENT STORY.
|
116 |
You MUST mention the previous situation and what is happening now with the new choice.
|
117 |
Never propose choices or options. Never describe the game variables.
|
118 |
+
LIMIT: 15 words.
|
|
|
|
|
119 |
"""
|
120 |
return ChatPromptTemplate(
|
121 |
messages=[
|
|
|
186 |
return 0 <= word_count <= 30
|
187 |
|
188 |
async def generate(self, story_beat: int, current_time: str, current_location: str, previous_choice: str, story_history: str = "", turn_before_end: int = 0, is_winning_story: bool = False) -> StorySegmentResponse:
|
189 |
+
"""Generate the next story segment."""
|
|
|
|
|
|
|
190 |
is_end = True if story_beat == turn_before_end else False
|
191 |
is_death = True if is_end and is_winning_story else False
|
192 |
is_victory = True if is_end and not is_winning_story else False
|
|
|
203 |
2. Maintains consistency with the universe and story
|
204 |
3. Respects all previous rules about length and style
|
205 |
4. Naturally integrates the custom elements while staying true to the plot
|
|
|
206 |
"""
|
207 |
|
208 |
+
# Créer les messages
|
209 |
messages = self.prompt.format_messages(
|
210 |
hero_description=self.hero_desc,
|
211 |
FORMATTING_RULES=FORMATTING_RULES,
|
|
|
221 |
universe_macguffin=self.universe_macguffin
|
222 |
)
|
223 |
|
224 |
+
# Générer le texte
|
225 |
+
story_text = await self.mistral_client.generate_text(messages)
|
226 |
+
return StorySegmentResponse(story_text=story_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server/core/story_generator.py
CHANGED
@@ -26,7 +26,13 @@ class StoryGenerator:
|
|
26 |
self.model_name = model_name
|
27 |
self.turn_before_end = random.randint(GameConfig.MIN_SEGMENTS_BEFORE_END, GameConfig.MAX_SEGMENTS_BEFORE_END)
|
28 |
self.is_winning_story = random.random() < GameConfig.WINNING_STORY_CHANCE
|
|
|
|
|
29 |
self.mistral_client = MistralClient(api_key=api_key, model_name=model_name)
|
|
|
|
|
|
|
|
|
30 |
self.image_prompt_generator = None # Will be initialized with the first universe style
|
31 |
self.metadata_generator = None # Will be initialized with hero description
|
32 |
self.segment_generators: Dict[str, StorySegmentGenerator] = {}
|
@@ -65,7 +71,7 @@ class StoryGenerator:
|
|
65 |
|
66 |
# Create a new StorySegmentGenerator with all universe parameters
|
67 |
self.segment_generators[session_id] = StorySegmentGenerator(
|
68 |
-
self.
|
69 |
universe_style=style["name"],
|
70 |
universe_genre=genre,
|
71 |
universe_epoch=epoch,
|
|
|
26 |
self.model_name = model_name
|
27 |
self.turn_before_end = random.randint(GameConfig.MIN_SEGMENTS_BEFORE_END, GameConfig.MAX_SEGMENTS_BEFORE_END)
|
28 |
self.is_winning_story = random.random() < GameConfig.WINNING_STORY_CHANCE
|
29 |
+
|
30 |
+
# Client principal avec limite standard
|
31 |
self.mistral_client = MistralClient(api_key=api_key, model_name=model_name)
|
32 |
+
|
33 |
+
# Client spécifique pour les segments d'histoire avec limite plus basse
|
34 |
+
self.story_segment_client = MistralClient(api_key=api_key, model_name=model_name, max_tokens=50)
|
35 |
+
|
36 |
self.image_prompt_generator = None # Will be initialized with the first universe style
|
37 |
self.metadata_generator = None # Will be initialized with hero description
|
38 |
self.segment_generators: Dict[str, StorySegmentGenerator] = {}
|
|
|
71 |
|
72 |
# Create a new StorySegmentGenerator with all universe parameters
|
73 |
self.segment_generators[session_id] = StorySegmentGenerator(
|
74 |
+
self.story_segment_client,
|
75 |
universe_style=style["name"],
|
76 |
universe_genre=genre,
|
77 |
universe_epoch=epoch,
|
server/services/mistral_client.py
CHANGED
@@ -31,17 +31,17 @@ logger = logging.getLogger(__name__)
|
|
31 |
# Pricing: https://docs.mistral.ai/platform/pricing/
|
32 |
|
33 |
class MistralClient:
|
34 |
-
def __init__(self, api_key: str, model_name: str = "mistral-large-latest"):
|
35 |
-
logger.info(f"Initializing MistralClient with model: {model_name}")
|
36 |
self.model = ChatMistralAI(
|
37 |
mistral_api_key=api_key,
|
38 |
model=model_name,
|
39 |
-
max_tokens=
|
40 |
)
|
41 |
self.fixing_model = ChatMistralAI(
|
42 |
mistral_api_key=api_key,
|
43 |
model=model_name,
|
44 |
-
max_tokens=
|
45 |
)
|
46 |
|
47 |
# Pour gérer le rate limit
|
|
|
31 |
# Pricing: https://docs.mistral.ai/platform/pricing/
|
32 |
|
33 |
class MistralClient:
|
34 |
+
def __init__(self, api_key: str, model_name: str = "mistral-large-latest", max_tokens: int = 1000):
|
35 |
+
logger.info(f"Initializing MistralClient with model: {model_name}, max_tokens: {max_tokens}")
|
36 |
self.model = ChatMistralAI(
|
37 |
mistral_api_key=api_key,
|
38 |
model=model_name,
|
39 |
+
max_tokens=max_tokens
|
40 |
)
|
41 |
self.fixing_model = ChatMistralAI(
|
42 |
mistral_api_key=api_key,
|
43 |
model=model_name,
|
44 |
+
max_tokens=max_tokens
|
45 |
)
|
46 |
|
47 |
# Pour gérer le rate limit
|