mrfakename
commited on
Commit
•
648ac03
1
Parent(s):
5dc7366
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- scripts/prepare_csv_wavs.py +132 -0
scripts/prepare_csv_wavs.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
sys.path.append(os.getcwd())
|
3 |
+
|
4 |
+
from pathlib import Path
|
5 |
+
import json
|
6 |
+
import shutil
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import csv
|
10 |
+
import torchaudio
|
11 |
+
from tqdm import tqdm
|
12 |
+
from datasets.arrow_writer import ArrowWriter
|
13 |
+
|
14 |
+
from model.utils import (
|
15 |
+
convert_char_to_pinyin,
|
16 |
+
)
|
17 |
+
|
18 |
+
PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
19 |
+
|
20 |
+
def is_csv_wavs_format(input_dataset_dir):
|
21 |
+
fpath = Path(input_dataset_dir)
|
22 |
+
metadata = fpath / "metadata.csv"
|
23 |
+
wavs = fpath / 'wavs'
|
24 |
+
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
|
25 |
+
|
26 |
+
|
27 |
+
def prepare_csv_wavs_dir(input_dir):
|
28 |
+
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
|
29 |
+
input_dir = Path(input_dir)
|
30 |
+
metadata_path = input_dir / "metadata.csv"
|
31 |
+
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
|
32 |
+
|
33 |
+
sub_result, durations = [], []
|
34 |
+
vocab_set = set()
|
35 |
+
polyphone = True
|
36 |
+
for audio_path, text in audio_path_text_pairs:
|
37 |
+
if not Path(audio_path).exists():
|
38 |
+
print(f"audio {audio_path} not found, skipping")
|
39 |
+
continue
|
40 |
+
audio_duration = get_audio_duration(audio_path)
|
41 |
+
# assume tokenizer = "pinyin" ("pinyin" | "char")
|
42 |
+
text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
|
43 |
+
sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration})
|
44 |
+
durations.append(audio_duration)
|
45 |
+
vocab_set.update(list(text))
|
46 |
+
|
47 |
+
return sub_result, durations, vocab_set
|
48 |
+
|
49 |
+
def get_audio_duration(audio_path):
|
50 |
+
audio, sample_rate = torchaudio.load(audio_path)
|
51 |
+
num_channels = audio.shape[0]
|
52 |
+
return audio.shape[1] / (sample_rate * num_channels)
|
53 |
+
|
54 |
+
def read_audio_text_pairs(csv_file_path):
|
55 |
+
audio_text_pairs = []
|
56 |
+
|
57 |
+
parent = Path(csv_file_path).parent
|
58 |
+
with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csvfile:
|
59 |
+
reader = csv.reader(csvfile, delimiter='|')
|
60 |
+
next(reader) # Skip the header row
|
61 |
+
for row in reader:
|
62 |
+
if len(row) >= 2:
|
63 |
+
audio_file = row[0].strip() # First column: audio file path
|
64 |
+
text = row[1].strip() # Second column: text
|
65 |
+
audio_file_path = parent / audio_file
|
66 |
+
audio_text_pairs.append((audio_file_path.as_posix(), text))
|
67 |
+
|
68 |
+
return audio_text_pairs
|
69 |
+
|
70 |
+
|
71 |
+
def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
|
72 |
+
out_dir = Path(out_dir)
|
73 |
+
# save preprocessed dataset to disk
|
74 |
+
out_dir.mkdir(exist_ok=True, parents=True)
|
75 |
+
print(f"\nSaving to {out_dir} ...")
|
76 |
+
|
77 |
+
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
78 |
+
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
|
79 |
+
raw_arrow_path = out_dir / "raw.arrow"
|
80 |
+
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
|
81 |
+
for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
|
82 |
+
writer.write(line)
|
83 |
+
|
84 |
+
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
85 |
+
dur_json_path = out_dir / "duration.json"
|
86 |
+
with open(dur_json_path.as_posix(), 'w', encoding='utf-8') as f:
|
87 |
+
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
88 |
+
|
89 |
+
# vocab map, i.e. tokenizer
|
90 |
+
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
91 |
+
# if tokenizer == "pinyin":
|
92 |
+
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
93 |
+
voca_out_path = out_dir / "vocab.txt"
|
94 |
+
with open(voca_out_path.as_posix(), "w") as f:
|
95 |
+
for vocab in sorted(text_vocab_set):
|
96 |
+
f.write(vocab + "\n")
|
97 |
+
|
98 |
+
if is_finetune:
|
99 |
+
file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
|
100 |
+
shutil.copy2(file_vocab_finetune, voca_out_path)
|
101 |
+
else:
|
102 |
+
with open(voca_out_path, "w") as f:
|
103 |
+
for vocab in sorted(text_vocab_set):
|
104 |
+
f.write(vocab + "\n")
|
105 |
+
|
106 |
+
dataset_name = out_dir.stem
|
107 |
+
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
108 |
+
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
109 |
+
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
110 |
+
|
111 |
+
|
112 |
+
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
|
113 |
+
if is_finetune:
|
114 |
+
assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
|
115 |
+
sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
|
116 |
+
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
|
117 |
+
|
118 |
+
|
119 |
+
def cli():
|
120 |
+
# finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
|
121 |
+
# pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
|
122 |
+
parser = argparse.ArgumentParser(description="Prepare and save dataset.")
|
123 |
+
parser.add_argument('inp_dir', type=str, help="Input directory containing the data.")
|
124 |
+
parser.add_argument('out_dir', type=str, help="Output directory to save the prepared data.")
|
125 |
+
parser.add_argument('--pretrain', action='store_true', help="Enable for new pretrain, otherwise is a fine-tune")
|
126 |
+
|
127 |
+
args = parser.parse_args()
|
128 |
+
|
129 |
+
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
cli()
|