|
import typer |
|
import torch |
|
import subprocess |
|
from pathlib import Path |
|
|
|
from expert import UpstreamExpert |
|
|
|
SUBMISSION_FILES = ["README.md", "expert.py", "model.pt"] |
|
SAMPLE_RATE = 16000 |
|
SECONDS = [2, 1.8, 3.7] |
|
|
|
app = typer.Typer() |
|
|
|
@app.command() |
|
def validate(): |
|
|
|
for file in SUBMISSION_FILES: |
|
if not Path(file).is_file(): |
|
raise ValueError(f"File {file} not found! Please include {file} in your submission") |
|
|
|
try: |
|
upstream = UpstreamExpert(ckpt="model.pt") |
|
wavs = [torch.rand(round(SAMPLE_RATE * sec)) for sec in SECONDS] |
|
results = upstream(wavs) |
|
|
|
assert isinstance(results, dict) |
|
tasks = ["PR", "SID", "ER", "ASR", "ASV", "SD", "QbE", "ST", "SS", "SE", "secret"] |
|
for task in tasks: |
|
hidden_states = results.get(task, "hidden_states") |
|
assert isinstance(hidden_states, list) |
|
|
|
for state in hidden_states: |
|
assert isinstance(state, torch.Tensor) |
|
assert state.dim() == 3, "(batch_size, max_sequence_length_of_batch, hidden_size)" |
|
assert state.shape == hidden_states[0].shape |
|
|
|
for task in tasks: |
|
downsample_rate = upstream.get_downsample_rates(task) |
|
assert isinstance(downsample_rate, int) |
|
print(f"The upstream's representation for {task}" |
|
f" has the downsample rate of {downsample_rate}.") |
|
except: |
|
print("Please check the Upstream Specification on https://superbbenchmark.org/challenge") |
|
raise |
|
|
|
typer.echo("All submission files validated!") |
|
typer.echo("Now you can make a submission.") |
|
|
|
|
|
@app.command() |
|
def submit(submission_name: str): |
|
subprocess.call("git pull origin main".split()) |
|
subprocess.call(["git", "add", "."]) |
|
subprocess.call(["git", "commit", "-m", f"Submission: {submission_name} "]) |
|
subprocess.call(["git", "push"]) |
|
typer.echo("Submission successful!") |
|
|
|
|
|
if __name__ == "__main__": |
|
app() |
|
|