hainazhu commited on
Commit
c4aaa82
·
1 Parent(s): 208580f

add separator.py

Browse files
Files changed (1) hide show
  1. separator.py +50 -0
separator.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ import os
3
+ import torch
4
+ from third_party.demucs.models.pretrained import get_model_from_yaml
5
+
6
+
7
+ class Separator(torch.nn.Module):
8
+ def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
9
+ super().__init__()
10
+ if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
11
+ self.device = torch.device(f"cuda:{gpu_id}")
12
+ else:
13
+ self.device = torch.device("cpu")
14
+ self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
15
+
16
+ def init_demucs_model(self, model_path, config_path):
17
+ model = get_model_from_yaml(config_path, model_path)
18
+ model.to(self.device)
19
+ model.eval()
20
+ return model
21
+
22
+ def load_audio(self, f):
23
+ a, fs = torchaudio.load(f)
24
+ if (fs != 48000):
25
+ a = torchaudio.functional.resample(a, fs, 48000)
26
+ if a.shape[-1] >= 48000*10:
27
+ a = a[..., :48000*10]
28
+ else:
29
+ a = torch.cat([a, a], -1)
30
+ return a[:, 0:48000*10]
31
+
32
+ def run(self, audio_path, output_dir='tmp', ext=".flac"):
33
+ os.makedirs(output_dir, exist_ok=True)
34
+ name, _ = os.path.splitext(os.path.split(audio_path)[-1])
35
+ output_paths = []
36
+
37
+ for stem in self.demucs_model.sources:
38
+ output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
39
+ if os.path.exists(output_path):
40
+ output_paths.append(output_path)
41
+ if len(output_paths) == 1: # 4
42
+ vocal_path = output_paths[0]
43
+ else:
44
+ drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
45
+ for path in [drums_path, bass_path, other_path]:
46
+ os.remove(path)
47
+ full_audio = self.load_audio(audio_path)
48
+ vocal_audio = self.load_audio(vocal_path)
49
+ bgm_audio = full_audio - vocal_audio
50
+ return full_audio, vocal_audio, bgm_audio