|
import os
|
|
import json
|
|
import torch
|
|
import random
|
|
random.seed(0)
|
|
torch.manual_seed(0)
|
|
import numpy as np
|
|
np.random.seed(0)
|
|
from src.model import DenoisingModel
|
|
|
|
|
|
def denoise(
|
|
wav: np.ndarray,
|
|
ckpt_path: str = os.path.join(os.getcwd(), 'ckpt', 'full.pkl'),
|
|
cfg_path: str = os.path.join(os.getcwd(), 'configs', 'full.json'),
|
|
):
|
|
|
|
with open(cfg_path) as f:
|
|
data = f.read()
|
|
config = json.loads(data)
|
|
|
|
net = DenoisingModel(**config['network_config']).to('cpu')
|
|
|
|
|
|
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
|
net.load_state_dict(checkpoint['model_state_dict'])
|
|
net.eval()
|
|
|
|
|
|
wav = torch.from_numpy(wav).unsqueeze(0)
|
|
wav_denoised = net(wav).squeeze(0).detach().numpy().reshape(-1)
|
|
|
|
return wav_denoised |