import os import json import torch import random random.seed(0) torch.manual_seed(0) import numpy as np np.random.seed(0) from src.model import DenoisingModel def denoise( wav: np.ndarray, ckpt_path: str = os.path.join(os.getcwd(), 'ckpt', 'full.pkl'), cfg_path: str = os.path.join(os.getcwd(), 'configs', 'full.json'), ): with open(cfg_path) as f: data = f.read() config = json.loads(data) net = DenoisingModel(**config['network_config']).to('cpu') # load checkpoint checkpoint = torch.load(ckpt_path, map_location='cpu') net.load_state_dict(checkpoint['model_state_dict']) net.eval() # inference wav = torch.from_numpy(wav).unsqueeze(0) wav_denoised = net(wav).squeeze(0).detach().numpy().reshape(-1) return wav_denoised