File size: 2,283 Bytes
4c56ddf
 
5488167
4c56ddf
5488167
4c56ddf
 
5488167
4c56ddf
 
5488167
4c56ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5488167
4c56ddf
5488167
4c56ddf
 
 
5488167
 
4c56ddf
 
 
 
 
5488167
 
 
 
 
 
96ec844
5488167
4c56ddf
5488167
 
 
4c56ddf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
ACE-Step: A Step Towards Music Generation Foundation Model

https://github.com/ace-step/ACE-Step

Apache 2.0 License
"""

import os
import click

@click.command()
@click.option(
    "--checkpoint_path",
    type=str,
    default="",
    help="Path to the checkpoint directory. Downloads automatically if empty.",
)
@click.option(
    "--server_name",
    type=str,
    default="127.0.0.1",
    help="The server name to use for the Gradio app.",
)
@click.option(
    "--port", type=int, default=None, help="The port to use for the Gradio app."
)
@click.option("--device_id", type=int, default=0, help="The CUDA device ID to use.")
@click.option(
    "--share",
    type=click.BOOL,
    default=False,
    help="Whether to create a public, shareable link for the Gradio app.",
)
@click.option(
    "--bf16",
    type=click.BOOL,
    default=True,
    help="Whether to use bfloat16 precision. Turn off if using MPS.",
)
@click.option(
    "--torch_compile", type=click.BOOL, default=False, help="Whether to use torch.compile."
)
@click.option(
    "--cpu_offload", type=bool, default=False, help="Whether to use CPU offloading (only load current stage's model to GPU)"
)
@click.option(
    "--overlapped_decode", type=bool, default=False, help="Whether to use overlapped decoding (run dcae and vocoder using sliding windows)"
)
def main(checkpoint_path, server_name, port, device_id, share, bf16, torch_compile, cpu_offload, overlapped_decode):
    """
    Main function to launch the ACE Step pipeline demo.
    """

    os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)

    from acestep.ui.components import create_main_demo_ui
    from acestep.pipeline_ace_step import ACEStepPipeline
    from acestep.data_sampler import DataSampler

    model_demo = ACEStepPipeline(
        checkpoint_dir=checkpoint_path,
        dtype="bfloat16" if bf16 else "float32",
        torch_compile=torch_compile,
        cpu_offload=cpu_offload,
        overlapped_decode=overlapped_decode
    )
    data_sampler = DataSampler()

    demo = create_main_demo_ui(
        text2music_process_func=model_demo.__call__,
        sample_data_func=data_sampler.sample,
        load_data_func=data_sampler.load_json,
    )
    demo.queue().launch(inbrowser=True)


if __name__ == "__main__":
    main()