Spaces:
Running
Running
"""Gateway notebook for SVG Image Generation""" | |
import os | |
import tempfile | |
from pathlib import Path | |
from typing import Any | |
import pandas as pd | |
import polars as pl | |
from kaggle_evaluation.core.base_gateway import GatewayRuntimeError, GatewayRuntimeErrorType, IS_RERUN | |
import kaggle_evaluation.core.templates | |
from kaggle_evaluation.svg_constraints import SVGConstraints | |
class SVGGateway(kaggle_evaluation.core.templates.Gateway): | |
def __init__(self, data_path: str | Path | None = None): | |
super().__init__(target_column_name='svg') | |
self.set_response_timeout_seconds(60 * 5) | |
self.row_id_column_name = 'id' | |
self.data_path: Path = Path(data_path) if data_path else Path(__file__).parent | |
self.constraints: SVGConstraints = SVGConstraints() | |
def generate_data_batches(self): | |
test = pl.read_csv(self.data_path / 'test.csv') | |
for _, group in test.group_by('id'): | |
yield group.item(0, 0), group.item(0, 1) # id, description | |
def get_all_predictions(self): | |
row_ids, predictions = [], [] | |
for id, description in self.generate_data_batches(): | |
svg = self.predict(description) | |
if not isinstance(svg, str): | |
raise ValueError("Predicted SVG must have `str` type.") | |
self.validate(svg) | |
row_ids.append(id) | |
predictions.append(svg) | |
return predictions, row_ids | |
def validate(self, svg: str): | |
try: | |
self.constraints.validate_svg(svg) | |
except ValueError as err: | |
msg = f'SVG failed validation: {str(err)}' | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, msg) | |
def write_submission(self, predictions: list, row_ids: list) -> Path: | |
predictions = pl.DataFrame( | |
data={ | |
self.row_id_column_name: row_ids, | |
self.target_column_name: predictions, | |
} | |
) | |
submission_path = Path('/kaggle/working/submission.csv') | |
if not IS_RERUN: | |
with tempfile.NamedTemporaryFile(prefix='kaggle-evaluation-submission-', suffix='.csv', delete=False, mode='w+') as f: | |
submission_path = Path(f.name) | |
predictions.write_csv(submission_path) | |
return submission_path | |