File size: 2,283 Bytes
4c56ddf 5488167 4c56ddf 5488167 4c56ddf 5488167 4c56ddf 5488167 4c56ddf 5488167 4c56ddf 5488167 4c56ddf 5488167 4c56ddf 5488167 96ec844 5488167 4c56ddf 5488167 4c56ddf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
"""
ACE-Step: A Step Towards Music Generation Foundation Model
https://github.com/ace-step/ACE-Step
Apache 2.0 License
"""
import os
import click
@click.command()
@click.option(
"--checkpoint_path",
type=str,
default="",
help="Path to the checkpoint directory. Downloads automatically if empty.",
)
@click.option(
"--server_name",
type=str,
default="127.0.0.1",
help="The server name to use for the Gradio app.",
)
@click.option(
"--port", type=int, default=None, help="The port to use for the Gradio app."
)
@click.option("--device_id", type=int, default=0, help="The CUDA device ID to use.")
@click.option(
"--share",
type=click.BOOL,
default=False,
help="Whether to create a public, shareable link for the Gradio app.",
)
@click.option(
"--bf16",
type=click.BOOL,
default=True,
help="Whether to use bfloat16 precision. Turn off if using MPS.",
)
@click.option(
"--torch_compile", type=click.BOOL, default=False, help="Whether to use torch.compile."
)
@click.option(
"--cpu_offload", type=bool, default=False, help="Whether to use CPU offloading (only load current stage's model to GPU)"
)
@click.option(
"--overlapped_decode", type=bool, default=False, help="Whether to use overlapped decoding (run dcae and vocoder using sliding windows)"
)
def main(checkpoint_path, server_name, port, device_id, share, bf16, torch_compile, cpu_offload, overlapped_decode):
"""
Main function to launch the ACE Step pipeline demo.
"""
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
from acestep.ui.components import create_main_demo_ui
from acestep.pipeline_ace_step import ACEStepPipeline
from acestep.data_sampler import DataSampler
model_demo = ACEStepPipeline(
checkpoint_dir=checkpoint_path,
dtype="bfloat16" if bf16 else "float32",
torch_compile=torch_compile,
cpu_offload=cpu_offload,
overlapped_decode=overlapped_decode
)
data_sampler = DataSampler()
demo = create_main_demo_ui(
text2music_process_func=model_demo.__call__,
sample_data_func=data_sampler.sample,
load_data_func=data_sampler.load_json,
)
demo.queue().launch(inbrowser=True)
if __name__ == "__main__":
main()
|