Spaces:
Paused
Paused
File size: 12,317 Bytes
15369ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
''' Lower level implementation details of the gateway.
Hosts should not need to review this file before writing their competition specific gateway.
'''
import enum
import json
import os
import pathlib
import re
import subprocess
import tempfile
from socket import gaierror
from typing import Any, List, Optional, Tuple, Union
import grpc
import numpy as np
import pandas as pd
import polars as pl
import kaggle_evaluation.core.relay
_FILE_SHARE_DIR = '/kaggle/shared/'
IS_RERUN = os.getenv('KAGGLE_IS_COMPETITION_RERUN') is not None
class GatewayRuntimeErrorType(enum.Enum):
''' Allow-listed error types that Gateways can raise, which map to canned error messages to show users.'''
UNSPECIFIED = 0
SERVER_NEVER_STARTED = 1
SERVER_CONNECTION_FAILED = 2
SERVER_RAISED_EXCEPTION = 3
SERVER_MISSING_ENDPOINT = 4
# Default error type if an exception was raised that was not explicitly handled by the Gateway
GATEWAY_RAISED_EXCEPTION = 5
INVALID_SUBMISSION = 6
class GatewayRuntimeError(Exception):
''' Gateways can raise this error to capture a user-visible error enum from above and host-visible error details.'''
def __init__(self, error_type: GatewayRuntimeErrorType, error_details: Optional[str]=None):
self.error_type = error_type
self.error_details = error_details
class BaseGateway():
def __init__(self, target_column_name: Optional[str]=None):
self.client = kaggle_evaluation.core.relay.Client('inference_server' if IS_RERUN else 'localhost')
self.server = None # The gateway can have a server but it isn't typically necessary.
self.target_column_name = target_column_name # Only used if the predictions are made as a primitive type (int, bool, etc) rather than a dataframe.
def validate_prediction_batch(
self,
prediction_batch: Any,
row_ids: Union[pl.DataFrame, pl.Series, pd.DataFrame, pd.Series]
):
''' If competitors can submit fewer rows than expected they can save all predictions for the last batch and
bypass the benefits of the Kaggle evaluation service. This attack was seen in a real competition with the older time series API:
https://www.kaggle.com/competitions/riiid-test-answer-prediction/discussion/196066
It's critically important that this check be run every time predict() is called.
If your predictions may take a variable number of rows and you need to write a custom version of this check,
you still must specify a minimum row count greater than zero per prediction batch.
'''
if prediction_batch is None:
raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'No prediction received')
num_received_rows = None
# Special handling for numpy ints only as numpy floats are python floats, but numpy ints aren't python ints
for primitive_type in [int, float, str, bool, np.int_]:
if isinstance(prediction_batch, primitive_type):
# Types that only support one predictions per batch don't need to be validated.
# Basic types are valid for prediction, but either don't have a length (int) or the length isn't relevant for
# purposes of this check (str).
num_received_rows = 1
if num_received_rows is None:
if type(prediction_batch) not in [pl.DataFrame, pl.Series, pd.DataFrame, pd.Series]:
raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, f'Invalid prediction data type, received: {type(prediction_batch)}')
num_received_rows = len(prediction_batch)
if type(row_ids) not in [pl.DataFrame, pl.Series, pd.DataFrame, pd.Series]:
raise GatewayRuntimeError(GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, f'Invalid row ID type {type(row_ids)}; expected Polars DataFrame or similar')
num_expected_rows = len(row_ids)
if len(row_ids) == 0:
raise GatewayRuntimeError(GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, 'Missing row IDs for batch')
if num_received_rows != num_expected_rows:
raise GatewayRuntimeError(
GatewayRuntimeErrorType.INVALID_SUBMISSION,
f'Invalid predictions: expected {num_expected_rows} rows but received {num_received_rows}'
)
def _standardize_and_validate_paths(
self,
input_paths: List[Union[str, pathlib.Path]]
) -> List[pathlib.Path]:
# Accept a list of str or pathlib.Path, but standardize on list of str
for path in input_paths:
if os.pardir in str(path):
raise ValueError(f'Send files path contains {os.pardir}: {path}')
if str(path) != str(os.path.normpath(path)):
# Raise an error rather than sending users unexpectedly altered paths
raise ValueError(f'Send files path {path} must be normalized. See `os.path.normpath`')
if type(path) not in (pathlib.Path, str):
raise ValueError('All paths must be of type str or pathlib.Path')
if not os.path.exists(path):
raise ValueError(f'Input path {path} does not exist')
input_paths = [os.path.abspath(path) for path in input_paths]
if len(set(input_paths)) != len(input_paths):
raise ValueError('Duplicate input paths found')
if not self.file_share_dir.endswith(os.path.sep):
# Ensure output dir is valid for later use
output_dir = self.file_share_dir + os.path.sep
if not os.path.exists(self.file_share_dir) or not os.path.isdir(self.file_share_dir):
raise ValueError(f'Invalid output directory {self.file_share_dir}')
# Can't use os.path.join for output_dir + path: os.path.join won't prepend to an abspath
output_paths = [output_dir + path for path in input_paths]
return input_paths, output_paths
def share_files(
self,
input_paths: List[Union[str, pathlib.Path]],
) -> List[str]:
''' Makes files and/or directories available to the user's inference_server. They will be mirrored under the
self.file_share_dir directory, using the full absolute path. An input like:
/kaggle/input/mycomp/test.csv
Would be written to:
/kaggle/shared/kaggle/input/mycomp/test.csv
Args:
input_paths: List of paths to files and/or directories that should be shared.
Returns:
The output paths that were shared.
Raises:
ValueError if any invalid paths are passed.
'''
input_paths, output_paths = self._standardize_and_validate_paths(input_paths)
for in_path, out_path in zip(input_paths, output_paths):
os.makedirs(os.path.dirname(out_path), exist_ok=True)
# This makes the files available to the InferenceServer as read-only. Only the Gateway can mount files.
# mount will only work in live kaggle evaluation rerun sessions. Otherwise use a symlink.
if IS_RERUN:
if not os.path.isdir(out_path):
pathlib.Path(out_path).touch()
subprocess.run(f'mount --bind {in_path} {out_path}', shell=True, check=True)
else:
subprocess.run(f'ln -s {in_path} {out_path}', shell=True, check=True)
return output_paths
def write_submission(self, predictions, row_ids: List[Union[pl.Series, pl.DataFrame, pd.Series, pd.DataFrame]]) -> pathlib.Path:
''' Export the predictions to a submission file.'''
if isinstance(predictions, list):
if isinstance(predictions[0], pd.DataFrame):
predictions = pd.concat(predictions, ignore_index=True)
elif isinstance(predictions[0], pl.DataFrame):
try:
predictions = pl.concat(predictions, how='vertical_relaxed')
except pl.exceptions.SchemaError:
raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'Inconsistent prediction types')
except pl.exceptions.ComputeError:
raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'Inconsistent prediction column counts')
elif isinstance(predictions[0], pl.Series):
try:
predictions = pl.concat(predictions, how='vertical')
except pl.exceptions.SchemaError:
raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'Inconsistent prediction types')
except pl.exceptions.ComputeError:
raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'Inconsistent prediction column counts')
if type(row_ids[0]) in [pl.Series, pl.DataFrame]:
row_ids = pl.concat(row_ids)
elif type(row_ids[0]) in [pd.Series, pd.DataFrame]:
row_ids = pd.concat(row_ids).reset_index(drop=True)
else:
raise GatewayRuntimeError(GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, f'Invalid row ID datatype {type(row_ids[0])}. Expected Polars series or dataframe.')
if self.target_column_name is None:
raise GatewayRuntimeError(GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, '`target_column_name` must be set in order to use scalar value predictions.')
predictions = pl.DataFrame(data={row_ids.columns[0]: row_ids, self.target_column_name: predictions})
submission_path = pathlib.Path('/kaggle/working/submission.csv')
if not IS_RERUN:
with tempfile.NamedTemporaryFile(prefix='kaggle-evaluation-submission-', suffix='.csv', delete=False, mode='w+') as f:
submission_path = pathlib.Path(f.name)
if isinstance(predictions, pd.DataFrame):
predictions.to_csv(submission_path, index=False)
elif isinstance(predictions, pl.DataFrame):
pl.DataFrame(predictions).write_csv(submission_path)
else:
raise ValueError(f"Unsupported predictions type {type(predictions)}; can't write submission file")
return submission_path
def write_result(self, error: Optional[GatewayRuntimeError]=None):
''' Export a result.json containing error details if applicable.'''
result = { 'Succeeded': error is None }
if error is not None:
result['ErrorType'] = error.error_type.value
result['ErrorName'] = error.error_type.name
# Max error detail length is 8000
result['ErrorDetails'] = str(error.error_details[:8000]) if error.error_details else None
with open('result.json', 'w') as f_open:
json.dump(result, f_open)
def handle_server_error(self, exception: Exception, endpoint: str):
''' Determine how to handle an exception raised when calling the inference server. Typically just format the
error into a GatewayRuntimeError and raise.
'''
exception_str = str(exception)
if isinstance(exception, gaierror) or (isinstance(exception, RuntimeError) and 'Failed to connect to server after waiting' in exception_str):
raise GatewayRuntimeError(GatewayRuntimeErrorType.SERVER_NEVER_STARTED) from None
if f'No listener for {endpoint} was registered' in exception_str:
raise GatewayRuntimeError(GatewayRuntimeErrorType.SERVER_MISSING_ENDPOINT, f'Server did not register a listener for {endpoint}') from None
if 'Exception calling application' in exception_str:
# Extract just the exception message raised by the inference server
message_match = re.search('"Exception calling application: (.*)"', exception_str, re.IGNORECASE)
message = message_match.group(1) if message_match else exception_str
raise GatewayRuntimeError(GatewayRuntimeErrorType.SERVER_RAISED_EXCEPTION, message) from None
if isinstance(exception, grpc._channel._InactiveRpcError):
raise GatewayRuntimeError(GatewayRuntimeErrorType.SERVER_CONNECTION_FAILED, exception_str) from None
raise exception
|