Spaces:
Paused
Paused
import yaml | |
from typing import Dict, List | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import librosa | |
from scipy.io.wavfile import write | |
from utils import ignore_warnings; ignore_warnings() | |
from utils import parse_yaml, load_ss_model | |
from models.clap_encoder import CLAP_Encoder | |
def build_audiosep(config_yaml, checkpoint_path, device): | |
configs = parse_yaml(config_yaml) | |
query_encoder = CLAP_Encoder().eval() | |
model = load_ss_model( | |
configs=configs, | |
checkpoint_path=checkpoint_path, | |
query_encoder=query_encoder | |
).eval().to(device) | |
print(f'Load AudioSep model from [{checkpoint_path}]') | |
return model | |
def inference(model, audio_file, text, output_file, device='cuda'): | |
print(f'Separate audio from [{audio_file}] with textual query [{text}]') | |
mixture, fs = librosa.load(audio_file, sr=32000, mono=True) | |
with torch.no_grad(): | |
text = [text] | |
conditions = model.query_encoder.get_query_embed( | |
modality='text', | |
text=text, | |
device=device | |
) | |
input_dict = { | |
"mixture": torch.Tensor(mixture)[None, None, :].to(device), | |
"condition": conditions, | |
} | |
sep_segment = model.ss_model(input_dict)["waveform"] | |
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() | |
write(output_file, 32000, np.round(sep_segment * 32767).astype(np.int16)) | |
print(f'Write separated audio to [{output_file}]') | |
if __name__ == '__main__': | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = build_audiosep( | |
config_yaml='config/audiosep_base.yaml', | |
checkpoint_path='checkpoint/step=3920000.ckpt', | |
device=device) | |
audio_file = '/mnt/bn/data-xubo/project/AudioShop/YT_audios/Y3VHpLxtd498.wav' | |
text = 'pigeons are cooing in the background' | |
output_file='separated_audio.wav' | |
inference(model, audio_file, text, output_file, device) | |