'''Template for the two classes hosts should customize for each competition.''' import abc import os import pathlib import polars as pl import time import sys import traceback import warnings from typing import Callable, Generator, Tuple import kaggle_evaluation.core.base_gateway import kaggle_evaluation.core.relay _initial_import_time = time.time() _issued_startup_time_warning = False class Gateway(kaggle_evaluation.core.base_gateway.BaseGateway, abc.ABC): ''' Template to start with when writing a new gateway. In most cases, hosts should only need to write get_all_predictions. There are two main methods for sending data to the inference_server hosts should understand: - Small datasets: use `self.predict`. Competitors will receive the data passed to self.predict as Python objects in memory. This is just a wrapper for self.client.send(); you can write additional wrappers if necessary. - Large datasets: it's much faster to send data via self.share_files, which is equivalent to making files available via symlink. See base_gateway.BaseGateway.share_files for the full details. ''' @abc.abstractmethod def generate_data_batches(self) -> Generator: ''' Used by the default implementation of `get_all_predictions` so we can ensure `validate_prediction_batch` is run every time `predict` is called. This method must yield both the batch of data to be sent to `predict` and a series of row IDs to be sent to `validate_prediction_batch`. ''' raise NotImplementedError def get_all_predictions(self): all_predictions = [] all_row_ids = [] for data_batch, row_ids in self.generate_data_batches(): predictions = self.predict(*data_batch) predictions = pl.Series(self.target_column_name, predictions) self.validate_prediction_batch(predictions, row_ids) all_predictions.append(predictions) all_row_ids.append(row_ids) return all_predictions, all_row_ids def predict(self, *args, **kwargs): ''' self.predict will send all data in args and kwargs to the user container, and instruct the user container to generate a `predict` response. ''' try: return self.client.send('predict', *args, **kwargs) except Exception as e: self.handle_server_error(e, 'predict') def set_response_timeout_seconds(self, timeout_seconds: float): # Also store timeout_seconds in an easy place for for competitor to access. self.timeout_seconds = timeout_seconds # Set a response deadline that will apply after the very first repsonse self.client.endpoint_deadline_seconds = timeout_seconds def run(self) -> pathlib.Path: error = None submission_path = None try: predictions, row_ids = self.get_all_predictions() submission_path = self.write_submission(predictions, row_ids) except kaggle_evaluation.core.base_gateway.GatewayRuntimeError as gre: error = gre except Exception: # Get the full stack trace exc_type, exc_value, exc_traceback = sys.exc_info() error_str = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback)) error = kaggle_evaluation.core.base_gateway.GatewayRuntimeError( kaggle_evaluation.core.base_gateway.GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, error_str ) self.client.close() if self.server: self.server.stop(0) if kaggle_evaluation.core.base_gateway.IS_RERUN: self.write_result(error) elif error: # For local testing raise error return submission_path class InferenceServer(abc.ABC): ''' Base class for competition participants to inherit from when writing their submission. In most cases, users should only need to implement a `predict` function or other endpoints to pass to this class's constructor, and hosts will provide a mock Gateway for testing. ''' def __init__(self, endpoint_listeners: Tuple[Callable]): self.server = kaggle_evaluation.core.relay.define_server(endpoint_listeners) self.client = None # The inference_server can have a client but it isn't typically necessary. def serve(self): self.server.start() if os.getenv('KAGGLE_IS_COMPETITION_RERUN') is not None: self.server.wait_for_termination() # This will block all other code