File size: 2,839 Bytes
c7a96cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Dict


# Text Generation Inference Errors
class ValidationError(Exception):
    def __init__(self, message: str):
        super().__init__(message)


class GenerationError(Exception):
    def __init__(self, message: str):
        super().__init__(message)


class OverloadedError(Exception):
    def __init__(self, message: str):
        super().__init__(message)


class IncompleteGenerationError(Exception):
    def __init__(self, message: str):
        super().__init__(message)


# API Inference Errors
class BadRequestError(Exception):
    def __init__(self, message: str):
        super().__init__(message)


class ShardNotReadyError(Exception):
    def __init__(self, message: str):
        super().__init__(message)


class ShardTimeoutError(Exception):
    def __init__(self, message: str):
        super().__init__(message)


class NotFoundError(Exception):
    def __init__(self, message: str):
        super().__init__(message)


class RateLimitExceededError(Exception):
    def __init__(self, message: str):
        super().__init__(message)


class NotSupportedError(Exception):
    def __init__(self, model_id: str):
        message = (
            f"Model `{model_id}` is not available for inference with this client. \n"
            "Use `huggingface_hub.inference_api.InferenceApi` instead."
        )
        super(NotSupportedError, self).__init__(message)


# Unknown error
class UnknownError(Exception):
    def __init__(self, message: str):
        super().__init__(message)


def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
    """
    Parse error given an HTTP status code and a json payload

    Args:
        status_code (`int`):
            HTTP status code
        payload (`Dict[str, str]`):
            Json payload

    Returns:
        Exception: parsed exception

    """
    # Try to parse a Text Generation Inference error
    message = payload["error"]
    if "error_type" in payload:
        error_type = payload["error_type"]
        if error_type == "generation":
            return GenerationError(message)
        if error_type == "incomplete_generation":
            return IncompleteGenerationError(message)
        if error_type == "overloaded":
            return OverloadedError(message)
        if error_type == "validation":
            return ValidationError(message)

    # Try to parse a APIInference error
    if status_code == 400:
        return BadRequestError(message)
    if status_code == 403 or status_code == 424:
        return ShardNotReadyError(message)
    if status_code == 504:
        return ShardTimeoutError(message)
    if status_code == 404:
        return NotFoundError(message)
    if status_code == 429:
        return RateLimitExceededError(message)

    # Fallback to an unknown error
    return UnknownError(message)