from gyraudio.audio_separation.experiment_tracking.experiments import get_experience from gyraudio.audio_separation.parser import shared_parser from gyraudio.audio_separation.properties import TEST, NAME, SHORT_NAME, CURRENT_EPOCH, SNR_FILTER from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder from gyraudio.audio_separation.metrics import snr from gyraudio.io.dump import Dump from pathlib import Path import sys import torch from tqdm import tqdm import torchaudio import pandas as pd from typing import List # Files paths DEFAULT_RECORD_FILE = "infer_record.csv" # Store the characteristics of the inference record file DEFAULT_EVALUATION_FILE = "eval_df.csv" # Store the characteristics of the inference record file # Record keys NBATCH = "nb_batch" BEST_SNR = "best_snr" BEST_SAVE_SNR = "best_save_snr" WORST_SNR = "worst_snr" WORST_SAVE_SNR = "worst_save_snr" RECORD_KEYS = [NAME, SHORT_NAME, CURRENT_EPOCH, NBATCH, SNR_FILTER, BEST_SAVE_SNR, BEST_SNR, WORST_SAVE_SNR, WORST_SNR] # Exaluation keys SAVE_IDX = "save_idx" SNR_IN = "snr_in" SNR_OUT = "snr_out" EVAL_KEYS = [SAVE_IDX, SNR_IN, SNR_OUT] def load_file(path: Path, keys: List[str]) -> pd.DataFrame: if not (path.exists()): df = pd.DataFrame(columns=keys) df.to_csv(path) return pd.read_csv(path) def launch_infer(exp: int, snr_filter: list = None, device: str = "cuda", model_dir: Path = None, output_dir: Path = EXPERIMENT_STORAGE_ROOT, force_reload=False, max_batches=None, ext=".wav"): # Load experience if snr_filter is not None: snr_filter = sorted(snr_filter) short_name, model, config, dl = get_experience(exp, snr_filter_test=snr_filter) exists, exp_dir = get_output_folder(config, root_dir=model_dir, override=False) assert exp_dir.exists(), f"Experiment {short_name} does not exist in {model_dir}" model.eval() model.to(device) model, optimizer, epoch, config_checkpt = load_checkpoint(model, exp_dir, epoch=None, device=device) # Folder creation if output_dir is not None: record_path = output_dir/DEFAULT_RECORD_FILE record_df = load_file(record_path, RECORD_KEYS) # Define conditions for filtering exist_conditions = { NAME: config[NAME], SHORT_NAME: config[SHORT_NAME], CURRENT_EPOCH: epoch, NBATCH: max_batches, } # Create boolean masks and combine them masks = [(record_df[key] == value) for key, value in exist_conditions.items()] if snr_filter is None: masks.append((record_df[SNR_FILTER]).isnull()) else: masks.append(record_df[SNR_FILTER] == str(snr_filter)) combined_mask = pd.Series(True, index=record_df.index) for mask in masks: combined_mask = combined_mask & mask filtered_df = record_df[combined_mask] save_dir = output_dir/(exp_dir.name+"_infer" + (f"_epoch_{epoch:04d}_nbatch_{max_batches if max_batches is not None else len(dl[TEST])}") + ("" if snr_filter is None else f"_snrs_{'_'.join(map(str, snr_filter))}")) evaluation_path = save_dir/DEFAULT_EVALUATION_FILE if not (filtered_df.empty) and not (force_reload): assert evaluation_path.exists() print(f"Inference already exists, see folder {save_dir}") record_row_df = filtered_df else: record_row_df = pd.DataFrame({ NAME: config[NAME], SHORT_NAME: config[SHORT_NAME], CURRENT_EPOCH: epoch, NBATCH: max_batches, SNR_FILTER: [None], }, index=[0], columns=RECORD_KEYS) record_row_df.at[0, SNR_FILTER] = snr_filter save_dir.mkdir(parents=True, exist_ok=True) evaluation_df = load_file(evaluation_path, EVAL_KEYS) with torch.no_grad(): test_loss = 0. save_idx = 0 best_snr = 0 worst_snr = 0 processed_batches = 0 for step_index, (batch_mix, batch_signal, batch_noise) in tqdm( enumerate(dl[TEST]), desc=f"Inference epoch {epoch}", total=max_batches if max_batches is not None else len(dl[TEST])): batch_mix, batch_signal, batch_noise = batch_mix.to( device), batch_signal.to(device), batch_noise.to(device) batch_output_signal, _batch_output_noise = model(batch_mix) loss = torch.nn.functional.mse_loss(batch_output_signal, batch_signal) test_loss += loss.item() # SNR stats snr_in = snr(batch_mix, batch_signal, reduce=None) snr_out = snr(batch_output_signal, batch_signal, reduce=None) best_current, best_idx_current = torch.max(snr_out-snr_in, axis=0) worst_current, worst_idx_current = torch.min(snr_out-snr_in, axis=0) if best_current > best_snr: best_snr = best_current best_save_idx = save_idx + best_idx_current if worst_current > worst_snr: worst_snr = worst_current worst_save_idx = save_idx + worst_idx_current # Save by signal batch_output_signal = batch_output_signal.detach().cpu() batch_signal = batch_signal.detach().cpu() batch_mix = batch_mix.detach().cpu() for audio_idx in range(batch_output_signal.shape[0]): dic = {SAVE_IDX: save_idx, SNR_IN: float( snr_in[audio_idx]), SNR_OUT: float(snr_out[audio_idx])} new_eval_row = pd.DataFrame(dic, index=[0]) evaluation_df = pd.concat([new_eval_row, evaluation_df.loc[:]], ignore_index=True) # Save .wav torchaudio.save( str(save_dir/f"{save_idx:04d}_mixed{ext}"), batch_mix[audio_idx, :, :], sample_rate=dl[TEST].dataset.sampling_rate, channels_first=True ) torchaudio.save( str(save_dir/f"{save_idx:04d}_out{ext}"), batch_output_signal[audio_idx, :, :], sample_rate=dl[TEST].dataset.sampling_rate, channels_first=True ) torchaudio.save( str(save_dir/f"{save_idx:04d}_original{ext}"), batch_signal[audio_idx, :, :], sample_rate=dl[TEST].dataset.sampling_rate, channels_first=True ) Dump.save_json(dic, save_dir/f"{save_idx:04d}.json") save_idx += 1 processed_batches += 1 if max_batches is not None and processed_batches >= max_batches: break test_loss = test_loss/len(dl[TEST]) evaluation_df.to_csv(evaluation_path) record_row_df[BEST_SAVE_SNR] = int(best_save_idx) record_row_df[BEST_SNR] = float(best_snr) record_row_df[WORST_SAVE_SNR] = int(worst_save_idx) record_row_df[WORST_SNR] = float(worst_snr) record_df = pd.concat([record_row_df, record_df.loc[:]], ignore_index=True) record_df.to_csv(record_path, index=0) print(f"Test loss: {test_loss:.3e}, \nbest snr performance: {best_save_idx} with {best_snr:.1f}dB, \nworst snr performance: {worst_save_idx} with {worst_snr:.1f}dB") return record_row_df, evaluation_path def main(argv): default_device = "cuda" if torch.cuda.is_available() else "cpu" parser_def = shared_parser(help="Launch inference on a specific model" + ("\n<<>>" if default_device == "cuda" else "")) parser_def.add_argument("-i", "--input-dir", type=str, default=EXPERIMENT_STORAGE_ROOT) parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT) parser_def.add_argument("-d", "--device", type=str, default=default_device, help="Training device", choices=["cpu", "cuda"]) parser_def.add_argument("-r", "--reload", action="store_true", help="Force reload files") parser_def.add_argument("-b", "--nb-batch", type=int, default=None, help="Number of batches to process") parser_def.add_argument("-s", "--snr-filter", type=float, nargs="+", default=None, help="SNR filters on the inference dataloader") parser_def.add_argument("-ext", "--extension", type=str, default=".wav", help="Extension of the audio files to save", choices=[".wav", ".mp4"]) args = parser_def.parse_args(argv) for exp in args.experiments: launch_infer( exp, model_dir=Path(args.input_dir), output_dir=Path(args.output_dir), device=args.device, force_reload=args.reload, max_batches=args.nb_batch, snr_filter=args.snr_filter, ext=args.extension ) if __name__ == "__main__": main(sys.argv[1:]) # Example : python src\gyraudio\audio_separation\infer.py -i ./__output_audiosep -e 1002 -d cpu -b 2 -s 4 5 6