File size: 9,904 Bytes
f6b56a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from gyraudio.audio_separation.experiment_tracking.experiments import get_experience
from gyraudio.audio_separation.parser import shared_parser
from gyraudio.audio_separation.properties import TEST, NAME, SHORT_NAME, CURRENT_EPOCH, SNR_FILTER
from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint
from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder
from gyraudio.audio_separation.metrics import snr
from gyraudio.io.dump import Dump
from pathlib import Path
import sys
import torch
from tqdm import tqdm
import torchaudio
import pandas as pd
from typing import List
# Files paths
DEFAULT_RECORD_FILE = "infer_record.csv"  # Store the characteristics of the inference record file
DEFAULT_EVALUATION_FILE = "eval_df.csv"  # Store the characteristics of the inference record file
# Record keys
NBATCH = "nb_batch"
BEST_SNR = "best_snr"
BEST_SAVE_SNR = "best_save_snr"
WORST_SNR = "worst_snr"
WORST_SAVE_SNR = "worst_save_snr"
RECORD_KEYS = [NAME, SHORT_NAME, CURRENT_EPOCH, NBATCH, SNR_FILTER, BEST_SAVE_SNR, BEST_SNR, WORST_SAVE_SNR, WORST_SNR]
# Exaluation keys
SAVE_IDX = "save_idx"
SNR_IN = "snr_in"
SNR_OUT = "snr_out"
EVAL_KEYS = [SAVE_IDX, SNR_IN, SNR_OUT]


def load_file(path: Path, keys: List[str]) -> pd.DataFrame:
    if not (path.exists()):
        df = pd.DataFrame(columns=keys)
        df.to_csv(path)
    return pd.read_csv(path)


def launch_infer(exp: int, snr_filter: list = None, device: str = "cuda", model_dir: Path = None,
                 output_dir: Path = EXPERIMENT_STORAGE_ROOT, force_reload=False, max_batches=None,
                 ext=".wav"):
    # Load experience
    if snr_filter is not None:
        snr_filter = sorted(snr_filter)
    short_name, model, config, dl = get_experience(exp, snr_filter_test=snr_filter)
    exists, exp_dir = get_output_folder(config, root_dir=model_dir, override=False)
    assert exp_dir.exists(), f"Experiment {short_name} does not exist in {model_dir}"
    model.eval()
    model.to(device)
    model, optimizer, epoch, config_checkpt = load_checkpoint(model, exp_dir, epoch=None, device=device)
    # Folder creation
    if output_dir is not None:
        record_path = output_dir/DEFAULT_RECORD_FILE
        record_df = load_file(record_path, RECORD_KEYS)

        # Define conditions for filtering
        exist_conditions = {
            NAME: config[NAME],
            SHORT_NAME: config[SHORT_NAME],
            CURRENT_EPOCH: epoch,
            NBATCH: max_batches,
        }
        # Create boolean masks and combine them
        masks = [(record_df[key] == value) for key, value in exist_conditions.items()]
        if snr_filter is None:
            masks.append((record_df[SNR_FILTER]).isnull())
        else:
            masks.append(record_df[SNR_FILTER] == str(snr_filter))
        combined_mask = pd.Series(True, index=record_df.index)
        for mask in masks:
            combined_mask = combined_mask & mask
        filtered_df = record_df[combined_mask]

        save_dir = output_dir/(exp_dir.name+"_infer" + (f"_epoch_{epoch:04d}_nbatch_{max_batches if max_batches is not None else len(dl[TEST])}")
                               + ("" if snr_filter is None else f"_snrs_{'_'.join(map(str, snr_filter))}"))
        evaluation_path = save_dir/DEFAULT_EVALUATION_FILE
        if not (filtered_df.empty) and not (force_reload):
            assert evaluation_path.exists()
            print(f"Inference already exists, see folder {save_dir}")
            record_row_df = filtered_df
        else:
            record_row_df = pd.DataFrame({
                NAME: config[NAME],
                SHORT_NAME: config[SHORT_NAME],
                CURRENT_EPOCH: epoch,
                NBATCH: max_batches,
                SNR_FILTER: [None],
            }, index=[0], columns=RECORD_KEYS)
            record_row_df.at[0, SNR_FILTER] = snr_filter

            save_dir.mkdir(parents=True, exist_ok=True)
            evaluation_df = load_file(evaluation_path, EVAL_KEYS)
            with torch.no_grad():
                test_loss = 0.
                save_idx = 0
                best_snr = 0
                worst_snr = 0
                processed_batches = 0
                for step_index, (batch_mix, batch_signal, batch_noise) in tqdm(
                        enumerate(dl[TEST]), desc=f"Inference epoch {epoch}", total=max_batches if max_batches is not None else len(dl[TEST])):
                    batch_mix, batch_signal, batch_noise = batch_mix.to(
                        device), batch_signal.to(device), batch_noise.to(device)
                    batch_output_signal, _batch_output_noise = model(batch_mix)
                    loss = torch.nn.functional.mse_loss(batch_output_signal, batch_signal)
                    test_loss += loss.item()

                    # SNR stats
                    snr_in = snr(batch_mix, batch_signal, reduce=None)
                    snr_out = snr(batch_output_signal, batch_signal, reduce=None)
                    best_current, best_idx_current = torch.max(snr_out-snr_in, axis=0)
                    worst_current, worst_idx_current = torch.min(snr_out-snr_in, axis=0)
                    if best_current > best_snr:
                        best_snr = best_current
                        best_save_idx = save_idx + best_idx_current
                    if worst_current > worst_snr:
                        worst_snr = worst_current
                        worst_save_idx = save_idx + worst_idx_current

                    # Save by signal
                    batch_output_signal = batch_output_signal.detach().cpu()
                    batch_signal = batch_signal.detach().cpu()
                    batch_mix = batch_mix.detach().cpu()
                    for audio_idx in range(batch_output_signal.shape[0]):
                        dic = {SAVE_IDX: save_idx, SNR_IN: float(
                            snr_in[audio_idx]), SNR_OUT: float(snr_out[audio_idx])}
                        new_eval_row = pd.DataFrame(dic, index=[0])
                        evaluation_df = pd.concat([new_eval_row, evaluation_df.loc[:]], ignore_index=True)

                        # Save .wav
                        torchaudio.save(
                            str(save_dir/f"{save_idx:04d}_mixed{ext}"),
                            batch_mix[audio_idx, :, :],
                            sample_rate=dl[TEST].dataset.sampling_rate,
                            channels_first=True
                        )
                        torchaudio.save(
                            str(save_dir/f"{save_idx:04d}_out{ext}"),
                            batch_output_signal[audio_idx, :, :],
                            sample_rate=dl[TEST].dataset.sampling_rate,
                            channels_first=True
                        )
                        torchaudio.save(
                            str(save_dir/f"{save_idx:04d}_original{ext}"),
                            batch_signal[audio_idx, :, :],
                            sample_rate=dl[TEST].dataset.sampling_rate,
                            channels_first=True
                        )
                        Dump.save_json(dic, save_dir/f"{save_idx:04d}.json")
                        save_idx += 1
                    processed_batches += 1
                    if max_batches is not None and processed_batches >= max_batches:
                        break
            test_loss = test_loss/len(dl[TEST])
            evaluation_df.to_csv(evaluation_path)

            record_row_df[BEST_SAVE_SNR] = int(best_save_idx)
            record_row_df[BEST_SNR] = float(best_snr)
            record_row_df[WORST_SAVE_SNR] = int(worst_save_idx)
            record_row_df[WORST_SNR] = float(worst_snr)
            record_df = pd.concat([record_row_df, record_df.loc[:]], ignore_index=True)
            record_df.to_csv(record_path, index=0)

            print(f"Test loss: {test_loss:.3e}, \nbest snr performance: {best_save_idx} with {best_snr:.1f}dB, \nworst snr performance: {worst_save_idx} with {worst_snr:.1f}dB")

    return record_row_df, evaluation_path


def main(argv):
    default_device = "cuda" if torch.cuda.is_available() else "cpu"
    parser_def = shared_parser(help="Launch inference on a specific model"
                               + ("\n<<<Cuda available>>>" if default_device == "cuda" else ""))
    parser_def.add_argument("-i", "--input-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
    parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
    parser_def.add_argument("-d", "--device", type=str, default=default_device,
                            help="Training device", choices=["cpu", "cuda"])
    parser_def.add_argument("-r", "--reload", action="store_true",
                            help="Force reload files")
    parser_def.add_argument("-b", "--nb-batch", type=int, default=None,
                            help="Number of batches to process")
    parser_def.add_argument("-s",  "--snr-filter", type=float, nargs="+", default=None,
                            help="SNR filters on the inference dataloader")
    parser_def.add_argument("-ext", "--extension", type=str, default=".wav", help="Extension of the audio files to save",
                            choices=[".wav", ".mp4"])
    args = parser_def.parse_args(argv)
    for exp in args.experiments:
        launch_infer(
            exp,
            model_dir=Path(args.input_dir),
            output_dir=Path(args.output_dir),
            device=args.device,
            force_reload=args.reload,
            max_batches=args.nb_batch,
            snr_filter=args.snr_filter,
            ext=args.extension
        )


if __name__ == "__main__":
    main(sys.argv[1:])

# Example : python src\gyraudio\audio_separation\infer.py -i ./__output_audiosep -e 1002 -d cpu -b 2 -s 4 5 6