asr-model / echopipeline.py
Sin2pi's picture
Update echopipeline.py
269a1c3 verified
raw
history blame
23.5 kB
import pyworld as pw
import os
import math
import logging
import torch
import torchaudio
import torch.nn.functional as F
import numpy as np
from typing import Optional, Dict, Union, List, Tuple, Any
from functools import partial
from datetime import datetime
from datasets import load_dataset, Audio, concatenate_datasets
from transformers.trainer_seq2seq import Seq2SeqTrainer
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
import evaluate
from dataclasses import dataclass
extractor = None
tokenizer = None
optimizer = None
scheduler = None
model = None
Residual = None
MultiheadA = None
Echo = None
metric = evaluate.load(path="wer")
@dataclass
class Dimensions:
vocab: int
text_ctx: int
text_dims: int
text_head: int
text_idx: int
mels: int
aud_ctx: int
aud_dims: int
aud_head: int
aud_idx: int
act: str
debug: List[str]
cross_attn: bool
features: List[str]
f0_rotary: bool
def align_f0(f0, ctx):
ctx = torch.tensor(ctx)
bat, length = f0.shape
if length == ctx:
return f0
frames = length / ctx
idx = torch.arange(ctx, device=f0.device)
idx = (idx * frames).long()
batch_idx = torch.arange(bat, device=f0.device).unsqueeze(1)
return f0[batch_idx, idx.unsqueeze(0).expand(bat, -1)]
@dataclass
class DataCollator:
tokenizer: Any
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 1
batch = {}
if "spectrogram" in features[0] and features[0]["spectrogram"] is not None:
spectrogram_list = [f["spectrogram"] for f in features]
max_len_feat = max(f.shape[-1] for f in spectrogram_list)
pad_spectrogram = []
for feat in spectrogram_list:
current_len = feat.shape[-1]
padding = max_len_feat - current_len
if padding > 0:
pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
else:
pad_feat = feat
pad_spectrogram.append(pad_feat)
batch["spectrogram"] = torch.stack(pad_spectrogram)
if "waveform" in features[0] and features[0]["waveform"] is not None:
waveform_list = [f["waveform"] for f in features]
max_len_wav = max(w.shape[-1] for w in waveform_list)
pad_waveforms = []
for wav in waveform_list:
current_len = wav.shape[-1]
padding = max_len_wav - current_len
if padding > 0:
if wav.ndim == 1:
wav = wav.unsqueeze(0)
pad_wav = F.pad(wav, (0, padding), mode='constant', value=pad_token_id)
else:
pad_wav = wav
pad_waveforms.append(pad_wav)
batch["waveform"] = torch.stack(pad_waveforms)
if "label" in features[0] and features[0]["label"] is not None:
labels_list = [f["label"] for f in features]
max_len = max(len(l) for l in labels_list)
all_ids = []
all_labels = []
for label in labels_list:
label_list = label.tolist() if isinstance(label, torch.Tensor) else label
decoder_input = [bos_token_id] + label_list
label_eos = label_list + [pad_token_id]
input_len = max_len + 1 - len(decoder_input)
label_len = max_len + 1 - len(label_eos)
padded_input = decoder_input + [pad_token_id] * input_len
padded_labels = label_eos + [pad_token_id] * label_len
all_ids.append(padded_input)
all_labels.append(padded_labels)
batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
if "pitch" in features[0] and features[0]["pitch"] is not None:
pitch_list = [f["pitch"] for f in features]
max_len_pitch = max(e.shape[-1] for e in pitch_list)
pad_pitch = []
for pitch in pitch_list:
current_len = pitch.shape[-1]
padding = max_len_pitch - current_len
if padding > 0:
pad_pitch_item = F.pad(pitch, (0, padding), mode='constant', value=pad_token_id)
else:
pad_pitch_item = pitch
pad_pitch.append(pad_pitch_item)
batch["pitch"] = torch.stack(pad_pitch)
if "f0" in features[0] and features[0]["f0"] is not None:
input_ids_batch = batch.get("input_ids", None)
if input_ids_batch is not None:
target_length = input_ids_batch.shape[-1]
aligned_list = []
original_list = []
for feature in features:
f0 = feature["f0"]
original_list.append(f0)
if f0.shape[-1] != target_length:
aligned_f0 = align_f0(f0.unsqueeze(0), target_length).squeeze(0)
else:
aligned_f0 = f0
aligned_list.append(aligned_f0)
batch["f0d"] = torch.stack(aligned_list)
batch["f0"] = torch.stack(original_list)
if "envelope" in features[0] and features[0]["envelope"] is not None:
env_list = [f["envelope"] for f in features]
max_len = max(f.shape[-1] for f in env_list)
pad_env = []
for feat in env_list:
current_len = feat.shape[-1]
padding = max_len_feat - current_len
if padding > 0:
pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
else:
pad_feat = feat
pad_env.append(pad_feat)
batch["envelope"] = torch.stack(pad_env)
if "phase" in features[0] and features[0]["phase"] is not None:
ph_list = [f["phase"] for f in features]
max_len = max(f.shape[-1] for f in ph_list)
pad_ph = []
for feat in ph_list:
current_len = feat.shape[-1]
padding = max_len_feat - current_len
if padding > 0:
pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
else:
pad_feat = feat
pad_ph.append(pad_feat)
batch["phase"] = torch.stack(pad_ph)
return batch
def hilbert_transform(x):
N = x.shape[-1]
xf = torch.fft.rfft(x)
h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype)
if N % 2 == 0:
h[0] = h[N//2] = 1
h[1:N//2] = 2
else:
h[0] = 1
h[1:(N+1)//2] = 2
return torch.fft.irfft(xf * h, n=N)
def analytic_signal(x):
return x + 1j * hilbert_transform(x)
def hilbert_transform_2d(x, dim=-1):
N = x.shape[dim]
if dim == -1 or dim == len(x.shape) - 1:
xf = torch.fft.rfft(x)
else:
xf = torch.fft.rfft(x, dim=dim)
h_shape = [1] * len(x.shape)
h_shape[dim] = N // 2 + 1
h = torch.zeros(h_shape, device=x.device, dtype=x.dtype)
if dim == -1 or dim == len(x.shape) - 1:
if N % 2 == 0:
h[..., 0] = h[..., -1] = 1
h[..., 1:-1] = 2
else:
h[..., 0] = 1
h[..., 1:] = 2
else:
pass
return torch.fft.irfft(xf * h, n=N, dim=dim)
def hilbert_transform_true_2d(x):
xf = torch.fft.rfft2(x)
h1, h2 = torch.meshgrid(
torch.fft.rfftfreq(x.shape[-2]) * 2 - 1,
torch.fft.rfftfreq(x.shape[-1]) * 2 - 1,
indexing='ij')
h = -1j / (math.pi * (h1 + 1j*h2))
h[0, 0] = 0
return torch.fft.irfft2(xf * h.to(x.device))
def process_spectrogram_with_hilbert(spec):
analytic = spec + 1j * hilbert_transform(spec)
envelope = torch.abs(analytic)
phase = torch.angle(analytic)
return envelope, phase
def load_wave(wave_data, sample_rate):
if isinstance(wave_data, str):
waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
elif isinstance(wave_data, dict):
waveform = torch.tensor(data=wave_data["array"]).float()
sr = wave_data["sampling_rate"]
else:
raise TypeError("Invalid wave_data format.")
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
if sr != sample_rate:
original_length = waveform.shape[1]
target_length = int(original_length * (sample_rate / sr))
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
waveform = resampler(waveform)
return waveform.flatten()
def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False,
hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000,
pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk",
norm=None, normalized=False, downsamples=False, period=False, hilbert=False):
dtype = torch.float32
device = torch.device("cuda:0")
audio = batch["audio"]
sampling_rate = audio["sampling_rate"]
sr = audio["sampling_rate"]
wav = load_wave(wave_data=audio, sample_rate=sr)
if spectrogram:
transform = torchaudio.transforms.MelSpectrogram(
f_max=fmax,
f_min=fmin,
n_mels=n_mels,
sample_rate=sr,
n_fft=n_fft,
hop_length=hop_length,
norm=norm,
normalized=normalized,
power=power,
center=center,
mel_scale=mel_scale,
window_fn=window_fn,
pad_mode=pad_mode)
mel_spectrogram = transform(wav)
log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
spec = (log_mel + 4.0) / 4.0
spec = torch.tensor(spec)
batch["spectrogram"] = spec
if hilbert:
envelope_list = []
phase_list = []
for ch_idx in range(spec.shape[0]):
envelope, phase = process_spectrogram_with_hilbert(spec[ch_idx])
envelope_list.append(envelope)
phase_list.append(phase)
batch["envelope"] = torch.stack(envelope_list)
batch["phase"] = torch.stack(phase_list)
wav_1d = wav.unsqueeze(0)
if waveforms:
batch["waveform"] = wav_1d
if pitch:
wav_np = wav.numpy().astype(np.float64)
f0, t = pw.dio(wav_np, sampling_rate,
frame_period=hop_length/sampling_rate*1000)
f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
f0 = torch.from_numpy(f0).float()
batch["pitch"] = f0.unsqueeze(0)
if frequency:
wav_np = wav.numpy().astype(np.float64)
f0, t = pw.dio(wav_np, sampling_rate,
frame_period=hop_length/sampling_rate*1000)
f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
f0 = f0
batch["f0"] = torch.from_numpy(f0).float()
if spectrogram and waveforms and pitch:
spec_mean = batch["spectrogram"].mean()
spec_std = batch["spectrogram"].std() + 1e-6
batch["spectrogram"] = (batch["spectrogram"] - spec_mean) / spec_std
wav_mean = batch["waveform"].mean()
wav_std = batch["waveform"].std() + 1e-6
batch["waveform"] = (batch["waveform"] - wav_mean) / wav_std
if batch["pitch"].max() > 1.0:
pitch_min = 50.0
pitch_max = 600.0
batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
return batch
def compute_metrics(eval_pred, compute_result: bool = True,
print_pred: bool = False, num_samples: int = 0, tokenizer=None, pitch=None, model=None):
pred_logits = eval_pred.predictions
label_ids = eval_pred.label_ids
if hasattr(pred_logits, "cpu"):
pred_logits = pred_logits.cpu()
if hasattr(label_ids, "cpu"):
label_ids = label_ids.cpu()
if isinstance(pred_logits, tuple):
pred_ids = pred_logits[0]
else:
pred_ids = pred_logits
if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3:
if not isinstance(pred_ids, torch.Tensor):
pred_ids = torch.tensor(pred_ids)
pred_ids = pred_ids.argmax(dim=-1)
pred_ids = pred_ids.tolist()
if hasattr(label_ids, "tolist"):
label_ids = label_ids.tolist()
label_ids = [[0 if token == -100 else token for token in seq] for seq in label_ids]
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
if print_pred:
for i in range(min(num_samples, len(pred_str))):
print(f"Preds: {pred_str[i]}")
print(f"Label: {label_str[i]}")
print(f"preds: {pred_ids[i]}")
print(f"label: {label_ids[i]}")
print("--------------------------------")
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
if model is None:
global global_model
if 'global_model' in globals():
model = global_model
if model is not None:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
if trainable_params > 0:
efficiency_score = (100 - wer) / trainable_params
else:
print("Warning: Zero trainable parameters detected")
efficiency_score = 0.0
else:
print("Warning: Model not available for parameter counting")
trainable_params = 0.0
efficiency_score = 0.0
if hasattr(wer, "item"):
wer = wer.item()
metrics = {
"wer": float(wer),
"trainable_params_M": float(trainable_params),
"efficiency_score": float(efficiency_score),
}
return metrics
logger = logging.getLogger(__name__)
def create_model(param: Dimensions) -> Echo:
model = Echo(param).to('cuda')
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
logger.info(f"Trainable parameters: {trainable_params:,}")
logger.info(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Total parameters: {total_params:,}")
return model
def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/tokenn/"):
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
orig_encode = tokenizer.encode
def enc(text, add_special_tokens=True):
ids = orig_encode(text).ids
if not add_special_tokens:
sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
ids = [id for id in ids if id not in sp_ids]
return ids
def bdec(ids_list, skip_special_tokens=True):
results = []
for ids in ids_list:
if skip_special_tokens:
ids = [id for id in ids if id not in [0, 1, 2]]
results.append(tokenizer.decode(ids))
return results
def save_pretrained(save_dir):
os.makedirs(save_dir, exist_ok=True)
tokenizer.save(f"{save_dir}/tokenizer.json")
tokenizer.encode = enc
tokenizer.batch_decode = bdec
tokenizer.save_pretrained = save_pretrained
tokenizer.pad_token_id = 0
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
return tokenizer
def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]:
if dataset_config is None:
dataset_config = {
"spectrogram": True,
"waveforms": True,
"pitch": True,
"frequency": True,
"downsamples": True,
"hop_length": 128,
"fmin": 50,
"fmax": 2000,
"n_mels": 128,
"n_fft": 1024,
"sampling_rate": 16000,
}
dataset = load_dataset(
"google/fleurs",
"en_us",
token=token,
trust_remote_code=True,
streaming=False)
dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
if sanity_check:
dataset = dataset["test"].take(10)
dataset = dataset.select_columns(["audio", "transcription"])
logger.info(f"Sanity dataset size: {dataset.num_rows}")
print(f"Sanity dataset size: {dataset.num_rows}")
prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
dataset = dataset.map(
function=prepare_fn,
remove_columns=["audio", "transcription"]
).with_format(type="torch")
train_dataset = dataset
test_dataset = dataset
else:
def filter_func(x):
return (0 < len(x["transcription"]) < 512 and
len(x["audio"]["array"]) > 0 and
len(x["audio"]["array"]) < 1500 * 160)
dataset = dataset.filter(filter_func).shuffle(seed=4)
logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
columns_to_remove = list(next(iter(dataset.values())).features)
train_dataset = dataset["train"]
test_dataset = dataset["test"].take(50)
logger.info(f"Train dataset size: {train_dataset.num_rows}, Test dataset size: {test_dataset.num_rows}")
train_dataset = train_dataset.map(
function=prepare_fn,
remove_columns=columns_to_remove
).with_format(type="torch")
test_dataset = test_dataset.map(
function=prepare_fn,
remove_columns=columns_to_remove
).with_format(type="torch")
return train_dataset, test_dataset
def get_training_args(
log_dir: str,
batch_eval_metrics: bool = False,
max_steps: int = 10,
save_steps: int = 1000,
eval_steps: int = 1,
warmup_steps: int = 0,
num_train_epochs: int = 1,
logging_steps: int = 1,
eval_on_start: bool = False,
learning_rate: float = 1e-4,
weight_decay: float = 0.01,
max_grad_norm: float = 1.0,
) -> Seq2SeqTrainingArguments:
return Seq2SeqTrainingArguments(
output_dir=log_dir,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=1,
eval_accumulation_steps=1,
tf32=True,
bf16=True,
eval_strategy="steps",
save_strategy="steps",
max_steps=max_steps,
save_steps=save_steps,
eval_steps=eval_steps,
warmup_steps=warmup_steps,
num_train_epochs=num_train_epochs,
logging_steps=logging_steps,
logging_dir=log_dir,
logging_strategy="steps",
report_to=["tensorboard"],
push_to_hub=False,
disable_tqdm=False,
save_total_limit=1,
label_names=["labels"],
optim="adamw_torch",
lr_scheduler_type="cosine",
learning_rate=learning_rate,
weight_decay=weight_decay,
save_safetensors=False,
eval_on_start=eval_on_start,
batch_eval_metrics=batch_eval_metrics,
max_grad_norm=max_grad_norm,
)
def main():
token = ""
log_dir = os.path.join('./output/logs', datetime.now().strftime(format='%m-%d_%H'))
os.makedirs(name=log_dir, exist_ok=True)
tokenizer = setup_tokenizer(token)
def sanity(sanity: bool):
if sanity:
training_args = get_training_args(
log_dir,
batch_eval_metrics = False,
max_steps = 10,
save_steps = 0,
eval_steps = 1,
warmup_steps = 0,
logging_steps = 1,
eval_on_start = False,
learning_rate = 5e-6,
weight_decay = 0.01,
)
else:
training_args = get_training_args(
log_dir,
batch_eval_metrics = False,
max_steps = 1000,
save_steps = 1000,
eval_steps = 100,
warmup_steps = 100,
logging_steps = 10,
eval_on_start = False,
learning_rate = 2.5e-4,
weight_decay = 0.01,
)
return training_args
param = Dimensions(
mels=128,
aud_ctx=1500,
aud_head=4,
aud_dims=512,
aud_idx=4,
vocab=40000,
text_ctx=512,
text_head=4,
text_dims=512,
text_idx=4,
act="swish",
debug={},#{"encoder", "decoder", "residual", "rotary"},
cross_attn=True,
f0_rotary=False,
features = ["spectrogram"]#, "waveform", "pitch", "f0", "envelope", "phase"],
)
sanity_check = False
training_args = sanity(sanity_check)
dataset_config = {
"spectrogram": True,
"waveforms": False,
"pitch": False,
"downsamples": False,
"frequency": True,
"hilbert": False,
"hop_length": 128,
"fmin": 150,
"fmax": 2000,
"n_mels": 128,
"n_fft": 1024,
"sampling_rate": 16000,
"pad_mode": "constant",
"center": True,
"power": 2.0,
"window_fn": torch.hann_window,
"mel_scale": "htk",
"norm": None,
"normalized": False}
model = create_model(param)
global global_model
global_model = model
metrics_fn = partial(compute_metrics, print_pred=False, num_samples=5,
tokenizer=tokenizer, model=model)
print(f"{'Sanity check' if sanity_check else 'Training'} mode")
train_dataset, test_dataset = prepare_datasets(
tokenizer=tokenizer,
token=token,
sanity_check=sanity_check,
dataset_config=dataset_config)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=train_dataset,
eval_dataset=test_dataset,
data_collator=DataCollator(tokenizer=tokenizer),
compute_metrics=metrics_fn,
)
model.init_weights()
trainer.train()
if __name__ == "__main__":
main()