leo19941227's picture
Submission: update-template
fa98f1c
raw
history blame
2.01 kB
import typer
import torch
import subprocess
from pathlib import Path
from expert import UpstreamExpert
SUBMISSION_FILES = ["README.md", "expert.py", "model.pt"]
SAMPLE_RATE = 16000
SECONDS = [2, 1.8, 3.7]
app = typer.Typer()
@app.command()
def validate():
# Check that all the expected files exist
for file in SUBMISSION_FILES:
if not Path(file).is_file():
raise ValueError(f"File {file} not found! Please include {file} in your submission")
try:
upstream = UpstreamExpert(ckpt="model.pt")
wavs = [torch.rand(round(SAMPLE_RATE * sec)) for sec in SECONDS]
results = upstream(wavs)
assert isinstance(results, dict)
tasks = ["PR", "SID", "ER", "ASR", "ASV", "SD", "QbE", "ST", "SS", "SE", "secret"]
for task in tasks:
hidden_states = results.get(task, "hidden_states")
assert isinstance(hidden_states, list)
for state in hidden_states:
assert isinstance(state, torch.Tensor)
assert state.dim() == 3, "(batch_size, max_sequence_length_of_batch, hidden_size)"
assert state.shape == hidden_states[0].shape
for task in tasks:
downsample_rate = upstream.get_downsample_rates(task)
assert isinstance(downsample_rate, int)
print(f"The upstream's representation for {task}"
f" has the downsample rate of {downsample_rate}.")
except:
print("Please check the Upstream Specification on https://superbbenchmark.org/challenge")
raise
typer.echo("All submission files validated!")
typer.echo("Now you can make a submission.")
@app.command()
def submit(submission_name: str):
subprocess.call("git pull origin main".split())
subprocess.call(["git", "add", "."])
subprocess.call(["git", "commit", "-m", f"Submission: {submission_name} "])
subprocess.call(["git", "push"])
typer.echo("Submission successful!")
if __name__ == "__main__":
app()