# Copyright (c) 2024 Alibaba Inc # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) import os import contextlib from functools import partial from tqdm import tqdm import pickle import numpy as np import librosa from hear21passt.base import get_basic_model import pyloudnorm as pyln import torch import torch.nn.functional as F SAMPLING_RATE = 32000 class _patch_passt_stft: """ From version 1.8.0, return_complex must always be given explicitly for real inputs and return_complex=False has been deprecated. Decorator to patch torch.stft in PaSST that uses an old stft version. Adapted from: https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py """ def __init__(self): self.old_stft = torch.stft def __enter__(self): # return_complex is a mandatory parameter in latest torch versions. # torch is throwing RuntimeErrors when not set. # see: https://pytorch.org/docs/1.7.1/generated/torch.stft.html?highlight=stft#torch.stft # see: https://github.com/kkoutini/passt_hear21/commit/dce83183674e559162b49924d666c0a916dc967a torch.stft = partial(torch.stft, return_complex=False) def __exit__(self, *exc): torch.stft = self.old_stft def return_probabilities(model, audio_path, window_size=10, overlap=5, collect='mean'): """ Given an audio and the PaSST model, return the probabilities of each AudioSet class. Audio is converted to mono at 32kHz. PaSST model is trained with 10 sec inputs. We refer to this parameter as the window_size. We set it to 10 sec for consistency with PaSST training. For longer audios, we split audio into overlapping analysis windows of window_size and overlap of 10 and 5 seconds. PaSST supports 10, 20 or 30 sec inputs. Not longer inputs: https://github.com/kkoutini/PaSST/issues/19 Note that AudioSet taggers normally use sigmoid output layers. Yet, to compute the KL we work with normalized probabilities by running a softmax over logits as in MusicGen: https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py This implementation assumes run will be on GPU. Params: -- model: PaSST model on a GPU. -- audio_path: path to the audio to be loaded with librosa. -- window_size (default=10 sec): analysis window (and receptive field) of PaSST. -- overlap (default=5 sec): overlap of the running analysis window for inputs longar than window_size (10 sec). -- collect (default='mean'): for longer inputs, aggregate/collect via 'mean' or 'max' pooling along logits vector. Returns: -- 527 probabilities (after softmax, no logarithm). """ # load the audio using librosa audio, _ = librosa.load(audio_path, sr=SAMPLING_RATE, mono=True) audio = pyln.normalize.peak(audio, -1.0) # calculate the step size for the analysis windows with the specified overlap step_size = int((window_size - overlap) * SAMPLING_RATE) # iterate over the audio, creating analysis windows probabilities = [] for i in range(0, max(step_size, len(audio) - step_size), step_size): # extract the current analysis window window = audio[i:i + int(window_size * SAMPLING_RATE)] # pad the window with zeros if it's shorter than the desired window size if len(window) < int(window_size * SAMPLING_RATE): # discard window if it's too small (avoid mostly zeros predicted as silence), as in MusicGen: # https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py if len(window) > int(window_size * SAMPLING_RATE * 0.15): tmp = np.zeros(int(window_size * SAMPLING_RATE)) tmp[:len(window)] = window window = tmp # convert to a PyTorch tensor and move to GPU audio_wave = torch.from_numpy(window.astype(np.float32)).unsqueeze(0).cuda() # get the probabilities for this analysis window with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): with torch.no_grad(), _patch_passt_stft(): logits = model(audio_wave) probabilities.append(torch.squeeze(logits)) probabilities = torch.stack(probabilities) if collect == 'mean': probabilities = torch.mean(probabilities, dim=0) elif collect == 'max': probabilities, _ = torch.max(probabilities, dim=0) return F.softmax(probabilities, dim=0).squeeze().cpu() def passt_kld(ids, eval_path, eval_files_extension='.wav', ref_path=None, ref_files_extension='.wav', load_ref_probabilities=None, no_ids=[], collect='mean'): """ Compute KL-divergence between the label probabilities of the generated audio with respect to the original audio. Both generated audio (in eval_path) and original audio (in ref_path) are represented by the same prompt/description. Audios are identified by an id, that is the name of the file in both directories and links the audio with the prompt/description. segmenting the audio For inputs longer that the 10 sec PaSST was trained on, we aggregate/collect via 'mean' (default) or 'max' pooling along the logits vector. We split the inpot into overlapping analysis windows. Subsequently, we aggregate/collect (accross windows) the generated logits and then apply a softmax. This evaluation script assumes that ids are in both ref_path and eval_path. We label probabilities via the PaSST model: https://github.com/kkoutini/PaSST GPU-based computation. Extracting the probabilities is timeconsuming. After being computed once, we store them. We store pre-computed reference probabilities in load/ To load those and save computation, just set the path in load_ref_probabilities. If load_ref_probabilities is set, ref_path is not required. Params: -- ids: list of ids present in both eval_path and ref_path. -- eval_path: path where the generated audio files to evaluate are available. -- eval_files_extenstion: files extension (default .wav) in eval_path. -- ref_path: path where the reference audio files are available. (instead of load_ref_probabilities) -- ref_files_extenstion: files extension (default .wav) in ref_path. -- load_ref_probabilities: path to the reference probabilities. (inestead of ref_path) -- no_ids: it is possible that some reference audio is corrupted or not present. Ignore some this list of ids. -- collect (default='mean'): for longer inputs, aggregate/collect via 'mean' or 'max' pooling along the logits vector. Returns: -- KL divergence """ with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): # capturing all useless outputs from passt # load model model = get_basic_model(mode="logits") model.eval() model = model.cuda() if not os.path.isdir(eval_path): if not os.path.isfile(eval_path): raise ValueError('eval_path does not exist') if load_ref_probabilities: if not os.path.exists(load_ref_probabilities): raise ValueError('load_ref_probabilities does not exist') print('[LOADING REFERENCE PROBABILITIES] ', load_ref_probabilities) with open(load_ref_probabilities, 'rb') as fp: ref_p = pickle.load(fp) else: if ref_path: if not os.path.isdir(ref_path): if os.path.isfile(ref_path): id2utt = {} with open(ref_path, "r") as f: for line in f: sec = line.strip().split(" ") id2utt[sec[0]] = sec[1] f.close() else: raise ValueError("ref_path does not exist") print('[EXTRACTING REFERENCE PROBABILITIES] ', ref_path) ref_p = {} for id in tqdm(ids): if id not in no_ids: try: if os.path.isfile(ref_path): if id in id2utt.keys(): audio_path = id2utt[id] else: raise ValueError(f"id: {id} not in {ref_path}!") else: audio_path = os.path.join(ref_path, str(id)+ref_files_extension) if os.path.isfile(audio_path): ref_p[id] = return_probabilities(model, audio_path, collect=collect) except Exception as e: print(f"An unexpected error occurred with {id}: {e}\nIf you failed to download it you can add it to no_ids list.") # store reference probabilities to load later on if not os.path.exists('load/passt_kld/'): os.makedirs('load/passt_kld/') save_ref_probabilities_path = 'load/passt_kld/'+ref_path.replace('/', '_')+'_collect'+str(collect)+'__reference_probabilities.pkl' with open(save_ref_probabilities_path, 'wb') as fp: pickle.dump(ref_p, fp) print('[REFERENCE EMBEDDINGS][SAVED] ', save_ref_probabilities_path) else: raise ValueError('Must specify ref_path or load_ref_probabilities') print('[EVALUATING GENERATIONS] ', eval_path) passt_kl = 0 count = 0 for id in tqdm(ids): if id not in no_ids: try: audio_path = os.path.join(eval_path, str(id)+eval_files_extension) if os.path.isfile(audio_path): eval_p = return_probabilities(model, audio_path, collect=collect) # note: F.kl_div(x, y) is KL(y||x) # see: https://github.com/pytorch/pytorch/issues/7337 # see: https://discuss.pytorch.org/t/kl-divergence-different-results-from-tf/56903/2 passt_kl += F.kl_div((ref_p[id] + 1e-6).log(), eval_p, reduction='sum', log_target=False) count += 1 except Exception as e: print(f"An unexpected error occurred with {id}: {e}\nIf you failed to download it you can add it to no_ids list.") return passt_kl / count if count > 0 else 0