File size: 1,020 Bytes
6d15d50 2a9c4ee 6d15d50 2a9c4ee 28dcb23 6d15d50 2a9c4ee 6d15d50 2a9c4ee d87f154 2a9c4ee 6e1a9b7 2a9c4ee fee470d 6d15d50 8b1b948 6e1a9b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
from spleeter.separator import Separator
class SpleeterConfig(PretrainedConfig):
model_type = "spleeter"
def __init__(self, stems=2, **kwargs):
super().__init__(**kwargs)
self.stems = stems
class SpleeterModel(PreTrainedModel):
config_class = SpleeterConfig
def __init__(self, config):
super().__init__(config)
self.separator = Separator(f"spleeter:{config.stems}stems")
def forward(self, audio_path: str):
"""
Separates the stems in the given audio file.
Args:
audio_path (str): Path to the input audio file.
Returns:
path: Separated stems.
"""
return self.separator.separate_to_file(audio_path, "separated_audio")
AutoConfig.register("spleeter", SpleeterConfig)
AutoModel.register(SpleeterConfig, SpleeterModel)
SpleeterConfig.register_for_auto_class()
SpleeterModel.register_for_auto_class("AutoModel") |