File size: 2,115 Bytes
091b1e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77c6c52
 
 
091b1e0
 
 
 
 
 
 
 
77c6c52
 
bd0a813
091b1e0
77c6c52
bd0a813
091b1e0
 
 
 
 
 
 
 
 
 
 
 
 
3f8f152
091b1e0
 
 
 
 
77c6c52
091b1e0
 
 
 
 
77c6c52
091b1e0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
from tqdm import tqdm

from utils import load_wav, collect_valentini_paths
from metrics import Metrics
from denoisers.SpectralGating import SpectralGating


PARSERS = {
    'valentini': collect_valentini_paths
}
MODELS = {
    'baseline': SpectralGating
}



def evaluate_on_dataset(model_name, dataset_path, dataset_type):
    if model_name is not None:
        model = MODELS[model_name]()
    parser = PARSERS[dataset_type]
    clean_wavs, noisy_wavs = parser(dataset_path)

    metrics = Metrics()
    mean_scores = {'PESQ': 0, 'STOI': 0}
    for clean_path, noisy_path in tqdm(zip(clean_wavs, noisy_wavs), total=len(clean_wavs)):
        clean_wav = load_wav(clean_path)
        noisy_wav = load_wav(noisy_path)

        if model_name is None:
            scores = metrics.calculate(denoised=noisy_wav, clean=clean_wav)
        else:
            denoised_wav = model(noisy_wav)
            scores = metrics.calculate(denoised=denoised_wav, clean=clean_wav)

        mean_scores['PESQ'] += scores['PESQ']
        mean_scores['STOI'] += scores['STOI']

    mean_scores['PESQ'] = mean_scores['PESQ'].numpy() / len(clean_wavs)
    mean_scores['STOI'] = mean_scores['STOI'].numpy() / len(clean_wavs)

    return mean_scores


if __name__ == '__main__':
    parser = argparse.ArgumentParser(prog='Program to evaluate denoising')
    parser.add_argument('--dataset_path', type=str,
                        default='/media/public/dataset/denoising/DS_10283_2791/',
                        help='Path to dataset folder')
    parser.add_argument('--dataset_type', type=str, required=True,
                        choices=['valentini'])
    parser.add_argument('--model_name', type=str,
                        choices=['baseline'])


    args = parser.parse_args()

    mean_scores = evaluate_on_dataset(model_name=args.model_name,
                        dataset_path=args.dataset_path,
                        dataset_type=args.dataset_type)
    print(f"Metrics on {args.dataset_type} dataset with "
          f"{args.model_name if args.model_name is not None else 'ideal denoising'} = {mean_scores}")