|
""" |
|
Get num retries for an exception. |
|
|
|
- Account for retry policy by exception type. |
|
""" |
|
|
|
from typing import Dict, Optional, Union |
|
|
|
from litellm.exceptions import ( |
|
AuthenticationError, |
|
BadRequestError, |
|
ContentPolicyViolationError, |
|
RateLimitError, |
|
Timeout, |
|
) |
|
from litellm.types.router import RetryPolicy |
|
|
|
|
|
def get_num_retries_from_retry_policy( |
|
exception: Exception, |
|
retry_policy: Optional[Union[RetryPolicy, dict]] = None, |
|
model_group: Optional[str] = None, |
|
model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = None, |
|
): |
|
""" |
|
BadRequestErrorRetries: Optional[int] = None |
|
AuthenticationErrorRetries: Optional[int] = None |
|
TimeoutErrorRetries: Optional[int] = None |
|
RateLimitErrorRetries: Optional[int] = None |
|
ContentPolicyViolationErrorRetries: Optional[int] = None |
|
""" |
|
|
|
|
|
if ( |
|
model_group_retry_policy is not None |
|
and model_group is not None |
|
and model_group in model_group_retry_policy |
|
): |
|
retry_policy = model_group_retry_policy.get(model_group, None) |
|
|
|
if retry_policy is None: |
|
return None |
|
if isinstance(retry_policy, dict): |
|
retry_policy = RetryPolicy(**retry_policy) |
|
|
|
if ( |
|
isinstance(exception, BadRequestError) |
|
and retry_policy.BadRequestErrorRetries is not None |
|
): |
|
return retry_policy.BadRequestErrorRetries |
|
if ( |
|
isinstance(exception, AuthenticationError) |
|
and retry_policy.AuthenticationErrorRetries is not None |
|
): |
|
return retry_policy.AuthenticationErrorRetries |
|
if isinstance(exception, Timeout) and retry_policy.TimeoutErrorRetries is not None: |
|
return retry_policy.TimeoutErrorRetries |
|
if ( |
|
isinstance(exception, RateLimitError) |
|
and retry_policy.RateLimitErrorRetries is not None |
|
): |
|
return retry_policy.RateLimitErrorRetries |
|
if ( |
|
isinstance(exception, ContentPolicyViolationError) |
|
and retry_policy.ContentPolicyViolationErrorRetries is not None |
|
): |
|
return retry_policy.ContentPolicyViolationErrorRetries |
|
|
|
|
|
def reset_retry_policy() -> RetryPolicy: |
|
return RetryPolicy() |
|
|