Niral Patel
commited on
Commit
·
2a9c4ee
1
Parent(s):
99ea3bf
Add Transformers integration with custom model
Browse files- .gitignore +1 -0
- config.json +4 -0
- custom_model.py +25 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
env/
|
config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "spleeter",
|
3 |
+
"stems": 2
|
4 |
+
}
|
custom_model.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
2 |
+
from spleeter.separator import Separator
|
3 |
+
|
4 |
+
class SpleeterConfig(PretrainedConfig):
|
5 |
+
model_type = "spleeter"
|
6 |
+
def __init__(self, stems=2, **kwargs):
|
7 |
+
super().__init__(**kwargs)
|
8 |
+
self.stems = stems
|
9 |
+
|
10 |
+
class SpleeterModel(PreTrainedModel):
|
11 |
+
config_class = SpleeterConfig
|
12 |
+
|
13 |
+
def __init__(self, config):
|
14 |
+
super().__init__(config)
|
15 |
+
self.separator = Separator(f"{config.stems}stems")
|
16 |
+
|
17 |
+
def forward(self, audio_path: str):
|
18 |
+
"""
|
19 |
+
Separates the stems in the given audio file.
|
20 |
+
Args:
|
21 |
+
audio_path (str): Path to the input audio file.
|
22 |
+
Returns:
|
23 |
+
dict: Separated stems.
|
24 |
+
"""
|
25 |
+
return self.separator.separate(audio_path)
|