Jinglong Xiong
first commit
15369ca
"""Gateway notebook for SVG Image Generation"""
import os
import tempfile
from pathlib import Path
from typing import Any
import pandas as pd
import polars as pl
from kaggle_evaluation.core.base_gateway import GatewayRuntimeError, GatewayRuntimeErrorType, IS_RERUN
import kaggle_evaluation.core.templates
from kaggle_evaluation.svg_constraints import SVGConstraints
class SVGGateway(kaggle_evaluation.core.templates.Gateway):
def __init__(self, data_path: str | Path | None = None):
super().__init__(target_column_name='svg')
self.set_response_timeout_seconds(60 * 5)
self.row_id_column_name = 'id'
self.data_path: Path = Path(data_path) if data_path else Path(__file__).parent
self.constraints: SVGConstraints = SVGConstraints()
def generate_data_batches(self):
test = pl.read_csv(self.data_path / 'test.csv')
for _, group in test.group_by('id'):
yield group.item(0, 0), group.item(0, 1) # id, description
def get_all_predictions(self):
row_ids, predictions = [], []
for id, description in self.generate_data_batches():
svg = self.predict(description)
if not isinstance(svg, str):
raise ValueError("Predicted SVG must have `str` type.")
self.validate(svg)
row_ids.append(id)
predictions.append(svg)
return predictions, row_ids
def validate(self, svg: str):
try:
self.constraints.validate_svg(svg)
except ValueError as err:
msg = f'SVG failed validation: {str(err)}'
raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, msg)
def write_submission(self, predictions: list, row_ids: list) -> Path:
predictions = pl.DataFrame(
data={
self.row_id_column_name: row_ids,
self.target_column_name: predictions,
}
)
submission_path = Path('/kaggle/working/submission.csv')
if not IS_RERUN:
with tempfile.NamedTemporaryFile(prefix='kaggle-evaluation-submission-', suffix='.csv', delete=False, mode='w+') as f:
submission_path = Path(f.name)
predictions.write_csv(submission_path)
return submission_path