File size: 826 Bytes
79f6504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
import json
import torch
import random
random.seed(0)
torch.manual_seed(0)
import numpy as np
np.random.seed(0)
from src.model import DenoisingModel


def denoise(

    wav: np.ndarray, 

    ckpt_path: str = os.path.join(os.getcwd(), 'ckpt', 'full.pkl'),

    cfg_path: str = os.path.join(os.getcwd(), 'configs', 'full.json'),

):

    with open(cfg_path) as f:
        data = f.read()
    config = json.loads(data)

    net = DenoisingModel(**config['network_config']).to('cpu')

    # load checkpoint
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    net.load_state_dict(checkpoint['model_state_dict'])
    net.eval()

    # inference
    wav = torch.from_numpy(wav).unsqueeze(0)
    wav_denoised = net(wav).squeeze(0).detach().numpy().reshape(-1)

    return wav_denoised