|
import datetime |
|
import re |
|
import subprocess |
|
from pathlib import Path |
|
|
|
import pandas as pd |
|
import typer |
|
from datasets import get_dataset_config_names, load_dataset |
|
|
|
CSV_SCHEMA = { |
|
"banking_77": (5000, 2), |
|
"overruling": (2350, 2), |
|
"semiconductor_org_types": (449, 2), |
|
"ade_corpus_v2": (5000, 2), |
|
"twitter_complaints": (3399, 2), |
|
"neurips_impact_statement_risks": (150, 2), |
|
"systematic_review_inclusion": (2244, 2), |
|
"terms_of_service": (5000, 2), |
|
"tai_safety_research": (1639, 2), |
|
"one_stop_english": (518, 2), |
|
"tweet_eval_hate": (2966, 2), |
|
} |
|
|
|
app = typer.Typer() |
|
|
|
|
|
def _update_submission_name(submission_name: str): |
|
replacement = "" |
|
with open("README.md", "r") as f: |
|
lines = f.readlines() |
|
|
|
for line in lines: |
|
if line.startswith("submission_name:"): |
|
changes = re.sub(r"submission_name:.+", f"submission_name: {submission_name}", line) |
|
replacement += changes |
|
else: |
|
replacement += line |
|
|
|
with open("README.md", "w") as f: |
|
f.write(replacement) |
|
|
|
|
|
@app.command() |
|
def validate(): |
|
|
|
tasks = get_dataset_config_names("ought/raft") |
|
|
|
|
|
prediction_files = list(Path("data").rglob("predictions.csv")) |
|
mismatched_files = set(tasks).symmetric_difference(set([f.parent.name for f in prediction_files])) |
|
if mismatched_files: |
|
raise ValueError(f"Incorrect number of files! Expected {len(tasks)} files, but got {len(prediction_files)}.") |
|
|
|
|
|
|
|
shape_errors = [] |
|
column_errors = [] |
|
for prediction_file in prediction_files: |
|
df = pd.read_csv(prediction_file) |
|
incorrect_shape = df.shape != CSV_SCHEMA[prediction_file.parent.name] |
|
if incorrect_shape: |
|
shape_errors.append(prediction_file) |
|
incorrect_columns = sorted(df.columns) != ["ID", "Label"] |
|
if incorrect_columns: |
|
column_errors.append(prediction_file) |
|
|
|
if shape_errors: |
|
raise ValueError(f"Incorrect CSV shapes in files: {shape_errors}") |
|
|
|
if column_errors: |
|
raise ValueError(f"Incorrect CSV columns in files: {column_errors}") |
|
|
|
|
|
load_errors = [] |
|
for task in tasks: |
|
try: |
|
_ = load_dataset("../{{cookiecutter.repo_name}}", task) |
|
except Exception as e: |
|
load_errors.append(e) |
|
|
|
if load_errors: |
|
raise ValueError(f"Could not load predictions! Errors: {load_errors}") |
|
|
|
typer.echo("All submission files validated! β¨ π β¨") |
|
typer.echo("Now you can make a submission π€") |
|
|
|
|
|
@app.command() |
|
def submit(submission_name: str = typer.Option(..., prompt="Please provide a name for your submission, e.g. GPT-4 π")): |
|
subprocess.call("git pull origin main".split()) |
|
_update_submission_name(submission_name) |
|
subprocess.call(["git", "add", "data/*predictions.csv", "README.md"]) |
|
subprocess.call(["git", "commit", "-m", f"Submission: {submission_name} "]) |
|
subprocess.call(["git", "push"]) |
|
|
|
today = datetime.date.today() |
|
|
|
idx = (today.weekday() + 1) % 7 |
|
sun = today + datetime.timedelta(7 - idx) |
|
typer.echo("Submission successful! π π₯³ π") |
|
typer.echo(f"Your submission will be evaulated on {sun:%A %d %B %Y} β³") |
|
|
|
|
|
if __name__ == "__main__": |
|
app() |
|
|