Spaces:
Running
Running
'''Template for the two classes hosts should customize for each competition.''' | |
import abc | |
import os | |
import pathlib | |
import polars as pl | |
import time | |
import sys | |
import traceback | |
import warnings | |
from typing import Callable, Generator, Tuple | |
import kaggle_evaluation.core.base_gateway | |
import kaggle_evaluation.core.relay | |
_initial_import_time = time.time() | |
_issued_startup_time_warning = False | |
class Gateway(kaggle_evaluation.core.base_gateway.BaseGateway, abc.ABC): | |
''' | |
Template to start with when writing a new gateway. | |
In most cases, hosts should only need to write get_all_predictions. | |
There are two main methods for sending data to the inference_server hosts should understand: | |
- Small datasets: use `self.predict`. Competitors will receive the data passed to self.predict as | |
Python objects in memory. This is just a wrapper for self.client.send(); you can write additional | |
wrappers if necessary. | |
- Large datasets: it's much faster to send data via self.share_files, which is equivalent to making | |
files available via symlink. See base_gateway.BaseGateway.share_files for the full details. | |
''' | |
def generate_data_batches(self) -> Generator: | |
''' Used by the default implementation of `get_all_predictions` so we can | |
ensure `validate_prediction_batch` is run every time `predict` is called. | |
This method must yield both the batch of data to be sent to `predict` and a series | |
of row IDs to be sent to `validate_prediction_batch`. | |
''' | |
raise NotImplementedError | |
def get_all_predictions(self): | |
all_predictions = [] | |
all_row_ids = [] | |
for data_batch, row_ids in self.generate_data_batches(): | |
predictions = self.predict(*data_batch) | |
predictions = pl.Series(self.target_column_name, predictions) | |
self.validate_prediction_batch(predictions, row_ids) | |
all_predictions.append(predictions) | |
all_row_ids.append(row_ids) | |
return all_predictions, all_row_ids | |
def predict(self, *args, **kwargs): | |
''' self.predict will send all data in args and kwargs to the user container, and | |
instruct the user container to generate a `predict` response. | |
''' | |
try: | |
return self.client.send('predict', *args, **kwargs) | |
except Exception as e: | |
self.handle_server_error(e, 'predict') | |
def set_response_timeout_seconds(self, timeout_seconds: float): | |
# Also store timeout_seconds in an easy place for for competitor to access. | |
self.timeout_seconds = timeout_seconds | |
# Set a response deadline that will apply after the very first repsonse | |
self.client.endpoint_deadline_seconds = timeout_seconds | |
def run(self) -> pathlib.Path: | |
error = None | |
submission_path = None | |
try: | |
predictions, row_ids = self.get_all_predictions() | |
submission_path = self.write_submission(predictions, row_ids) | |
except kaggle_evaluation.core.base_gateway.GatewayRuntimeError as gre: | |
error = gre | |
except Exception: | |
# Get the full stack trace | |
exc_type, exc_value, exc_traceback = sys.exc_info() | |
error_str = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback)) | |
error = kaggle_evaluation.core.base_gateway.GatewayRuntimeError( | |
kaggle_evaluation.core.base_gateway.GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, | |
error_str | |
) | |
self.client.close() | |
if self.server: | |
self.server.stop(0) | |
if kaggle_evaluation.core.base_gateway.IS_RERUN: | |
self.write_result(error) | |
elif error: | |
# For local testing | |
raise error | |
return submission_path | |
class InferenceServer(abc.ABC): | |
''' | |
Base class for competition participants to inherit from when writing their submission. In most cases, users should | |
only need to implement a `predict` function or other endpoints to pass to this class's constructor, and hosts will | |
provide a mock Gateway for testing. | |
''' | |
def __init__(self, endpoint_listeners: Tuple[Callable]): | |
self.server = kaggle_evaluation.core.relay.define_server(endpoint_listeners) | |
self.client = None # The inference_server can have a client but it isn't typically necessary. | |
def serve(self): | |
self.server.start() | |
if os.getenv('KAGGLE_IS_COMPETITION_RERUN') is not None: | |
self.server.wait_for_termination() # This will block all other code | |