File size: 2,898 Bytes
6452bf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import typer
import torch
import subprocess
from pathlib import Path

from expert import UpstreamExpert

SUBMISSION_FILES = ["expert.py", "model.pt"]
SAMPLE_RATE = 16000
SECONDS = [2, 1.8, 3.7]

app = typer.Typer()

@app.command()
def validate():
    # Check that all the expected files exist
    for file in SUBMISSION_FILES:
        if not Path(file).is_file():
            raise ValueError(f"File {file} not found! Please include {file} in your submission")

    try:
        upstream = UpstreamExpert(ckpt="model.pt")
        samples = [round(SAMPLE_RATE * sec) for sec in SECONDS]
        wavs = [torch.rand(sample) for sample in samples]
        results = upstream(wavs)

        assert isinstance(results, dict)
        tasks = ["PR", "SID", "ER", "ASR", "ASV", "SD", "QbE", "ST", "SS", "SE", "secret"]
        for task in tasks:
            hidden_states = results.get(task, results["hidden_states"])
            assert isinstance(hidden_states, list)

            for state in hidden_states:
                assert isinstance(state, torch.Tensor)
                assert state.dim() == 3, "(batch_size, max_sequence_length_of_batch, hidden_size)"
                assert state.shape == hidden_states[0].shape

            downsample_rate = upstream.get_downsample_rates(task)
            assert isinstance(downsample_rate, int)
            assert abs(round(max(samples) / downsample_rate) - hidden_states[0].size(1)) < 5, "wrong downsample rate"

    except:
        print("Please check the Upstream Specification on https://superbbenchmark.org/challenge-slt2022/upstream")
        raise

    typer.echo("All submission files validated!")
    typer.echo("Now you can upload these files to huggingface's Hub.")


@app.command()
def upload(commit_message: str):
    subprocess.call("git pull origin main".split())
    subprocess.call(["git", "add", "."])
    subprocess.call(["git", "commit", "-m", f"Upload Upstream: {commit_message} "])
    subprocess.call(["git", "push"])
    typer.echo("Upload successful!")
    typer.echo("Please go to https://superbbenchmark.org/submit to make a submission with the following information:")
    typer.echo("1. Organization Name")
    typer.echo("2. Repository Name")
    typer.echo("3. Commit Hash (full 40 characters)")
    typer.echo("These information can be shown by: python cli.py info")

@app.command()
def info():
    result = subprocess.run(["git", "config", "--get", "remote.origin.url"], capture_output=True)
    url = result.stdout.decode("utf-8").strip()
    organization = url.split("/")[-2]
    repo = url.split("/")[-1]

    result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True)
    commit_hash = result.stdout.decode("utf-8").strip()

    typer.echo(f"Organization Name: {organization}")
    typer.echo(f"Repository Name: {repo}")
    typer.echo(f"Commit Hash: {commit_hash}")

if __name__ == "__main__":
    app()