youtube_spleeter / custom_model.py
Niral Patel
change in model config
6e1a9b7
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
from spleeter.separator import Separator
class SpleeterConfig(PretrainedConfig):
model_type = "spleeter"
def __init__(self, stems=2, **kwargs):
super().__init__(**kwargs)
self.stems = stems
class SpleeterModel(PreTrainedModel):
config_class = SpleeterConfig
def __init__(self, config):
super().__init__(config)
self.separator = Separator(f"spleeter:{config.stems}stems")
def forward(self, audio_path: str):
"""
Separates the stems in the given audio file.
Args:
audio_path (str): Path to the input audio file.
Returns:
path: Separated stems.
"""
return self.separator.separate_to_file(audio_path, "separated_audio")
AutoConfig.register("spleeter", SpleeterConfig)
AutoModel.register(SpleeterConfig, SpleeterModel)
SpleeterConfig.register_for_auto_class()
SpleeterModel.register_for_auto_class("AutoModel")