Niral Patel commited on
Commit
2a9c4ee
·
1 Parent(s): 99ea3bf

Add Transformers integration with custom model

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. config.json +4 -0
  3. custom_model.py +25 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ env/
config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "model_type": "spleeter",
3
+ "stems": 2
4
+ }
custom_model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ from spleeter.separator import Separator
3
+
4
+ class SpleeterConfig(PretrainedConfig):
5
+ model_type = "spleeter"
6
+ def __init__(self, stems=2, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.stems = stems
9
+
10
+ class SpleeterModel(PreTrainedModel):
11
+ config_class = SpleeterConfig
12
+
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+ self.separator = Separator(f"{config.stems}stems")
16
+
17
+ def forward(self, audio_path: str):
18
+ """
19
+ Separates the stems in the given audio file.
20
+ Args:
21
+ audio_path (str): Path to the input audio file.
22
+ Returns:
23
+ dict: Separated stems.
24
+ """
25
+ return self.separator.separate(audio_path)