|
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel |
|
from spleeter.separator import Separator |
|
|
|
|
|
class SpleeterConfig(PretrainedConfig): |
|
model_type = "spleeter" |
|
|
|
def __init__(self, stems=2, **kwargs): |
|
super().__init__(**kwargs) |
|
self.stems = stems |
|
|
|
|
|
class SpleeterModel(PreTrainedModel): |
|
config_class = SpleeterConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.separator = Separator(f"spleeter:{config.stems}stems") |
|
|
|
def forward(self, audio_path: str): |
|
""" |
|
Separates the stems in the given audio file. |
|
Args: |
|
audio_path (str): Path to the input audio file. |
|
Returns: |
|
path: Separated stems. |
|
""" |
|
return self.separator.separate_to_file(audio_path, "separated_audio") |
|
|
|
|
|
AutoConfig.register("spleeter", SpleeterConfig) |
|
AutoModel.register(SpleeterConfig, SpleeterModel) |
|
SpleeterConfig.register_for_auto_class() |
|
SpleeterModel.register_for_auto_class("AutoModel") |