|
import typer |
|
import torch |
|
import subprocess |
|
from pathlib import Path |
|
|
|
from expert import UpstreamExpert |
|
|
|
SUBMISSION_FILES = ["expert.py", "model.pt"] |
|
SAMPLE_RATE = 16000 |
|
SECONDS = [2, 1.8, 3.7] |
|
|
|
app = typer.Typer() |
|
|
|
@app.command() |
|
def validate(): |
|
|
|
for file in SUBMISSION_FILES: |
|
if not Path(file).is_file(): |
|
raise ValueError(f"File {file} not found! Please include {file} in your submission") |
|
|
|
try: |
|
upstream = UpstreamExpert(ckpt="model.pt") |
|
samples = [round(SAMPLE_RATE * sec) for sec in SECONDS] |
|
wavs = [torch.rand(sample) for sample in samples] |
|
results = upstream(wavs) |
|
|
|
assert isinstance(results, dict) |
|
tasks = ["PR", "SID", "ER", "ASR", "ASV", "SD", "QbE", "ST", "SS", "SE", "secret"] |
|
for task in tasks: |
|
hidden_states = results.get(task, results["hidden_states"]) |
|
assert isinstance(hidden_states, list) |
|
|
|
for state in hidden_states: |
|
assert isinstance(state, torch.Tensor) |
|
assert state.dim() == 3, "(batch_size, max_sequence_length_of_batch, hidden_size)" |
|
assert state.shape == hidden_states[0].shape |
|
|
|
downsample_rate = upstream.get_downsample_rates(task) |
|
assert isinstance(downsample_rate, int) |
|
assert abs(round(max(samples) / downsample_rate) - hidden_states[0].size(1)) < 5, "wrong downsample rate" |
|
|
|
except: |
|
print("Please check the Upstream Specification on https://superbbenchmark.org/challenge-slt2022/upstream") |
|
raise |
|
|
|
typer.echo("All submission files validated!") |
|
typer.echo("Now you can upload these files to huggingface's Hub.") |
|
|
|
|
|
@app.command() |
|
def upload(commit_message: str): |
|
subprocess.call("git pull origin main".split()) |
|
subprocess.call(["git", "add", "."]) |
|
subprocess.call(["git", "commit", "-m", f"Upload Upstream: {commit_message} "]) |
|
subprocess.call(["git", "push"]) |
|
typer.echo("Upload successful!") |
|
typer.echo("Please go to https://superbbenchmark.org/submit to make a submission with the following information:") |
|
typer.echo("1. Organization Name") |
|
typer.echo("2. Repository Name") |
|
typer.echo("3. Commit Hash (full 40 characters)") |
|
typer.echo("These information can be shown by: python cli.py info") |
|
|
|
@app.command() |
|
def info(): |
|
result = subprocess.run(["git", "config", "--get", "remote.origin.url"], capture_output=True) |
|
url = result.stdout.decode("utf-8").strip() |
|
organization = url.split("/")[-2] |
|
repo = url.split("/")[-1] |
|
|
|
result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True) |
|
commit_hash = result.stdout.decode("utf-8").strip() |
|
|
|
typer.echo(f"Organization Name: {organization}") |
|
typer.echo(f"Repository Name: {repo}") |
|
typer.echo(f"Commit Hash: {commit_hash}") |
|
|
|
if __name__ == "__main__": |
|
app() |
|
|