Spaces:
Running
Running
File size: 16,787 Bytes
15369ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 |
'''
Core implementation of the client module, implementing generic communication
patterns with Python in / Python out supporting many (nested) primitives +
special data science types like DataFrames or np.ndarrays, with gRPC + protobuf
as a backing implementation.
'''
import grpc
import io
import json
import socket
import time
from concurrent import futures
from typing import Callable, List, Tuple
import numpy as np
import pandas as pd
import polars as pl
import pyarrow
import kaggle_evaluation.core.generated.kaggle_evaluation_pb2 as kaggle_evaluation_proto
import kaggle_evaluation.core.generated.kaggle_evaluation_pb2_grpc as kaggle_evaluation_grpc
_SERVICE_CONFIG = {
# Service config proto: https://github.com/grpc/grpc-proto/blob/ec886024c2f7b7f597ba89d5b7d60c3f94627b17/grpc/service_config/service_config.proto#L377
'methodConfig': [
{
'name': [{}], # Applies to all methods
# See retry policy docs: https://grpc.io/docs/guides/retry/
'retryPolicy': {
'maxAttempts': 5,
'initialBackoff': '0.1s',
'maxBackoff': '1s',
'backoffMultiplier': 1, # Ensure relatively rapid feedback in the event of a crash
'retryableStatusCodes': ['UNAVAILABLE'],
},
}
]
}
_GRPC_PORT = 50051
_GRPC_CHANNEL_OPTIONS = [
# -1 for unlimited message send/receive size
# https://github.com/grpc/grpc/blob/v1.64.x/include/grpc/impl/channel_arg_names.h#L39
('grpc.max_send_message_length', -1),
('grpc.max_receive_message_length', -1),
# https://github.com/grpc/grpc/blob/master/doc/keepalive.md
('grpc.keepalive_time_ms', 60_000), # Time between heartbeat pings
('grpc.keepalive_timeout_ms', 5_000), # Time allowed to respond to pings
('grpc.http2.max_pings_without_data', 0), # Remove another cap on pings
('grpc.keepalive_permit_without_calls', 1), # Allow heartbeat pings at any time
('grpc.http2.min_ping_interval_without_data_ms', 1_000),
('grpc.service_config', json.dumps(_SERVICE_CONFIG)),
]
DEFAULT_DEADLINE_SECONDS = 60 * 60
_RETRY_SLEEP_SECONDS = 1
# Enforce a relatively strict server startup time so users can get feedback quickly if they're not
# configuring KaggleEvaluation correctly. We really don't want notebooks timing out after nine hours
# somebody forgot to start their inference_server. Slow steps like loading models
# can happen during the first inference call if necessary.
STARTUP_LIMIT_SECONDS = 60 * 15
### Utils shared by client and server for data transfer
# pl.Enum is currently unstable, but we should eventually consider supporting it.
# https://docs.pola.rs/api/python/stable/reference/api/polars.datatypes.Enum.html#polars.datatypes.Enum
_POLARS_TYPE_DENYLIST = set([pl.Enum, pl.Object, pl.Unknown])
def _serialize(data) -> kaggle_evaluation_proto.Payload:
'''Maps input data of one of several allow-listed types to a protobuf message to be sent over gRPC.
Args:
data: The input data to be mapped. Any of the types listed below are accepted.
Returns:
The Payload protobuf message.
Raises:
TypeError if data is of an unsupported type.
'''
# Python primitives and Numpy scalars
if isinstance(data, np.generic):
# Numpy functions that return a single number return numpy scalars instead of python primitives.
# In some cases this difference matters: https://numpy.org/devdocs/release/2.0.0-notes.html#representation-of-numpy-scalars-changed
# Ex: np.mean(1,2) yields np.float64(1.5) instead of 1.5.
# Check for numpy scalars first since most of them also inherit from python primitives.
# For example, `np.float64(1.5)` is an instance of `float` among many other things.
# https://numpy.org/doc/stable/reference/arrays.scalars.html
assert data.shape == () # Additional validation that the np.generic type remains solely for scalars
assert isinstance(data, np.number) or isinstance(data, np.bool_) # No support for bytes, strings, objects, etc
buffer = io.BytesIO()
np.save(buffer, data, allow_pickle=False)
return kaggle_evaluation_proto.Payload(numpy_scalar_value=buffer.getvalue())
elif isinstance(data, str):
return kaggle_evaluation_proto.Payload(str_value=data)
elif isinstance(data, bool): # bool is a subclass of int, so check that first
return kaggle_evaluation_proto.Payload(bool_value=data)
elif isinstance(data, int):
return kaggle_evaluation_proto.Payload(int_value=data)
elif isinstance(data, float):
return kaggle_evaluation_proto.Payload(float_value=data)
elif data is None:
return kaggle_evaluation_proto.Payload(none_value=True)
# Iterables for nested types
if isinstance(data, list):
return kaggle_evaluation_proto.Payload(list_value=kaggle_evaluation_proto.PayloadList(payloads=map(_serialize, data)))
elif isinstance(data, tuple):
return kaggle_evaluation_proto.Payload(tuple_value=kaggle_evaluation_proto.PayloadList(payloads=map(_serialize, data)))
elif isinstance(data, dict):
serialized_dict = {}
for key, value in data.items():
if not isinstance(key, str):
raise TypeError(f'KaggleEvaluation only supports dicts with keys of type str, found {type(key)}.')
serialized_dict[key] = _serialize(value)
return kaggle_evaluation_proto.Payload(dict_value=kaggle_evaluation_proto.PayloadMap(payload_map=serialized_dict))
# Allowlisted special types
if isinstance(data, pd.DataFrame):
buffer = io.BytesIO()
data.to_parquet(buffer, index=False, compression='lz4')
return kaggle_evaluation_proto.Payload(pandas_dataframe_value=buffer.getvalue())
elif isinstance(data, pl.DataFrame):
data_types = set(i.base_type() for i in data.dtypes)
banned_types = _POLARS_TYPE_DENYLIST.intersection(data_types)
if len(banned_types) > 0:
raise TypeError(f'Unsupported Polars data type(s): {banned_types}')
table = data.to_arrow()
buffer = io.BytesIO()
with pyarrow.ipc.new_stream(buffer, table.schema, options=pyarrow.ipc.IpcWriteOptions(compression='lz4')) as writer:
writer.write_table(table)
return kaggle_evaluation_proto.Payload(polars_dataframe_value=buffer.getvalue())
elif isinstance(data, pd.Series):
buffer = io.BytesIO()
# Can't serialize a pd.Series directly to parquet, must use intermediate DataFrame
pd.DataFrame(data).to_parquet(buffer, index=False, compression='lz4')
return kaggle_evaluation_proto.Payload(pandas_series_value=buffer.getvalue())
elif isinstance(data, pl.Series):
buffer = io.BytesIO()
# Can't serialize a pl.Series directly to parquet, must use intermediate DataFrame
pl.DataFrame(data).write_parquet(buffer, compression='lz4', statistics=False)
return kaggle_evaluation_proto.Payload(polars_series_value=buffer.getvalue())
elif isinstance(data, np.ndarray):
buffer = io.BytesIO()
np.save(buffer, data, allow_pickle=False)
return kaggle_evaluation_proto.Payload(numpy_array_value=buffer.getvalue())
elif isinstance(data, io.BytesIO):
return kaggle_evaluation_proto.Payload(bytes_io_value=data.getvalue())
raise TypeError(f'Type {type(data)} not supported for KaggleEvaluation.')
def _deserialize(payload: kaggle_evaluation_proto.Payload):
'''Maps a Payload protobuf message to a value of whichever type was set on the message.
Args:
payload: The message to be mapped.
Returns:
A value of one of several allow-listed types.
Raises:
TypeError if an unexpected value data type is found.
'''
# Primitives
if payload.WhichOneof('value') == 'str_value':
return payload.str_value
elif payload.WhichOneof('value') == 'bool_value':
return payload.bool_value
elif payload.WhichOneof('value') == 'int_value':
return payload.int_value
elif payload.WhichOneof('value') == 'float_value':
return payload.float_value
elif payload.WhichOneof('value') == 'none_value':
return None
# Iterables for nested types
elif payload.WhichOneof('value') == 'list_value':
return list(map(_deserialize, payload.list_value.payloads))
elif payload.WhichOneof('value') == 'tuple_value':
return tuple(map(_deserialize, payload.tuple_value.payloads))
elif payload.WhichOneof('value') == 'dict_value':
return {key: _deserialize(value) for key, value in payload.dict_value.payload_map.items()}
# Allowlisted special types
elif payload.WhichOneof('value') == 'pandas_dataframe_value':
return pd.read_parquet(io.BytesIO(payload.pandas_dataframe_value))
elif payload.WhichOneof('value') == 'polars_dataframe_value':
with pyarrow.ipc.open_stream(payload.polars_dataframe_value) as reader:
table = reader.read_all()
return pl.from_arrow(table)
elif payload.WhichOneof('value') == 'pandas_series_value':
# Pandas will still read a single column csv as a DataFrame.
df = pd.read_parquet(io.BytesIO(payload.pandas_series_value))
return pd.Series(df[df.columns[0]])
elif payload.WhichOneof('value') == 'polars_series_value':
return pl.Series(pl.read_parquet(io.BytesIO(payload.polars_series_value)))
elif payload.WhichOneof('value') == 'numpy_array_value':
return np.load(io.BytesIO(payload.numpy_array_value), allow_pickle=False)
elif payload.WhichOneof('value') == 'numpy_scalar_value':
data = np.load(io.BytesIO(payload.numpy_scalar_value), allow_pickle=False)
# As of Numpy 2.0.2, np.load for a numpy scalar yields a dimensionless array instead of a scalar
data = data.dtype.type(data) # Restore the expected numpy scalar type.
assert data.shape == () # Additional validation that the np.generic type remains solely for scalars
assert isinstance(data, np.number) or isinstance(data, np.bool_) # No support for bytes, strings, objects, etc
return data
elif payload.WhichOneof('value') == 'bytes_io_value':
return io.BytesIO(payload.bytes_io_value)
raise TypeError(f'Found unknown Payload case {payload.WhichOneof("value")}')
### Client code
class Client():
'''
Class which allows callers to make KaggleEvaluation requests.
'''
def __init__(self, channel_address: str='localhost'):
self.channel_address = channel_address
self.channel = grpc.insecure_channel(f'{channel_address}:{_GRPC_PORT}', options=_GRPC_CHANNEL_OPTIONS)
self._made_first_connection = False
self.endpoint_deadline_seconds = DEFAULT_DEADLINE_SECONDS
self.stub = kaggle_evaluation_grpc.KaggleEvaluationServiceStub(self.channel)
def _send_with_deadline(self, request):
''' Sends a message to the server while also:
- Throwing an error as soon as the inference_server container has been shut down.
- Setting a deadline of STARTUP_LIMIT_SECONDS for the inference_server to startup.
'''
if self._made_first_connection:
return self.stub.Send(request, wait_for_ready=False, timeout=self.endpoint_deadline_seconds)
first_call_time = time.time()
# Allow time for the server to start as long as its container is running
while time.time() - first_call_time < STARTUP_LIMIT_SECONDS:
try:
response = self.stub.Send(request, wait_for_ready=False)
self._made_first_connection = True
break
except grpc._channel._InactiveRpcError as err:
if 'StatusCode.UNAVAILABLE' not in str(err):
raise err
# Confirm the inference_server container is still alive & it's worth waiting on the server.
# If the inference_server container is no longer running this will throw a socket.gaierror.
socket.gethostbyname(self.channel_address)
time.sleep(_RETRY_SLEEP_SECONDS)
if not self._made_first_connection:
raise RuntimeError(f'Failed to connect to server after waiting {STARTUP_LIMIT_SECONDS} seconds')
return response
def serialize_request(self, name: str, *args, **kwargs) -> kaggle_evaluation_proto.KaggleEvaluationRequest:
''' Serialize a single request. Exists as a separate function from `send`
to enable gateway concurrency for some competitions.
'''
already_serialized = (len(args) == 1) and isinstance(args[0], kaggle_evaluation_proto.KaggleEvaluationRequest)
if already_serialized:
return args[0] # args is a tuple of length 1 containing the request
return kaggle_evaluation_proto.KaggleEvaluationRequest(
name=name,
args=map(_serialize, args),
kwargs={key: _serialize(value) for key, value in kwargs.items()}
)
def send(self, name: str, *args, **kwargs):
'''Sends a single KaggleEvaluation request.
Args:
name: The endpoint name for the request.
*args: Variable-length/type arguments to be supplied on the request.
**kwargs: Key-value arguments to be supplied on the request.
Returns:
The response, which is of one of several allow-listed data types.
'''
request = self.serialize_request(name, *args, **kwargs)
response = self._send_with_deadline(request)
return _deserialize(response.payload)
def close(self):
self.channel.close()
### Server code
class KaggleEvaluationServiceServicer(kaggle_evaluation_grpc.KaggleEvaluationServiceServicer):
'''
Class which allows serving responses to KaggleEvaluation requests. The inference_server will run this service to listen for and respond
to requests from the Gateway. The Gateway may also listen for requests from the inference_server in some cases.
'''
def __init__(self, listeners: List[callable]):
self.listeners_map = dict((func.__name__, func) for func in listeners)
# pylint: disable=unused-argument
def Send(self, request: kaggle_evaluation_proto.KaggleEvaluationRequest, context: grpc.ServicerContext) -> kaggle_evaluation_proto.KaggleEvaluationResponse:
'''Handler for gRPC requests that deserializes arguments, calls a user-registered function for handling the
requested endpoint, then serializes and returns the response.
Args:
request: The KaggleEvaluationRequest protobuf message.
context: (Unused) gRPC context.
Returns:
The KaggleEvaluationResponse protobuf message.
Raises:
NotImplementedError if the caller has not registered a handler for the requested endpoint.
'''
if request.name not in self.listeners_map:
raise NotImplementedError(f'No listener for {request.name} was registered.')
args = map(_deserialize, request.args)
kwargs = {key: _deserialize(value) for key, value in request.kwargs.items()}
response_function = self.listeners_map[request.name]
response_payload = _serialize(response_function(*args, **kwargs))
return kaggle_evaluation_proto.KaggleEvaluationResponse(payload=response_payload)
def define_server(*endpoint_listeners: Tuple[Callable]) -> grpc.server:
'''Registers the endpoints that the container is able to respond to, then starts a server which listens for
those endpoints. The endpoints that need to be implemented will depend on the specific competition.
Args:
endpoint_listeners: Tuple of functions that define how requests to the endpoint of the function name should be
handled.
Returns:
The gRPC server object, which has been started. It should be stopped at exit time.
Raises:
ValueError if parameter values are invalid.
'''
if not endpoint_listeners:
raise ValueError('Must pass at least one endpoint listener, e.g. `predict`')
for func in endpoint_listeners:
if not isinstance(func, Callable):
raise ValueError('Endpoint listeners passed to `serve` must be functions')
if func.__name__ == '<lambda>':
raise ValueError('Functions passed as endpoint listeners must be named')
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1), options=_GRPC_CHANNEL_OPTIONS)
kaggle_evaluation_grpc.add_KaggleEvaluationServiceServicer_to_server(KaggleEvaluationServiceServicer(endpoint_listeners), server)
server.add_insecure_port(f'[::]:{_GRPC_PORT}')
return server
|