Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import signal | |
import subprocess # For invoking ffprobe | |
import shutil | |
import concurrent.futures | |
import multiprocessing | |
from contextlib import contextmanager | |
sys.path.append(os.getcwd()) | |
import argparse | |
import csv | |
import json | |
from importlib.resources import files | |
from pathlib import Path | |
import torchaudio | |
from tqdm import tqdm | |
from datasets.arrow_writer import ArrowWriter | |
from f5_tts.model.utils import ( | |
convert_char_to_pinyin, | |
) | |
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt") | |
def is_csv_wavs_format(input_dataset_dir): | |
fpath = Path(input_dataset_dir) | |
metadata = fpath / "metadata.csv" | |
wavs = fpath / "wavs" | |
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir() | |
# Configuration constants | |
BATCH_SIZE = 100 # Batch size for text conversion | |
MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1) # Leave one CPU free | |
THREAD_NAME_PREFIX = "AudioProcessor" | |
CHUNK_SIZE = 100 # Number of files to process per worker batch | |
executor = None # Global executor for cleanup | |
def graceful_exit(): | |
"""Context manager for graceful shutdown on signals""" | |
def signal_handler(signum, frame): | |
print("\nReceived signal to terminate. Cleaning up...") | |
if executor is not None: | |
print("Shutting down executor...") | |
executor.shutdown(wait=False, cancel_futures=True) | |
sys.exit(1) | |
# Set up signal handlers | |
signal.signal(signal.SIGINT, signal_handler) | |
signal.signal(signal.SIGTERM, signal_handler) | |
try: | |
yield | |
finally: | |
if executor is not None: | |
executor.shutdown(wait=False) | |
def process_audio_file(audio_path, text, polyphone): | |
"""Process a single audio file by checking its existence and extracting duration.""" | |
if not Path(audio_path).exists(): | |
print(f"audio {audio_path} not found, skipping") | |
return None | |
try: | |
audio_duration = get_audio_duration(audio_path) | |
if audio_duration <= 0: | |
raise ValueError(f"Duration {audio_duration} is non-positive.") | |
return (audio_path, text, audio_duration) | |
except Exception as e: | |
print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.") | |
return None | |
def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE): | |
"""Convert a list of texts to pinyin in batches.""" | |
converted_texts = [] | |
for i in range(0, len(texts), batch_size): | |
batch = texts[i : i + batch_size] | |
converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone) | |
converted_texts.extend(converted_batch) | |
return converted_texts | |
def prepare_csv_wavs_dir(input_dir, num_workers=None): | |
global executor | |
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}" | |
input_dir = Path(input_dir) | |
metadata_path = input_dir / "metadata.csv" | |
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix()) | |
polyphone = True | |
total_files = len(audio_path_text_pairs) | |
# Use provided worker count or calculate optimal number | |
worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files) | |
print(f"\nProcessing {total_files} audio files using {worker_count} workers...") | |
with graceful_exit(): | |
# Initialize thread pool with optimized settings | |
with concurrent.futures.ThreadPoolExecutor( | |
max_workers=worker_count, thread_name_prefix=THREAD_NAME_PREFIX | |
) as exec: | |
executor = exec | |
results = [] | |
# Process files in chunks for better efficiency | |
for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE): | |
chunk = audio_path_text_pairs[i : i + CHUNK_SIZE] | |
# Submit futures in order | |
chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk] | |
# Iterate over futures in the original submission order to preserve ordering | |
for future in tqdm( | |
chunk_futures, | |
total=len(chunk), | |
desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}", | |
): | |
try: | |
result = future.result() | |
if result is not None: | |
results.append(result) | |
except Exception as e: | |
print(f"Error processing file: {e}") | |
executor = None | |
# Filter out failed results | |
processed = [res for res in results if res is not None] | |
if not processed: | |
raise RuntimeError("No valid audio files were processed!") | |
# Batch process text conversion | |
raw_texts = [item[1] for item in processed] | |
converted_texts = batch_convert_texts(raw_texts, polyphone, batch_size=BATCH_SIZE) | |
# Prepare final results | |
sub_result = [] | |
durations = [] | |
vocab_set = set() | |
for (audio_path, _, duration), conv_text in zip(processed, converted_texts): | |
sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration}) | |
durations.append(duration) | |
vocab_set.update(list(conv_text)) | |
return sub_result, durations, vocab_set | |
def get_audio_duration(audio_path, timeout=5): | |
""" | |
Get the duration of an audio file in seconds using ffmpeg's ffprobe. | |
Falls back to torchaudio.load() if ffprobe fails. | |
""" | |
try: | |
cmd = [ | |
"ffprobe", | |
"-v", | |
"error", | |
"-show_entries", | |
"format=duration", | |
"-of", | |
"default=noprint_wrappers=1:nokey=1", | |
audio_path, | |
] | |
result = subprocess.run( | |
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout | |
) | |
duration_str = result.stdout.strip() | |
if duration_str: | |
return float(duration_str) | |
raise ValueError("Empty duration string from ffprobe.") | |
except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e: | |
print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.") | |
try: | |
audio, sample_rate = torchaudio.load(audio_path) | |
return audio.shape[1] / sample_rate | |
except Exception as e: | |
raise RuntimeError(f"Both ffprobe and torchaudio failed for {audio_path}: {e}") | |
def read_audio_text_pairs(csv_file_path): | |
audio_text_pairs = [] | |
parent = Path(csv_file_path).parent | |
with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile: | |
reader = csv.reader(csvfile, delimiter="|") | |
next(reader) # Skip the header row | |
for row in reader: | |
if len(row) >= 2: | |
audio_file = row[0].strip() # First column: audio file path | |
text = row[1].strip() # Second column: text | |
audio_file_path = parent / audio_file | |
audio_text_pairs.append((audio_file_path.as_posix(), text)) | |
return audio_text_pairs | |
def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune): | |
out_dir = Path(out_dir) | |
out_dir.mkdir(exist_ok=True, parents=True) | |
print(f"\nSaving to {out_dir} ...") | |
# Save dataset with improved batch size for better I/O performance | |
raw_arrow_path = out_dir / "raw.arrow" | |
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer: | |
for line in tqdm(result, desc="Writing to raw.arrow ..."): | |
writer.write(line) | |
# Save durations to JSON | |
dur_json_path = out_dir / "duration.json" | |
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f: | |
json.dump({"duration": duration_list}, f, ensure_ascii=False) | |
# Handle vocab file - write only once based on finetune flag | |
voca_out_path = out_dir / "vocab.txt" | |
if is_finetune: | |
file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix() | |
shutil.copy2(file_vocab_finetune, voca_out_path) | |
else: | |
with open(voca_out_path.as_posix(), "w") as f: | |
for vocab in sorted(text_vocab_set): | |
f.write(vocab + "\n") | |
dataset_name = out_dir.stem | |
print(f"\nFor {dataset_name}, sample count: {len(result)}") | |
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") | |
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours") | |
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None): | |
if is_finetune: | |
assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}" | |
sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir, num_workers=num_workers) | |
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune) | |
def cli(): | |
try: | |
# Before processing, check if ffprobe is available. | |
if shutil.which("ffprobe") is None: | |
print( | |
"Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower)." | |
) | |
# Usage examples in help text | |
parser = argparse.ArgumentParser( | |
description="Prepare and save dataset.", | |
epilog=""" | |
Examples: | |
# For fine-tuning (default): | |
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path | |
# For pre-training: | |
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --pretrain | |
# With custom worker count: | |
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4 | |
""", | |
) | |
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.") | |
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.") | |
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune") | |
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})") | |
args = parser.parse_args() | |
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers) | |
except KeyboardInterrupt: | |
print("\nOperation cancelled by user. Cleaning up...") | |
if executor is not None: | |
executor.shutdown(wait=False, cancel_futures=True) | |
sys.exit(1) | |
if __name__ == "__main__": | |
cli() | |