|
from multiprocessing import Pool, cpu_count |
|
import os |
|
import torchaudio |
|
import torch |
|
from tqdm import tqdm |
|
|
|
def process_file(args): |
|
""" |
|
Processes a single audio file to check for NaN values and get its length. |
|
Args: |
|
args (tuple): A tuple containing the file path and the root directory. |
|
Returns: |
|
str or None: A formatted string with the ABSOLUTE path and number of samples, |
|
or None if the file is empty, contains NaN, or causes an error. |
|
""" |
|
file_path, root_dir = args |
|
try: |
|
|
|
abs_path = os.path.abspath(file_path) |
|
waveform, sample_rate = torchaudio.load(file_path) |
|
|
|
if waveform.numel() == 0: |
|
return None |
|
|
|
flat_waveform = waveform.reshape(-1) |
|
batch_size = 10000 |
|
|
|
for start in range(0, flat_waveform.numel(), batch_size): |
|
end = min(start + batch_size, flat_waveform.numel()) |
|
if torch.isnan(flat_waveform[start:end]).any(): |
|
print(f"NaN found in: {abs_path}") |
|
return None |
|
|
|
nsample = waveform.shape[1] |
|
return f"{abs_path}\t{nsample}\n" |
|
|
|
except Exception as e: |
|
print(f"Error processing {file_path}: {e}") |
|
return None |
|
|
|
def list_audio_files(root_dir, output_file, exclude_dirs=None): |
|
""" |
|
Lists audio files in a directory, processes them in parallel to get their |
|
lengths, and writes the results to a file with ABSOLUTE paths. |
|
Args: |
|
root_dir (str): The root directory to search for audio files. |
|
output_file (str): The path to the output file. |
|
exclude_dirs (list, optional): A list of directories to exclude. Defaults to None. |
|
""" |
|
if exclude_dirs is None: |
|
exclude_dirs = [] |
|
|
|
exclude_dirs = [os.path.abspath(d) for d in exclude_dirs] |
|
|
|
audio_files = [] |
|
print("Finding audio files...") |
|
|
|
for root, dirs, files in os.walk(root_dir, topdown=True): |
|
|
|
dirs[:] = [d for d in dirs if os.path.abspath(os.path.join(root, d)) not in exclude_dirs] |
|
|
|
for filename in files: |
|
if filename.lower().endswith(('.wav', '.flac', '.mp3')): |
|
file_path = os.path.join(root, filename) |
|
audio_files.append((file_path, root_dir)) |
|
|
|
|
|
audio_files.sort(key=lambda x: x[0]) |
|
print(f"Found {len(audio_files)} audio files to process.") |
|
|
|
|
|
num_processes = max(1, int(cpu_count() / 2)) |
|
print(f"Starting processing with {num_processes} processes...") |
|
|
|
with Pool(processes=num_processes) as pool: |
|
results = list(tqdm(pool.imap(process_file, audio_files), |
|
total=len(audio_files), |
|
desc="Processing audio files")) |
|
|
|
print(f"Writing results to {output_file}...") |
|
with open(output_file, 'w', encoding='utf-8') as file: |
|
|
|
for result in results: |
|
if result: |
|
file.write(result) |
|
|
|
print("Processing complete.") |
|
|
|
|
|
root_directory = '/home/ubuntu/respair/test_wav' |
|
output_tsv = '/home/ubuntu/X-Codec-2.0/audio_high_quality_TEST.txt' |
|
exclude_folders = [''] |
|
|
|
list_audio_files(root_directory, output_tsv, exclude_dirs=exclude_folders) |