File size: 9,916 Bytes
eb3e391
37757ef
a61ba58
37757ef
 
eb3e391
 
 
 
37757ef
 
a61ba58
 
 
 
eb3e391
 
 
 
 
eca7f7a
 
 
 
 
 
 
 
 
 
eb3e391
 
78b81a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb3e391
d9ce58c
4c8df65
eb3e391
 
 
4c8df65
eb3e391
 
 
 
4c8df65
eb3e391
 
 
 
 
8201795
78b81a5
 
eb3e391
 
 
 
 
 
 
a61ba58
 
 
eb3e391
 
37757ef
78b81a5
 
 
 
 
 
 
 
 
 
 
 
 
37757ef
 
 
 
 
 
 
 
 
 
 
 
a61ba58
 
37757ef
 
78b81a5
 
 
 
 
 
a61ba58
37757ef
a61ba58
 
 
 
 
78b81a5
 
 
 
 
a61ba58
78b81a5
37757ef
 
 
78b81a5
37757ef
 
78b81a5
 
 
 
37757ef
 
 
78b81a5
 
 
37757ef
78b81a5
 
 
 
 
a61ba58
78b81a5
37757ef
 
78b81a5
a61ba58
 
37757ef
78b81a5
 
 
eb3e391
37757ef
 
 
eb3e391
 
 
37757ef
 
 
 
 
 
 
eb3e391
37757ef
eb3e391
 
1e8f4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9ce58c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import asyncio
import json
import logging
from typing import TypeVar, Type, Optional, Callable
from pydantic import BaseModel
from langchain_mistralai.chat_models import ChatMistralAI
from langchain.schema import SystemMessage, HumanMessage
from langchain.schema.messages import BaseMessage

T = TypeVar('T', bound=BaseModel)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Available Mistral models:
# - mistral-tiny     : Fastest, cheapest, good for testing
# - mistral-small    : Good balance of speed and quality
# - mistral-medium   : Better quality, slower than small
# - mistral-large    : Best quality, slowest and most expensive
#
# mistral-large-latest: currently points to mistral-large-2411.
# pixtral-large-latest: currently points to pixtral-large-2411.
# mistral-moderation-latest: currently points to mistral-moderation-2411.
# ministral-3b-latest: currently points to ministral-3b-2410.
# ministral-8b-latest: currently points to ministral-8b-2410.
# open-mistral-nemo: currently points to open-mistral-nemo-2407.
# mistral-small-latest: currently points to mistral-small-2409.
# codestral-latest: currently points to codestral-2501.
#
# Pricing: https://docs.mistral.ai/platform/pricing/

class MistralAPIError(Exception):
    """Base class for Mistral API errors"""
    pass

class MistralRateLimitError(MistralAPIError):
    """Raised when hitting rate limits"""
    pass

class MistralParsingError(MistralAPIError):
    """Raised when response parsing fails"""
    pass

class MistralValidationError(MistralAPIError):
    """Raised when response validation fails"""
    pass

class MistralClient:
    def __init__(self, api_key: str, model_name: str = "mistral-small-latest", max_tokens: int = 1000):
        logger.info(f"Initializing MistralClient with model: {model_name}, max_tokens: {max_tokens}")
        self.model = ChatMistralAI(
            mistral_api_key=api_key,
            model=model_name,
            max_tokens=max_tokens
        )
        self.fixing_model = ChatMistralAI(
            mistral_api_key=api_key,
            model=model_name,
            max_tokens=max_tokens
        )
        
        # Pour gérer le rate limit
        self.last_call_time = 0
        self.min_delay = 1  # 1 seconde minimum entre les appels
        self.max_retries = 5
        self.backoff_factor = 2  # For exponential backoff
        self.max_backoff = 30  # Maximum backoff time in seconds
    
    async def _wait_for_rate_limit(self):
        """Attend le temps nécessaire pour respecter le rate limit."""
        current_time = asyncio.get_event_loop().time()
        time_since_last_call = current_time - self.last_call_time
        
        if time_since_last_call < self.min_delay:
            delay = self.min_delay - time_since_last_call
            logger.debug(f"Rate limit: waiting for {delay:.2f} seconds")
            await asyncio.sleep(delay)
        
        self.last_call_time = asyncio.get_event_loop().time()

    async def _handle_api_error(self, error: Exception, retry_count: int) -> float:
        """Handle API errors and return wait time for retry"""
        wait_time = min(self.backoff_factor ** retry_count, self.max_backoff)
        
        if "rate limit" in str(error).lower():
            logger.warning(f"Rate limit hit, waiting {wait_time}s before retry")
            raise MistralRateLimitError(str(error))
        elif "403" in str(error):
            logger.error("Authentication error - invalid API key or quota exceeded")
            raise MistralAPIError("Authentication failed")
        
        return wait_time

    async def _generate_with_retry(
        self,
        messages: list[BaseMessage],
        response_model: Optional[Type[T]] = None,
        custom_parser: Optional[Callable[[str], T]] = None,
        error_feedback: str = None
    ) -> T | str:
        retry_count = 0
        last_error = None
        
        while retry_count < self.max_retries:
            try:
                logger.info(f"Attempt {retry_count + 1}/{self.max_retries}")
                
                current_messages = messages.copy()
                if error_feedback and retry_count > 0:
                    if isinstance(last_error, MistralParsingError):
                        # For parsing errors, add structured format reminder
                        current_messages.append(HumanMessage(content="Please ensure your response is in valid JSON format."))
                    elif isinstance(last_error, MistralValidationError):
                        # For validation errors, add the specific feedback
                        current_messages.append(HumanMessage(content=f"Previous error: {error_feedback}. Please try again."))
                
                await self._wait_for_rate_limit()
                try:
                    response = await self.model.ainvoke(current_messages)
                    content = response.content
                    logger.debug(f"Raw response: {content[:100]}...")
                except Exception as api_error:
                    wait_time = await self._handle_api_error(api_error, retry_count)
                    retry_count += 1
                    if retry_count < self.max_retries:
                        await asyncio.sleep(wait_time)
                        continue
                    raise

                # Si pas de parsing requis, retourner le contenu brut
                if not response_model and not custom_parser:
                    return content

                # Parser la réponse
                try:
                    if custom_parser:
                        return custom_parser(content)
                    
                    # Essayer de parser avec le modèle Pydantic
                    data = json.loads(content)
                    return response_model(**data)
                except json.JSONDecodeError as e:
                    last_error = MistralParsingError(f"Invalid JSON format: {str(e)}")
                    logger.error(f"JSON parsing error: {str(e)}")
                    raise last_error
                except Exception as e:
                    last_error = MistralValidationError(str(e))
                    logger.error(f"Validation error: {str(e)}")
                    raise last_error

            except (MistralParsingError, MistralValidationError) as e:
                logger.error(f"Error on attempt {retry_count + 1}/{self.max_retries}: {str(e)}")
                last_error = e
                retry_count += 1
                if retry_count < self.max_retries:
                    wait_time = min(self.backoff_factor ** retry_count, self.max_backoff)
                    logger.info(f"Waiting {wait_time} seconds before retry...")
                    await asyncio.sleep(wait_time)
                    continue
                
                logger.error(f"Failed after {self.max_retries} attempts. Last error: {str(last_error)}")
                raise Exception(f"Failed after {self.max_retries} attempts. Last error: {str(last_error)}")
    
    async def generate(self, messages: list[BaseMessage], response_model: Optional[Type[T]] = None, custom_parser: Optional[Callable[[str], T]] = None) -> T | str:
        """Génère une réponse à partir d'une liste de messages avec parsing optionnel."""
        return await self._generate_with_retry(messages, response_model, custom_parser)

    async def transform_prompt(self, story_text: str, art_prompt: str) -> str:
        """Transforme un texte d'histoire en prompt artistique."""
        messages = [{
            "role": "system",
            "content": art_prompt
        }, {
            "role": "user",
            "content": f"Transform this story text into a comic panel description:\n{story_text}"
        }]
        try:
            return await self._generate_with_retry(messages)
        except Exception as e:
            print(f"Error transforming prompt: {str(e)}")
            return story_text 

    async def generate_text(self, messages: list[BaseMessage]) -> str:
        """
        Génère une réponse textuelle simple sans structure JSON.
        Utile pour la génération de texte narratif ou descriptif.
        
        Args:
            messages: Liste des messages pour le modèle
            
        Returns:
            str: Le texte généré
        """
        retry_count = 0
        last_error = None
        
        while retry_count < self.max_retries:
            try:
                logger.info(f"Attempt {retry_count + 1}/{self.max_retries}")
                
                await self._wait_for_rate_limit()
                response = await self.model.ainvoke(messages)
                return response.content.strip()
                
            except Exception as e:
                logger.error(f"Error on attempt {retry_count + 1}/{self.max_retries}: {str(e)}")
                retry_count += 1
                if retry_count < self.max_retries:
                    wait_time = 2 * retry_count
                    logger.info(f"Waiting {wait_time} seconds before retry...")
                    await asyncio.sleep(wait_time)
                    continue
                
                logger.error(f"Failed after {self.max_retries} attempts. Last error: {last_error or str(e)}")
                raise Exception(f"Failed after {self.max_retries} attempts. Last error: {last_error or str(e)}") 

    async def check_health(self) -> bool:
        """
        Vérifie la disponibilité du service Mistral avec un appel simple sans retry.
        
        Returns:
            bool: True si le service est disponible, False sinon
        """
        try:
            response = await self.model.ainvoke([SystemMessage(content="Hi")])
            return True
        except Exception as e:
            logger.error(f"Health check failed: {str(e)}")
            raise