Spaces:
Paused
Paused
import inspect | |
import pathlib | |
from types import ModuleType | |
from kaggle_evaluation.core import relay, templates | |
from kaggle_evaluation.svg_gateway import SVGGateway | |
def test(model_cls: type, data_path: str | pathlib.Path | None = None) -> None: | |
'''Tests this competition's inference loop over the given Model class. | |
The provided Model class should have a `predict` function which accepts input(s) | |
and returns output(s) with the shapes and types required by this competition. | |
This function performs best-effort validation of this by running an inference | |
loop with a dummy test set over Model.predict. | |
By default the test set is taken from the `kaggle_evaluation` directory, but you | |
may override to another directory with the same test file structure via the | |
`data_path` arg.''' | |
print('Creating Model instance...') | |
model = model_cls() | |
if not hasattr(model, 'predict') or not inspect.ismethod(model.predict): | |
msg = f'Model does not have method predict.' | |
raise ValueError(msg) | |
print('Running inference tests...') | |
server = relay.define_server(model.predict) | |
server.start() | |
try: | |
gateway = SVGGateway(data_path) | |
submission_path = gateway.run() | |
print(f'Wrote test submission file to "{str(submission_path)}".') | |
except Exception as err: | |
raise err from None | |
finally: | |
server.stop(0) | |
print('Success!') | |
def _run_gateway() -> None: | |
'''Internal function for running the Gateway during a Kaggle scoring session. | |
Starts a scoring session which assumes existence of an Inference Server to return | |
inferences over the test set.''' | |
gateway = SVGGateway() | |
gateway.run() | |
def _run_inference_server(module: ModuleType) -> None: | |
'''Internal function for running the Inference Server during a Kaggle scoring session. | |
Takes the user's submitted, imported module and sets up the inference server exposing | |
their required method(s).''' | |
model = module.Model() | |
server = templates.InferenceServer(model.predict) | |
server.serve() |