File size: 2,016 Bytes
ae29df4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import yaml
from typing import Dict, List
import torch
import torch.nn as nn
import numpy as np
import librosa
from scipy.io.wavfile import write
from utils import ignore_warnings; ignore_warnings()
from utils import parse_yaml, load_ss_model
from models.clap_encoder import CLAP_Encoder


def build_audiosep(config_yaml, checkpoint_path, device):
    configs = parse_yaml(config_yaml)
    
    query_encoder = CLAP_Encoder().eval()
    model = load_ss_model(
            configs=configs,
            checkpoint_path=checkpoint_path,
            query_encoder=query_encoder
        ).eval().to(device)

    print(f'Load AudioSep model from [{checkpoint_path}]')
    return model


def inference(model, audio_file, text, output_file, device='cuda'):
    print(f'Separate audio from [{audio_file}] with textual query [{text}]')
    mixture, fs = librosa.load(audio_file, sr=32000, mono=True)
    with torch.no_grad():
        text = [text]

        conditions = model.query_encoder.get_query_embed(
            modality='text',
            text=text,
            device=device
        )

        input_dict = {
            "mixture": torch.Tensor(mixture)[None, None, :].to(device),
            "condition": conditions,
        } 

        sep_segment = model.ss_model(input_dict)["waveform"]

        sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()

        write(output_file, 32000, np.round(sep_segment * 32767).astype(np.int16))
        print(f'Write separated audio to [{output_file}]')


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = build_audiosep(
        config_yaml='config/audiosep_base.yaml', 
        checkpoint_path='checkpoint/step=3920000.ckpt', 
        device=device)

    audio_file = '/mnt/bn/data-xubo/project/AudioShop/YT_audios/Y3VHpLxtd498.wav'
    text = 'pigeons are cooing in the background'
    output_file='separated_audio.wav'
    
    inference(model, audio_file, text, output_file, device)