File size: 826 Bytes
79f6504 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import os
import json
import torch
import random
random.seed(0)
torch.manual_seed(0)
import numpy as np
np.random.seed(0)
from src.model import DenoisingModel
def denoise(
wav: np.ndarray,
ckpt_path: str = os.path.join(os.getcwd(), 'ckpt', 'full.pkl'),
cfg_path: str = os.path.join(os.getcwd(), 'configs', 'full.json'),
):
with open(cfg_path) as f:
data = f.read()
config = json.loads(data)
net = DenoisingModel(**config['network_config']).to('cpu')
# load checkpoint
checkpoint = torch.load(ckpt_path, map_location='cpu')
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()
# inference
wav = torch.from_numpy(wav).unsqueeze(0)
wav_denoised = net(wav).squeeze(0).detach().numpy().reshape(-1)
return wav_denoised |