File size: 9,916 Bytes
eb3e391 37757ef a61ba58 37757ef eb3e391 37757ef a61ba58 eb3e391 eca7f7a eb3e391 78b81a5 eb3e391 d9ce58c 4c8df65 eb3e391 4c8df65 eb3e391 4c8df65 eb3e391 8201795 78b81a5 eb3e391 a61ba58 eb3e391 37757ef 78b81a5 37757ef a61ba58 37757ef 78b81a5 a61ba58 37757ef a61ba58 78b81a5 a61ba58 78b81a5 37757ef 78b81a5 37757ef 78b81a5 37757ef 78b81a5 37757ef 78b81a5 a61ba58 78b81a5 37757ef 78b81a5 a61ba58 37757ef 78b81a5 eb3e391 37757ef eb3e391 37757ef eb3e391 37757ef eb3e391 1e8f4c6 d9ce58c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import asyncio
import json
import logging
from typing import TypeVar, Type, Optional, Callable
from pydantic import BaseModel
from langchain_mistralai.chat_models import ChatMistralAI
from langchain.schema import SystemMessage, HumanMessage
from langchain.schema.messages import BaseMessage
T = TypeVar('T', bound=BaseModel)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Available Mistral models:
# - mistral-tiny : Fastest, cheapest, good for testing
# - mistral-small : Good balance of speed and quality
# - mistral-medium : Better quality, slower than small
# - mistral-large : Best quality, slowest and most expensive
#
# mistral-large-latest: currently points to mistral-large-2411.
# pixtral-large-latest: currently points to pixtral-large-2411.
# mistral-moderation-latest: currently points to mistral-moderation-2411.
# ministral-3b-latest: currently points to ministral-3b-2410.
# ministral-8b-latest: currently points to ministral-8b-2410.
# open-mistral-nemo: currently points to open-mistral-nemo-2407.
# mistral-small-latest: currently points to mistral-small-2409.
# codestral-latest: currently points to codestral-2501.
#
# Pricing: https://docs.mistral.ai/platform/pricing/
class MistralAPIError(Exception):
"""Base class for Mistral API errors"""
pass
class MistralRateLimitError(MistralAPIError):
"""Raised when hitting rate limits"""
pass
class MistralParsingError(MistralAPIError):
"""Raised when response parsing fails"""
pass
class MistralValidationError(MistralAPIError):
"""Raised when response validation fails"""
pass
class MistralClient:
def __init__(self, api_key: str, model_name: str = "mistral-small-latest", max_tokens: int = 1000):
logger.info(f"Initializing MistralClient with model: {model_name}, max_tokens: {max_tokens}")
self.model = ChatMistralAI(
mistral_api_key=api_key,
model=model_name,
max_tokens=max_tokens
)
self.fixing_model = ChatMistralAI(
mistral_api_key=api_key,
model=model_name,
max_tokens=max_tokens
)
# Pour gérer le rate limit
self.last_call_time = 0
self.min_delay = 1 # 1 seconde minimum entre les appels
self.max_retries = 5
self.backoff_factor = 2 # For exponential backoff
self.max_backoff = 30 # Maximum backoff time in seconds
async def _wait_for_rate_limit(self):
"""Attend le temps nécessaire pour respecter le rate limit."""
current_time = asyncio.get_event_loop().time()
time_since_last_call = current_time - self.last_call_time
if time_since_last_call < self.min_delay:
delay = self.min_delay - time_since_last_call
logger.debug(f"Rate limit: waiting for {delay:.2f} seconds")
await asyncio.sleep(delay)
self.last_call_time = asyncio.get_event_loop().time()
async def _handle_api_error(self, error: Exception, retry_count: int) -> float:
"""Handle API errors and return wait time for retry"""
wait_time = min(self.backoff_factor ** retry_count, self.max_backoff)
if "rate limit" in str(error).lower():
logger.warning(f"Rate limit hit, waiting {wait_time}s before retry")
raise MistralRateLimitError(str(error))
elif "403" in str(error):
logger.error("Authentication error - invalid API key or quota exceeded")
raise MistralAPIError("Authentication failed")
return wait_time
async def _generate_with_retry(
self,
messages: list[BaseMessage],
response_model: Optional[Type[T]] = None,
custom_parser: Optional[Callable[[str], T]] = None,
error_feedback: str = None
) -> T | str:
retry_count = 0
last_error = None
while retry_count < self.max_retries:
try:
logger.info(f"Attempt {retry_count + 1}/{self.max_retries}")
current_messages = messages.copy()
if error_feedback and retry_count > 0:
if isinstance(last_error, MistralParsingError):
# For parsing errors, add structured format reminder
current_messages.append(HumanMessage(content="Please ensure your response is in valid JSON format."))
elif isinstance(last_error, MistralValidationError):
# For validation errors, add the specific feedback
current_messages.append(HumanMessage(content=f"Previous error: {error_feedback}. Please try again."))
await self._wait_for_rate_limit()
try:
response = await self.model.ainvoke(current_messages)
content = response.content
logger.debug(f"Raw response: {content[:100]}...")
except Exception as api_error:
wait_time = await self._handle_api_error(api_error, retry_count)
retry_count += 1
if retry_count < self.max_retries:
await asyncio.sleep(wait_time)
continue
raise
# Si pas de parsing requis, retourner le contenu brut
if not response_model and not custom_parser:
return content
# Parser la réponse
try:
if custom_parser:
return custom_parser(content)
# Essayer de parser avec le modèle Pydantic
data = json.loads(content)
return response_model(**data)
except json.JSONDecodeError as e:
last_error = MistralParsingError(f"Invalid JSON format: {str(e)}")
logger.error(f"JSON parsing error: {str(e)}")
raise last_error
except Exception as e:
last_error = MistralValidationError(str(e))
logger.error(f"Validation error: {str(e)}")
raise last_error
except (MistralParsingError, MistralValidationError) as e:
logger.error(f"Error on attempt {retry_count + 1}/{self.max_retries}: {str(e)}")
last_error = e
retry_count += 1
if retry_count < self.max_retries:
wait_time = min(self.backoff_factor ** retry_count, self.max_backoff)
logger.info(f"Waiting {wait_time} seconds before retry...")
await asyncio.sleep(wait_time)
continue
logger.error(f"Failed after {self.max_retries} attempts. Last error: {str(last_error)}")
raise Exception(f"Failed after {self.max_retries} attempts. Last error: {str(last_error)}")
async def generate(self, messages: list[BaseMessage], response_model: Optional[Type[T]] = None, custom_parser: Optional[Callable[[str], T]] = None) -> T | str:
"""Génère une réponse à partir d'une liste de messages avec parsing optionnel."""
return await self._generate_with_retry(messages, response_model, custom_parser)
async def transform_prompt(self, story_text: str, art_prompt: str) -> str:
"""Transforme un texte d'histoire en prompt artistique."""
messages = [{
"role": "system",
"content": art_prompt
}, {
"role": "user",
"content": f"Transform this story text into a comic panel description:\n{story_text}"
}]
try:
return await self._generate_with_retry(messages)
except Exception as e:
print(f"Error transforming prompt: {str(e)}")
return story_text
async def generate_text(self, messages: list[BaseMessage]) -> str:
"""
Génère une réponse textuelle simple sans structure JSON.
Utile pour la génération de texte narratif ou descriptif.
Args:
messages: Liste des messages pour le modèle
Returns:
str: Le texte généré
"""
retry_count = 0
last_error = None
while retry_count < self.max_retries:
try:
logger.info(f"Attempt {retry_count + 1}/{self.max_retries}")
await self._wait_for_rate_limit()
response = await self.model.ainvoke(messages)
return response.content.strip()
except Exception as e:
logger.error(f"Error on attempt {retry_count + 1}/{self.max_retries}: {str(e)}")
retry_count += 1
if retry_count < self.max_retries:
wait_time = 2 * retry_count
logger.info(f"Waiting {wait_time} seconds before retry...")
await asyncio.sleep(wait_time)
continue
logger.error(f"Failed after {self.max_retries} attempts. Last error: {last_error or str(e)}")
raise Exception(f"Failed after {self.max_retries} attempts. Last error: {last_error or str(e)}")
async def check_health(self) -> bool:
"""
Vérifie la disponibilité du service Mistral avec un appel simple sans retry.
Returns:
bool: True si le service est disponible, False sinon
"""
try:
response = await self.model.ainvoke([SystemMessage(content="Hi")])
return True
except Exception as e:
logger.error(f"Health check failed: {str(e)}")
raise |