File size: 16,787 Bytes
15369ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
'''
Core implementation of the client module, implementing generic communication
patterns with Python in / Python out supporting many (nested) primitives +
special data science types like DataFrames or np.ndarrays, with gRPC + protobuf
as a backing implementation.
'''

import grpc
import io
import json
import socket
import time

from concurrent import futures
from typing import Callable, List, Tuple

import numpy as np
import pandas as pd
import polars as pl
import pyarrow

import kaggle_evaluation.core.generated.kaggle_evaluation_pb2 as kaggle_evaluation_proto
import kaggle_evaluation.core.generated.kaggle_evaluation_pb2_grpc as kaggle_evaluation_grpc


_SERVICE_CONFIG = {
    # Service config proto: https://github.com/grpc/grpc-proto/blob/ec886024c2f7b7f597ba89d5b7d60c3f94627b17/grpc/service_config/service_config.proto#L377
    'methodConfig': [
        {
            'name': [{}],  # Applies to all methods
            # See retry policy docs: https://grpc.io/docs/guides/retry/
            'retryPolicy': {
                'maxAttempts': 5,
                'initialBackoff': '0.1s',
                'maxBackoff': '1s',
                'backoffMultiplier': 1, # Ensure relatively rapid feedback in the event of a crash
                'retryableStatusCodes': ['UNAVAILABLE'],
            },
        }
    ]
}
_GRPC_PORT = 50051
_GRPC_CHANNEL_OPTIONS = [
    # -1 for unlimited message send/receive size
    # https://github.com/grpc/grpc/blob/v1.64.x/include/grpc/impl/channel_arg_names.h#L39
    ('grpc.max_send_message_length', -1),
    ('grpc.max_receive_message_length', -1),
    # https://github.com/grpc/grpc/blob/master/doc/keepalive.md
    ('grpc.keepalive_time_ms', 60_000),  # Time between heartbeat pings
    ('grpc.keepalive_timeout_ms', 5_000),  # Time allowed to respond to pings
    ('grpc.http2.max_pings_without_data', 0), # Remove another cap on pings
    ('grpc.keepalive_permit_without_calls', 1), # Allow heartbeat pings at any time
    ('grpc.http2.min_ping_interval_without_data_ms', 1_000),
    ('grpc.service_config', json.dumps(_SERVICE_CONFIG)),
]


DEFAULT_DEADLINE_SECONDS = 60 * 60
_RETRY_SLEEP_SECONDS = 1
# Enforce a relatively strict server startup time so users can get feedback quickly if they're not
# configuring KaggleEvaluation correctly. We really don't want notebooks timing out after nine hours
# somebody forgot to start their inference_server. Slow steps like loading models
# can happen during the first inference call if necessary.
STARTUP_LIMIT_SECONDS = 60 * 15

### Utils shared by client and server for data transfer

# pl.Enum is currently unstable, but we should eventually consider supporting it.
# https://docs.pola.rs/api/python/stable/reference/api/polars.datatypes.Enum.html#polars.datatypes.Enum
_POLARS_TYPE_DENYLIST = set([pl.Enum, pl.Object, pl.Unknown])

def _serialize(data) -> kaggle_evaluation_proto.Payload:
    '''Maps input data of one of several allow-listed types to a protobuf message to be sent over gRPC.

    Args:
        data: The input data to be mapped. Any of the types listed below are accepted.

    Returns:
        The Payload protobuf message.

    Raises:
        TypeError if data is of an unsupported type.
    '''
    # Python primitives and Numpy scalars
    if isinstance(data, np.generic):
        # Numpy functions that return a single number return numpy scalars instead of python primitives.
        # In some cases this difference matters: https://numpy.org/devdocs/release/2.0.0-notes.html#representation-of-numpy-scalars-changed
        # Ex: np.mean(1,2) yields np.float64(1.5) instead of 1.5.
        # Check for numpy scalars first since most of them also inherit from python primitives.
        # For example, `np.float64(1.5)` is an instance of `float` among many other things.
        # https://numpy.org/doc/stable/reference/arrays.scalars.html
        assert data.shape == ()  # Additional validation that the np.generic type remains solely for scalars
        assert isinstance(data, np.number) or isinstance(data, np.bool_)  # No support for bytes, strings, objects, etc
        buffer = io.BytesIO()
        np.save(buffer, data, allow_pickle=False)
        return kaggle_evaluation_proto.Payload(numpy_scalar_value=buffer.getvalue())
    elif isinstance(data, str):
        return kaggle_evaluation_proto.Payload(str_value=data)
    elif isinstance(data, bool): # bool is a subclass of int, so check that first
        return kaggle_evaluation_proto.Payload(bool_value=data)
    elif isinstance(data, int):
        return kaggle_evaluation_proto.Payload(int_value=data)
    elif isinstance(data, float):
        return kaggle_evaluation_proto.Payload(float_value=data)
    elif data is None:
        return kaggle_evaluation_proto.Payload(none_value=True)
    # Iterables for nested types
    if isinstance(data, list):
        return kaggle_evaluation_proto.Payload(list_value=kaggle_evaluation_proto.PayloadList(payloads=map(_serialize, data)))
    elif isinstance(data, tuple):
        return kaggle_evaluation_proto.Payload(tuple_value=kaggle_evaluation_proto.PayloadList(payloads=map(_serialize, data)))
    elif isinstance(data, dict):
        serialized_dict = {}
        for key, value in data.items():
            if not isinstance(key, str):
                raise TypeError(f'KaggleEvaluation only supports dicts with keys of type str, found {type(key)}.')
            serialized_dict[key] = _serialize(value)
        return kaggle_evaluation_proto.Payload(dict_value=kaggle_evaluation_proto.PayloadMap(payload_map=serialized_dict))
    # Allowlisted special types
    if isinstance(data, pd.DataFrame):
        buffer = io.BytesIO()
        data.to_parquet(buffer, index=False, compression='lz4')
        return kaggle_evaluation_proto.Payload(pandas_dataframe_value=buffer.getvalue())
    elif isinstance(data, pl.DataFrame):
        data_types = set(i.base_type() for i in data.dtypes)
        banned_types = _POLARS_TYPE_DENYLIST.intersection(data_types)
        if len(banned_types) > 0:
            raise TypeError(f'Unsupported Polars data type(s): {banned_types}')

        table = data.to_arrow()
        buffer = io.BytesIO()
        with pyarrow.ipc.new_stream(buffer, table.schema, options=pyarrow.ipc.IpcWriteOptions(compression='lz4')) as writer:
            writer.write_table(table)
        return kaggle_evaluation_proto.Payload(polars_dataframe_value=buffer.getvalue())
    elif isinstance(data, pd.Series):
        buffer = io.BytesIO()
        # Can't serialize a pd.Series directly to parquet, must use intermediate DataFrame
        pd.DataFrame(data).to_parquet(buffer, index=False, compression='lz4')
        return kaggle_evaluation_proto.Payload(pandas_series_value=buffer.getvalue())
    elif isinstance(data, pl.Series):
        buffer = io.BytesIO()
        # Can't serialize a pl.Series directly to parquet, must use intermediate DataFrame
        pl.DataFrame(data).write_parquet(buffer, compression='lz4', statistics=False)
        return kaggle_evaluation_proto.Payload(polars_series_value=buffer.getvalue())
    elif isinstance(data, np.ndarray):
        buffer = io.BytesIO()
        np.save(buffer, data, allow_pickle=False)
        return kaggle_evaluation_proto.Payload(numpy_array_value=buffer.getvalue())
    elif isinstance(data, io.BytesIO):
        return kaggle_evaluation_proto.Payload(bytes_io_value=data.getvalue())

    raise TypeError(f'Type {type(data)} not supported for KaggleEvaluation.')


def _deserialize(payload: kaggle_evaluation_proto.Payload):
    '''Maps a Payload protobuf message to a value of whichever type was set on the message.

    Args:
        payload: The message to be mapped.

    Returns:
        A value of one of several allow-listed types.

    Raises:
        TypeError if an unexpected value data type is found.
    '''
    # Primitives
    if payload.WhichOneof('value') == 'str_value':
        return payload.str_value
    elif payload.WhichOneof('value') == 'bool_value':
        return payload.bool_value
    elif payload.WhichOneof('value') == 'int_value':
        return payload.int_value
    elif payload.WhichOneof('value') == 'float_value':
        return payload.float_value
    elif payload.WhichOneof('value') == 'none_value':
        return None
    # Iterables for nested types
    elif payload.WhichOneof('value') == 'list_value':
        return list(map(_deserialize, payload.list_value.payloads))
    elif payload.WhichOneof('value') == 'tuple_value':
        return tuple(map(_deserialize, payload.tuple_value.payloads))
    elif payload.WhichOneof('value') == 'dict_value':
        return {key: _deserialize(value) for key, value in payload.dict_value.payload_map.items()}
    # Allowlisted special types
    elif payload.WhichOneof('value') == 'pandas_dataframe_value':
        return pd.read_parquet(io.BytesIO(payload.pandas_dataframe_value))
    elif payload.WhichOneof('value') == 'polars_dataframe_value':
        with pyarrow.ipc.open_stream(payload.polars_dataframe_value) as reader:
            table = reader.read_all()
        return pl.from_arrow(table)
    elif payload.WhichOneof('value') == 'pandas_series_value':
        # Pandas will still read a single column csv as a DataFrame.
        df = pd.read_parquet(io.BytesIO(payload.pandas_series_value))
        return pd.Series(df[df.columns[0]])
    elif payload.WhichOneof('value') == 'polars_series_value':
        return pl.Series(pl.read_parquet(io.BytesIO(payload.polars_series_value)))
    elif payload.WhichOneof('value') == 'numpy_array_value':
        return np.load(io.BytesIO(payload.numpy_array_value), allow_pickle=False)
    elif payload.WhichOneof('value') == 'numpy_scalar_value':
        data = np.load(io.BytesIO(payload.numpy_scalar_value), allow_pickle=False)
        # As of Numpy 2.0.2, np.load for a numpy scalar yields a dimensionless array instead of a scalar
        data = data.dtype.type(data) # Restore the expected numpy scalar type.
        assert data.shape == ()  # Additional validation that the np.generic type remains solely for scalars
        assert isinstance(data, np.number) or isinstance(data, np.bool_)  # No support for bytes, strings, objects, etc
        return data
    elif payload.WhichOneof('value') == 'bytes_io_value':
        return io.BytesIO(payload.bytes_io_value)

    raise TypeError(f'Found unknown Payload case {payload.WhichOneof("value")}')

### Client code

class Client():
    '''
    Class which allows callers to make KaggleEvaluation requests.
    '''
    def __init__(self, channel_address: str='localhost'):
        self.channel_address = channel_address
        self.channel = grpc.insecure_channel(f'{channel_address}:{_GRPC_PORT}', options=_GRPC_CHANNEL_OPTIONS)
        self._made_first_connection = False
        self.endpoint_deadline_seconds = DEFAULT_DEADLINE_SECONDS
        self.stub = kaggle_evaluation_grpc.KaggleEvaluationServiceStub(self.channel)

    def _send_with_deadline(self, request):
        ''' Sends a message to the server while also:
        - Throwing an error as soon as the inference_server container has been shut down.
        - Setting a deadline of STARTUP_LIMIT_SECONDS for the inference_server to startup.
        '''
        if self._made_first_connection:
            return self.stub.Send(request, wait_for_ready=False, timeout=self.endpoint_deadline_seconds)

        first_call_time = time.time()
        # Allow time for the server to start as long as its container is running
        while time.time() - first_call_time < STARTUP_LIMIT_SECONDS:
            try:
                response = self.stub.Send(request, wait_for_ready=False)
                self._made_first_connection = True
                break
            except grpc._channel._InactiveRpcError as err:
                if 'StatusCode.UNAVAILABLE' not in str(err):
                    raise err
            # Confirm the inference_server container is still alive & it's worth waiting on the server.
            # If the inference_server container is no longer running this will throw a socket.gaierror.
            socket.gethostbyname(self.channel_address)
            time.sleep(_RETRY_SLEEP_SECONDS)

        if not self._made_first_connection:
            raise RuntimeError(f'Failed to connect to server after waiting {STARTUP_LIMIT_SECONDS} seconds')
        return response

    def serialize_request(self, name: str, *args, **kwargs) -> kaggle_evaluation_proto.KaggleEvaluationRequest:
        ''' Serialize a single request. Exists as a separate function from `send`
        to enable gateway concurrency for some competitions.
        '''
        already_serialized = (len(args) == 1) and isinstance(args[0], kaggle_evaluation_proto.KaggleEvaluationRequest)
        if already_serialized:
            return args[0]  # args is a tuple of length 1 containing the request
        return kaggle_evaluation_proto.KaggleEvaluationRequest(
                name=name,
                args=map(_serialize, args),
                kwargs={key: _serialize(value) for key, value in kwargs.items()}
        )

    def send(self, name: str, *args, **kwargs):
        '''Sends a single KaggleEvaluation request.

        Args:
            name: The endpoint name for the request.
            *args: Variable-length/type arguments to be supplied on the request.
            **kwargs: Key-value arguments to be supplied on the request.

        Returns:
            The response, which is of one of several allow-listed data types.
        '''
        request = self.serialize_request(name, *args, **kwargs)
        response = self._send_with_deadline(request)
        return _deserialize(response.payload)

    def close(self):
        self.channel.close()


### Server code

class KaggleEvaluationServiceServicer(kaggle_evaluation_grpc.KaggleEvaluationServiceServicer):
    '''
    Class which allows serving responses to KaggleEvaluation requests. The inference_server will run this service to listen for and respond
    to requests from the Gateway. The Gateway may also listen for requests from the inference_server in some cases.
    '''
    def __init__(self, listeners: List[callable]):
        self.listeners_map = dict((func.__name__, func) for func in listeners)

    # pylint: disable=unused-argument
    def Send(self, request: kaggle_evaluation_proto.KaggleEvaluationRequest, context: grpc.ServicerContext) -> kaggle_evaluation_proto.KaggleEvaluationResponse:
        '''Handler for gRPC requests that deserializes arguments, calls a user-registered function for handling the
        requested endpoint, then serializes and returns the response.

        Args:
            request: The KaggleEvaluationRequest protobuf message.
            context: (Unused) gRPC context.

        Returns:
            The KaggleEvaluationResponse protobuf message.

        Raises:
            NotImplementedError if the caller has not registered a handler for the requested endpoint.
        '''
        if request.name not in self.listeners_map:
            raise NotImplementedError(f'No listener for {request.name} was registered.')

        args = map(_deserialize, request.args)
        kwargs = {key: _deserialize(value) for key, value in request.kwargs.items()}
        response_function = self.listeners_map[request.name]
        response_payload = _serialize(response_function(*args, **kwargs))
        return kaggle_evaluation_proto.KaggleEvaluationResponse(payload=response_payload)

def define_server(*endpoint_listeners: Tuple[Callable]) -> grpc.server:
    '''Registers the endpoints that the container is able to respond to, then starts a server which listens for
    those endpoints. The endpoints that need to be implemented will depend on the specific competition.

    Args:
        endpoint_listeners: Tuple of functions that define how requests to the endpoint of the function name should be
            handled.

    Returns:
        The gRPC server object, which has been started. It should be stopped at exit time.

    Raises:
        ValueError if parameter values are invalid.
    '''
    if not endpoint_listeners:
        raise ValueError('Must pass at least one endpoint listener, e.g. `predict`')
    for func in endpoint_listeners:
        if not isinstance(func, Callable):
            raise ValueError('Endpoint listeners passed to `serve` must be functions')
        if func.__name__ == '<lambda>':
            raise ValueError('Functions passed as endpoint listeners must be named')

    server = grpc.server(futures.ThreadPoolExecutor(max_workers=1), options=_GRPC_CHANNEL_OPTIONS)
    kaggle_evaluation_grpc.add_KaggleEvaluationServiceServicer_to_server(KaggleEvaluationServiceServicer(endpoint_listeners), server)
    server.add_insecure_port(f'[::]:{_GRPC_PORT}')
    return server