File size: 12,317 Bytes
15369ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
''' Lower level implementation details of the gateway.
Hosts should not need to review this file before writing their competition specific gateway.
'''

import enum
import json
import os
import pathlib
import re
import subprocess
import tempfile

from socket import gaierror
from typing import Any, List, Optional, Tuple, Union

import grpc
import numpy as np
import pandas as pd
import polars as pl

import kaggle_evaluation.core.relay


_FILE_SHARE_DIR = '/kaggle/shared/'
IS_RERUN = os.getenv('KAGGLE_IS_COMPETITION_RERUN') is not None


class GatewayRuntimeErrorType(enum.Enum):
    ''' Allow-listed error types that Gateways can raise, which map to canned error messages to show users.'''
    UNSPECIFIED = 0
    SERVER_NEVER_STARTED = 1
    SERVER_CONNECTION_FAILED = 2
    SERVER_RAISED_EXCEPTION = 3
    SERVER_MISSING_ENDPOINT = 4
    # Default error type if an exception was raised that was not explicitly handled by the Gateway
    GATEWAY_RAISED_EXCEPTION = 5
    INVALID_SUBMISSION = 6


class GatewayRuntimeError(Exception):
    ''' Gateways can raise this error to capture a user-visible error enum from above and host-visible error details.'''
    def __init__(self, error_type: GatewayRuntimeErrorType, error_details: Optional[str]=None):
        self.error_type = error_type
        self.error_details = error_details


class BaseGateway():
    def __init__(self, target_column_name: Optional[str]=None):
        self.client = kaggle_evaluation.core.relay.Client('inference_server' if IS_RERUN else 'localhost')
        self.server = None  # The gateway can have a server but it isn't typically necessary.
        self.target_column_name = target_column_name  # Only used if the predictions are made as a primitive type (int, bool, etc) rather than a dataframe.

    def validate_prediction_batch(
            self,
            prediction_batch: Any,
            row_ids: Union[pl.DataFrame, pl.Series, pd.DataFrame, pd.Series]
        ):
        ''' If competitors can submit fewer rows than expected they can save all predictions for the last batch and
        bypass the benefits of the Kaggle evaluation service. This attack was seen in a real competition with the older time series API:
        https://www.kaggle.com/competitions/riiid-test-answer-prediction/discussion/196066
        It's critically important that this check be run every time predict() is called.

        If your predictions may take a variable number of rows and you need to write a custom version of this check,
        you still must specify a minimum row count greater than zero per prediction batch.
        '''
        if prediction_batch is None:
            raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'No prediction received')
        num_received_rows = None
        # Special handling for numpy ints only as numpy floats are python floats, but numpy ints aren't python ints
        for primitive_type in [int, float, str, bool, np.int_]:
            if isinstance(prediction_batch, primitive_type):
                # Types that only support one predictions per batch don't need to be validated.
                # Basic types are valid for prediction, but either don't have a length (int) or the length isn't relevant for
                # purposes of this check (str).
                num_received_rows = 1
        if num_received_rows is None:
            if type(prediction_batch) not in [pl.DataFrame, pl.Series, pd.DataFrame, pd.Series]:
                raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, f'Invalid prediction data type, received: {type(prediction_batch)}')
            num_received_rows = len(prediction_batch)
        if type(row_ids) not in [pl.DataFrame, pl.Series, pd.DataFrame, pd.Series]:
            raise GatewayRuntimeError(GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, f'Invalid row ID type {type(row_ids)}; expected Polars DataFrame or similar')
        num_expected_rows = len(row_ids)
        if len(row_ids) == 0:
            raise GatewayRuntimeError(GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, 'Missing row IDs for batch')
        if num_received_rows != num_expected_rows:
            raise GatewayRuntimeError(
                GatewayRuntimeErrorType.INVALID_SUBMISSION,
                f'Invalid predictions: expected {num_expected_rows} rows but received {num_received_rows}'
            )

    def _standardize_and_validate_paths(
            self,
            input_paths: List[Union[str, pathlib.Path]]
        ) -> List[pathlib.Path]:
        # Accept a list of str or pathlib.Path, but standardize on list of str
        for path in input_paths:
            if os.pardir in str(path):
                raise ValueError(f'Send files path contains {os.pardir}: {path}')
            if str(path) != str(os.path.normpath(path)):
                # Raise an error rather than sending users unexpectedly altered paths
                raise ValueError(f'Send files path {path} must be normalized. See `os.path.normpath`')
            if type(path) not in (pathlib.Path, str):
                raise ValueError('All paths must be of type str or pathlib.Path')
            if not os.path.exists(path):
                raise ValueError(f'Input path {path} does not exist')

        input_paths = [os.path.abspath(path) for path in input_paths]
        if len(set(input_paths)) != len(input_paths):
            raise ValueError('Duplicate input paths found')

        if not self.file_share_dir.endswith(os.path.sep):
            # Ensure output dir is valid for later use
            output_dir = self.file_share_dir + os.path.sep

        if not os.path.exists(self.file_share_dir) or not os.path.isdir(self.file_share_dir):
            raise ValueError(f'Invalid output directory {self.file_share_dir}')
        # Can't use os.path.join for output_dir + path: os.path.join won't prepend to an abspath
        output_paths = [output_dir + path for path in input_paths]
        return input_paths, output_paths

    def share_files(
            self,
            input_paths: List[Union[str, pathlib.Path]],
        ) -> List[str]:
        ''' Makes files and/or directories available to the user's inference_server. They will be mirrored under the
        self.file_share_dir directory, using the full absolute path. An input like:
            /kaggle/input/mycomp/test.csv
        Would be written to:
            /kaggle/shared/kaggle/input/mycomp/test.csv

        Args:
            input_paths: List of paths to files and/or directories that should be shared.

        Returns:
            The output paths that were shared.

        Raises:
            ValueError if any invalid paths are passed.
        '''
        input_paths, output_paths = self._standardize_and_validate_paths(input_paths)
        for in_path, out_path in zip(input_paths, output_paths):
            os.makedirs(os.path.dirname(out_path), exist_ok=True)

            # This makes the files available to the InferenceServer as read-only. Only the Gateway can mount files.
            # mount will only work in live kaggle evaluation rerun sessions. Otherwise use a symlink.
            if IS_RERUN:
                if not os.path.isdir(out_path):
                    pathlib.Path(out_path).touch()
                subprocess.run(f'mount --bind {in_path} {out_path}', shell=True, check=True)
            else:
                subprocess.run(f'ln -s {in_path} {out_path}', shell=True, check=True)

        return output_paths

    def write_submission(self, predictions, row_ids: List[Union[pl.Series, pl.DataFrame, pd.Series, pd.DataFrame]]) -> pathlib.Path:
        ''' Export the predictions to a submission file.'''
        if isinstance(predictions, list):
            if isinstance(predictions[0], pd.DataFrame):
                predictions = pd.concat(predictions, ignore_index=True)
            elif isinstance(predictions[0], pl.DataFrame):
                try:
                    predictions = pl.concat(predictions, how='vertical_relaxed')
                except pl.exceptions.SchemaError:
                    raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'Inconsistent prediction types')
                except pl.exceptions.ComputeError:
                    raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'Inconsistent prediction column counts')
            elif isinstance(predictions[0], pl.Series):
                try:
                    predictions = pl.concat(predictions, how='vertical')
                except pl.exceptions.SchemaError:
                    raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'Inconsistent prediction types')
                except pl.exceptions.ComputeError:
                    raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'Inconsistent prediction column counts')
            
            if type(row_ids[0]) in [pl.Series, pl.DataFrame]:
                row_ids = pl.concat(row_ids)
            elif type(row_ids[0]) in [pd.Series, pd.DataFrame]:
                row_ids = pd.concat(row_ids).reset_index(drop=True)
            else:
                raise GatewayRuntimeError(GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, f'Invalid row ID datatype {type(row_ids[0])}. Expected Polars series or dataframe.')
            if self.target_column_name is None:
                raise GatewayRuntimeError(GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, '`target_column_name` must be set in order to use scalar value predictions.')
            predictions = pl.DataFrame(data={row_ids.columns[0]: row_ids, self.target_column_name: predictions})

        submission_path = pathlib.Path('/kaggle/working/submission.csv')
        if not IS_RERUN:
            with tempfile.NamedTemporaryFile(prefix='kaggle-evaluation-submission-', suffix='.csv', delete=False, mode='w+') as f:
                submission_path = pathlib.Path(f.name)

        if isinstance(predictions, pd.DataFrame):
            predictions.to_csv(submission_path, index=False)
        elif isinstance(predictions, pl.DataFrame):
            pl.DataFrame(predictions).write_csv(submission_path)
        else:
            raise ValueError(f"Unsupported predictions type {type(predictions)}; can't write submission file")
        
        return submission_path

    def write_result(self, error: Optional[GatewayRuntimeError]=None):
        ''' Export a result.json containing error details if applicable.'''
        result = { 'Succeeded': error is None }

        if error is not None:
            result['ErrorType'] = error.error_type.value
            result['ErrorName'] = error.error_type.name
            # Max error detail length is 8000
            result['ErrorDetails'] = str(error.error_details[:8000]) if error.error_details else None

        with open('result.json', 'w') as f_open:
            json.dump(result, f_open)

    def handle_server_error(self, exception: Exception, endpoint: str):
        ''' Determine how to handle an exception raised when calling the inference server. Typically just format the
        error into a GatewayRuntimeError and raise.
        '''
        exception_str = str(exception)
        if isinstance(exception, gaierror) or (isinstance(exception, RuntimeError) and 'Failed to connect to server after waiting' in exception_str):
            raise GatewayRuntimeError(GatewayRuntimeErrorType.SERVER_NEVER_STARTED) from None
        if f'No listener for {endpoint} was registered' in exception_str:
            raise GatewayRuntimeError(GatewayRuntimeErrorType.SERVER_MISSING_ENDPOINT, f'Server did not register a listener for {endpoint}') from None
        if 'Exception calling application' in exception_str:
            # Extract just the exception message raised by the inference server
            message_match = re.search('"Exception calling application: (.*)"', exception_str, re.IGNORECASE)
            message = message_match.group(1) if message_match else exception_str
            raise GatewayRuntimeError(GatewayRuntimeErrorType.SERVER_RAISED_EXCEPTION, message) from None
        if isinstance(exception, grpc._channel._InactiveRpcError):
            raise GatewayRuntimeError(GatewayRuntimeErrorType.SERVER_CONNECTION_FAILED, exception_str) from None

        raise exception