""" ACE-Step: A Step Towards Music Generation Foundation Model https://github.com/ace-step/ACE-Step Apache 2.0 License """ import os import click @click.command() @click.option( "--checkpoint_path", type=str, default="", help="Path to the checkpoint directory. Downloads automatically if empty.", ) @click.option( "--server_name", type=str, default="127.0.0.1", help="The server name to use for the Gradio app.", ) @click.option( "--port", type=int, default=None, help="The port to use for the Gradio app." ) @click.option("--device_id", type=int, default=0, help="The CUDA device ID to use.") @click.option( "--share", type=click.BOOL, default=False, help="Whether to create a public, shareable link for the Gradio app.", ) @click.option( "--bf16", type=click.BOOL, default=True, help="Whether to use bfloat16 precision. Turn off if using MPS.", ) @click.option( "--torch_compile", type=click.BOOL, default=False, help="Whether to use torch.compile." ) @click.option( "--cpu_offload", type=bool, default=False, help="Whether to use CPU offloading (only load current stage's model to GPU)" ) @click.option( "--overlapped_decode", type=bool, default=False, help="Whether to use overlapped decoding (run dcae and vocoder using sliding windows)" ) def main(checkpoint_path, server_name, port, device_id, share, bf16, torch_compile, cpu_offload, overlapped_decode): """ Main function to launch the ACE Step pipeline demo. """ os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) from acestep.ui.components import create_main_demo_ui from acestep.pipeline_ace_step import ACEStepPipeline from acestep.data_sampler import DataSampler model_demo = ACEStepPipeline( checkpoint_dir=checkpoint_path, dtype="bfloat16" if bf16 else "float32", torch_compile=torch_compile, cpu_offload=cpu_offload, overlapped_decode=overlapped_decode ) data_sampler = DataSampler() demo = create_main_demo_ui( text2music_process_func=model_demo.__call__, sample_data_func=data_sampler.sample, load_data_func=data_sampler.load_json, ) demo.queue().launch(inbrowser=True) if __name__ == "__main__": main()