from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel from spleeter.separator import Separator class SpleeterConfig(PretrainedConfig): model_type = "spleeter" def __init__(self, stems=2, **kwargs): super().__init__(**kwargs) self.stems = stems class SpleeterModel(PreTrainedModel): config_class = SpleeterConfig def __init__(self, config): super().__init__(config) self.separator = Separator(f"spleeter:{config.stems}stems") def forward(self, audio_path: str): """ Separates the stems in the given audio file. Args: audio_path (str): Path to the input audio file. Returns: path: Separated stems. """ return self.separator.separate_to_file(audio_path, "separated_audio") AutoConfig.register("spleeter", SpleeterConfig) AutoModel.register(SpleeterConfig, SpleeterModel) SpleeterConfig.register_for_auto_class() SpleeterModel.register_for_auto_class("AutoModel")