|
import pyworld as pw |
|
import os |
|
import math |
|
import logging |
|
import torch |
|
import torchaudio |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from typing import Optional, Dict, Union, List, Tuple, Any |
|
from functools import partial |
|
from datetime import datetime |
|
from datasets import load_dataset, Audio, concatenate_datasets |
|
from transformers.trainer_seq2seq import Seq2SeqTrainer |
|
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments |
|
import evaluate |
|
from dataclasses import dataclass |
|
|
|
extractor = None |
|
tokenizer = None |
|
optimizer = None |
|
scheduler = None |
|
model = None |
|
Residual = None |
|
MultiheadA = None |
|
Echo = None |
|
|
|
metric = evaluate.load(path="wer") |
|
|
|
@dataclass |
|
class Dimensions: |
|
vocab: int |
|
text_ctx: int |
|
text_dims: int |
|
text_head: int |
|
text_idx: int |
|
mels: int |
|
aud_ctx: int |
|
aud_dims: int |
|
aud_head: int |
|
aud_idx: int |
|
act: str |
|
debug: List[str] |
|
cross_attn: bool |
|
features: List[str] |
|
f0_rotary: bool |
|
|
|
def align_f0(f0, ctx): |
|
ctx = torch.tensor(ctx) |
|
bat, length = f0.shape |
|
if length == ctx: |
|
return f0 |
|
frames = length / ctx |
|
idx = torch.arange(ctx, device=f0.device) |
|
idx = (idx * frames).long() |
|
batch_idx = torch.arange(bat, device=f0.device).unsqueeze(1) |
|
return f0[batch_idx, idx.unsqueeze(0).expand(bat, -1)] |
|
|
|
@dataclass |
|
class DataCollator: |
|
tokenizer: Any |
|
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
|
pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0 |
|
bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 1 |
|
|
|
batch = {} |
|
|
|
if "spectrogram" in features[0] and features[0]["spectrogram"] is not None: |
|
spectrogram_list = [f["spectrogram"] for f in features] |
|
max_len_feat = max(f.shape[-1] for f in spectrogram_list) |
|
pad_spectrogram = [] |
|
for feat in spectrogram_list: |
|
current_len = feat.shape[-1] |
|
padding = max_len_feat - current_len |
|
if padding > 0: |
|
pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id) |
|
else: |
|
pad_feat = feat |
|
pad_spectrogram.append(pad_feat) |
|
batch["spectrogram"] = torch.stack(pad_spectrogram) |
|
|
|
if "waveform" in features[0] and features[0]["waveform"] is not None: |
|
waveform_list = [f["waveform"] for f in features] |
|
max_len_wav = max(w.shape[-1] for w in waveform_list) |
|
pad_waveforms = [] |
|
for wav in waveform_list: |
|
current_len = wav.shape[-1] |
|
padding = max_len_wav - current_len |
|
if padding > 0: |
|
if wav.ndim == 1: |
|
wav = wav.unsqueeze(0) |
|
pad_wav = F.pad(wav, (0, padding), mode='constant', value=pad_token_id) |
|
else: |
|
pad_wav = wav |
|
pad_waveforms.append(pad_wav) |
|
batch["waveform"] = torch.stack(pad_waveforms) |
|
|
|
if "label" in features[0] and features[0]["label"] is not None: |
|
labels_list = [f["label"] for f in features] |
|
max_len = max(len(l) for l in labels_list) |
|
all_ids = [] |
|
all_labels = [] |
|
|
|
for label in labels_list: |
|
label_list = label.tolist() if isinstance(label, torch.Tensor) else label |
|
decoder_input = [bos_token_id] + label_list |
|
label_eos = label_list + [pad_token_id] |
|
input_len = max_len + 1 - len(decoder_input) |
|
label_len = max_len + 1 - len(label_eos) |
|
padded_input = decoder_input + [pad_token_id] * input_len |
|
padded_labels = label_eos + [pad_token_id] * label_len |
|
all_ids.append(padded_input) |
|
all_labels.append(padded_labels) |
|
batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long) |
|
batch["labels"] = torch.tensor(all_labels, dtype=torch.long) |
|
|
|
if "pitch" in features[0] and features[0]["pitch"] is not None: |
|
pitch_list = [f["pitch"] for f in features] |
|
max_len_pitch = max(e.shape[-1] for e in pitch_list) |
|
pad_pitch = [] |
|
for pitch in pitch_list: |
|
current_len = pitch.shape[-1] |
|
padding = max_len_pitch - current_len |
|
if padding > 0: |
|
pad_pitch_item = F.pad(pitch, (0, padding), mode='constant', value=pad_token_id) |
|
else: |
|
pad_pitch_item = pitch |
|
pad_pitch.append(pad_pitch_item) |
|
batch["pitch"] = torch.stack(pad_pitch) |
|
|
|
if "f0" in features[0] and features[0]["f0"] is not None: |
|
input_ids_batch = batch.get("input_ids", None) |
|
if input_ids_batch is not None: |
|
target_length = input_ids_batch.shape[-1] |
|
aligned_list = [] |
|
original_list = [] |
|
for feature in features: |
|
f0 = feature["f0"] |
|
original_list.append(f0) |
|
if f0.shape[-1] != target_length: |
|
aligned_f0 = align_f0(f0.unsqueeze(0), target_length).squeeze(0) |
|
else: |
|
aligned_f0 = f0 |
|
aligned_list.append(aligned_f0) |
|
batch["f0d"] = torch.stack(aligned_list) |
|
batch["f0"] = torch.stack(original_list) |
|
|
|
if "envelope" in features[0] and features[0]["envelope"] is not None: |
|
env_list = [f["envelope"] for f in features] |
|
max_len = max(f.shape[-1] for f in env_list) |
|
pad_env = [] |
|
for feat in env_list: |
|
current_len = feat.shape[-1] |
|
padding = max_len_feat - current_len |
|
if padding > 0: |
|
pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id) |
|
else: |
|
pad_feat = feat |
|
pad_env.append(pad_feat) |
|
batch["envelope"] = torch.stack(pad_env) |
|
|
|
if "phase" in features[0] and features[0]["phase"] is not None: |
|
ph_list = [f["phase"] for f in features] |
|
max_len = max(f.shape[-1] for f in ph_list) |
|
pad_ph = [] |
|
for feat in ph_list: |
|
current_len = feat.shape[-1] |
|
padding = max_len_feat - current_len |
|
if padding > 0: |
|
pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id) |
|
else: |
|
pad_feat = feat |
|
pad_ph.append(pad_feat) |
|
batch["phase"] = torch.stack(pad_ph) |
|
return batch |
|
|
|
def hilbert_transform(x): |
|
N = x.shape[-1] |
|
xf = torch.fft.rfft(x) |
|
h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype) |
|
if N % 2 == 0: |
|
h[0] = h[N//2] = 1 |
|
h[1:N//2] = 2 |
|
else: |
|
h[0] = 1 |
|
h[1:(N+1)//2] = 2 |
|
return torch.fft.irfft(xf * h, n=N) |
|
|
|
def analytic_signal(x): |
|
return x + 1j * hilbert_transform(x) |
|
|
|
def hilbert_transform_2d(x, dim=-1): |
|
N = x.shape[dim] |
|
if dim == -1 or dim == len(x.shape) - 1: |
|
xf = torch.fft.rfft(x) |
|
else: |
|
xf = torch.fft.rfft(x, dim=dim) |
|
h_shape = [1] * len(x.shape) |
|
h_shape[dim] = N // 2 + 1 |
|
h = torch.zeros(h_shape, device=x.device, dtype=x.dtype) |
|
if dim == -1 or dim == len(x.shape) - 1: |
|
if N % 2 == 0: |
|
h[..., 0] = h[..., -1] = 1 |
|
h[..., 1:-1] = 2 |
|
else: |
|
h[..., 0] = 1 |
|
h[..., 1:] = 2 |
|
else: |
|
pass |
|
return torch.fft.irfft(xf * h, n=N, dim=dim) |
|
|
|
def hilbert_transform_true_2d(x): |
|
xf = torch.fft.rfft2(x) |
|
h1, h2 = torch.meshgrid( |
|
torch.fft.rfftfreq(x.shape[-2]) * 2 - 1, |
|
torch.fft.rfftfreq(x.shape[-1]) * 2 - 1, |
|
indexing='ij') |
|
h = -1j / (math.pi * (h1 + 1j*h2)) |
|
h[0, 0] = 0 |
|
return torch.fft.irfft2(xf * h.to(x.device)) |
|
|
|
def process_spectrogram_with_hilbert(spec): |
|
analytic = spec + 1j * hilbert_transform(spec) |
|
envelope = torch.abs(analytic) |
|
phase = torch.angle(analytic) |
|
return envelope, phase |
|
|
|
def load_wave(wave_data, sample_rate): |
|
if isinstance(wave_data, str): |
|
waveform, sr = torchaudio.load(uri=wave_data, normalize=False) |
|
elif isinstance(wave_data, dict): |
|
waveform = torch.tensor(data=wave_data["array"]).float() |
|
sr = wave_data["sampling_rate"] |
|
else: |
|
raise TypeError("Invalid wave_data format.") |
|
|
|
if waveform.dim() == 1: |
|
waveform = waveform.unsqueeze(0) |
|
|
|
if sr != sample_rate: |
|
original_length = waveform.shape[1] |
|
target_length = int(original_length * (sample_rate / sr)) |
|
|
|
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) |
|
waveform = resampler(waveform) |
|
|
|
return waveform.flatten() |
|
|
|
def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False, |
|
hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000, |
|
pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk", |
|
norm=None, normalized=False, downsamples=False, period=False, hilbert=False): |
|
|
|
dtype = torch.float32 |
|
device = torch.device("cuda:0") |
|
audio = batch["audio"] |
|
sampling_rate = audio["sampling_rate"] |
|
sr = audio["sampling_rate"] |
|
wav = load_wave(wave_data=audio, sample_rate=sr) |
|
|
|
if spectrogram: |
|
transform = torchaudio.transforms.MelSpectrogram( |
|
f_max=fmax, |
|
f_min=fmin, |
|
n_mels=n_mels, |
|
sample_rate=sr, |
|
n_fft=n_fft, |
|
hop_length=hop_length, |
|
norm=norm, |
|
normalized=normalized, |
|
power=power, |
|
center=center, |
|
mel_scale=mel_scale, |
|
window_fn=window_fn, |
|
pad_mode=pad_mode) |
|
|
|
mel_spectrogram = transform(wav) |
|
log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10() |
|
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0) |
|
spec = (log_mel + 4.0) / 4.0 |
|
spec = torch.tensor(spec) |
|
batch["spectrogram"] = spec |
|
|
|
if hilbert: |
|
envelope_list = [] |
|
phase_list = [] |
|
|
|
for ch_idx in range(spec.shape[0]): |
|
envelope, phase = process_spectrogram_with_hilbert(spec[ch_idx]) |
|
envelope_list.append(envelope) |
|
phase_list.append(phase) |
|
|
|
batch["envelope"] = torch.stack(envelope_list) |
|
batch["phase"] = torch.stack(phase_list) |
|
|
|
wav_1d = wav.unsqueeze(0) |
|
|
|
if waveforms: |
|
batch["waveform"] = wav_1d |
|
|
|
if pitch: |
|
wav_np = wav.numpy().astype(np.float64) |
|
f0, t = pw.dio(wav_np, sampling_rate, |
|
frame_period=hop_length/sampling_rate*1000) |
|
f0 = pw.stonemask(wav_np, f0, t, sampling_rate) |
|
f0 = torch.from_numpy(f0).float() |
|
batch["pitch"] = f0.unsqueeze(0) |
|
|
|
if frequency: |
|
wav_np = wav.numpy().astype(np.float64) |
|
f0, t = pw.dio(wav_np, sampling_rate, |
|
frame_period=hop_length/sampling_rate*1000) |
|
f0 = pw.stonemask(wav_np, f0, t, sampling_rate) |
|
f0 = f0 |
|
batch["f0"] = torch.from_numpy(f0).float() |
|
|
|
if spectrogram and waveforms and pitch: |
|
spec_mean = batch["spectrogram"].mean() |
|
spec_std = batch["spectrogram"].std() + 1e-6 |
|
batch["spectrogram"] = (batch["spectrogram"] - spec_mean) / spec_std |
|
|
|
wav_mean = batch["waveform"].mean() |
|
wav_std = batch["waveform"].std() + 1e-6 |
|
batch["waveform"] = (batch["waveform"] - wav_mean) / wav_std |
|
|
|
if batch["pitch"].max() > 1.0: |
|
pitch_min = 50.0 |
|
pitch_max = 600.0 |
|
batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min) |
|
|
|
batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False) |
|
return batch |
|
|
|
def compute_metrics(eval_pred, compute_result: bool = True, |
|
print_pred: bool = False, num_samples: int = 0, tokenizer=None, pitch=None, model=None): |
|
|
|
pred_logits = eval_pred.predictions |
|
label_ids = eval_pred.label_ids |
|
|
|
if hasattr(pred_logits, "cpu"): |
|
pred_logits = pred_logits.cpu() |
|
if hasattr(label_ids, "cpu"): |
|
label_ids = label_ids.cpu() |
|
if isinstance(pred_logits, tuple): |
|
pred_ids = pred_logits[0] |
|
else: |
|
pred_ids = pred_logits |
|
if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3: |
|
if not isinstance(pred_ids, torch.Tensor): |
|
pred_ids = torch.tensor(pred_ids) |
|
pred_ids = pred_ids.argmax(dim=-1) |
|
pred_ids = pred_ids.tolist() |
|
|
|
if hasattr(label_ids, "tolist"): |
|
label_ids = label_ids.tolist() |
|
|
|
label_ids = [[0 if token == -100 else token for token in seq] for seq in label_ids] |
|
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False) |
|
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False) |
|
|
|
if print_pred: |
|
for i in range(min(num_samples, len(pred_str))): |
|
print(f"Preds: {pred_str[i]}") |
|
print(f"Label: {label_str[i]}") |
|
print(f"preds: {pred_ids[i]}") |
|
print(f"label: {label_ids[i]}") |
|
print("--------------------------------") |
|
|
|
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) |
|
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) |
|
wer = 100 * metric.compute(predictions=pred_str, references=label_str) |
|
|
|
if model is None: |
|
global global_model |
|
if 'global_model' in globals(): |
|
model = global_model |
|
|
|
if model is not None: |
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000 |
|
if trainable_params > 0: |
|
efficiency_score = (100 - wer) / trainable_params |
|
else: |
|
print("Warning: Zero trainable parameters detected") |
|
efficiency_score = 0.0 |
|
else: |
|
print("Warning: Model not available for parameter counting") |
|
trainable_params = 0.0 |
|
efficiency_score = 0.0 |
|
|
|
if hasattr(wer, "item"): |
|
wer = wer.item() |
|
|
|
metrics = { |
|
"wer": float(wer), |
|
"trainable_params_M": float(trainable_params), |
|
"efficiency_score": float(efficiency_score), |
|
} |
|
|
|
return metrics |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def create_model(param: Dimensions) -> Echo: |
|
model = Echo(param).to('cuda') |
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
logger.info(f"Trainable parameters: {trainable_params:,}") |
|
logger.info(f"Total parameters: {total_params:,}") |
|
print(f"Trainable parameters: {trainable_params:,}") |
|
print(f"Total parameters: {total_params:,}") |
|
|
|
return model |
|
|
|
def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/tokenn/"): |
|
from tokenizers import Tokenizer |
|
tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json") |
|
orig_encode = tokenizer.encode |
|
def enc(text, add_special_tokens=True): |
|
ids = orig_encode(text).ids |
|
if not add_special_tokens: |
|
sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]] |
|
ids = [id for id in ids if id not in sp_ids] |
|
return ids |
|
def bdec(ids_list, skip_special_tokens=True): |
|
results = [] |
|
for ids in ids_list: |
|
if skip_special_tokens: |
|
ids = [id for id in ids if id not in [0, 1, 2]] |
|
results.append(tokenizer.decode(ids)) |
|
return results |
|
def save_pretrained(save_dir): |
|
os.makedirs(save_dir, exist_ok=True) |
|
tokenizer.save(f"{save_dir}/tokenizer.json") |
|
tokenizer.encode = enc |
|
tokenizer.batch_decode = bdec |
|
tokenizer.save_pretrained = save_pretrained |
|
tokenizer.pad_token_id = 0 |
|
tokenizer.bos_token_id = 1 |
|
tokenizer.eos_token_id = 2 |
|
return tokenizer |
|
|
|
def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]: |
|
if dataset_config is None: |
|
dataset_config = { |
|
"spectrogram": True, |
|
"waveforms": True, |
|
"pitch": True, |
|
"frequency": True, |
|
"downsamples": True, |
|
"hop_length": 128, |
|
"fmin": 50, |
|
"fmax": 2000, |
|
"n_mels": 128, |
|
"n_fft": 1024, |
|
"sampling_rate": 16000, |
|
} |
|
|
|
dataset = load_dataset( |
|
"google/fleurs", |
|
"en_us", |
|
token=token, |
|
trust_remote_code=True, |
|
streaming=False) |
|
|
|
dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"]) |
|
|
|
if sanity_check: |
|
dataset = dataset["test"].take(10) |
|
dataset = dataset.select_columns(["audio", "transcription"]) |
|
logger.info(f"Sanity dataset size: {dataset.num_rows}") |
|
print(f"Sanity dataset size: {dataset.num_rows}") |
|
prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config) |
|
|
|
dataset = dataset.map( |
|
function=prepare_fn, |
|
remove_columns=["audio", "transcription"] |
|
).with_format(type="torch") |
|
train_dataset = dataset |
|
test_dataset = dataset |
|
else: |
|
def filter_func(x): |
|
return (0 < len(x["transcription"]) < 512 and |
|
len(x["audio"]["array"]) > 0 and |
|
len(x["audio"]["array"]) < 1500 * 160) |
|
|
|
dataset = dataset.filter(filter_func).shuffle(seed=4) |
|
logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}") |
|
print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}") |
|
prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config) |
|
columns_to_remove = list(next(iter(dataset.values())).features) |
|
train_dataset = dataset["train"] |
|
test_dataset = dataset["test"].take(50) |
|
logger.info(f"Train dataset size: {train_dataset.num_rows}, Test dataset size: {test_dataset.num_rows}") |
|
|
|
train_dataset = train_dataset.map( |
|
function=prepare_fn, |
|
remove_columns=columns_to_remove |
|
).with_format(type="torch") |
|
|
|
test_dataset = test_dataset.map( |
|
function=prepare_fn, |
|
remove_columns=columns_to_remove |
|
).with_format(type="torch") |
|
|
|
return train_dataset, test_dataset |
|
|
|
def get_training_args( |
|
log_dir: str, |
|
batch_eval_metrics: bool = False, |
|
max_steps: int = 10, |
|
save_steps: int = 1000, |
|
eval_steps: int = 1, |
|
warmup_steps: int = 0, |
|
num_train_epochs: int = 1, |
|
logging_steps: int = 1, |
|
eval_on_start: bool = False, |
|
learning_rate: float = 1e-4, |
|
weight_decay: float = 0.01, |
|
max_grad_norm: float = 1.0, |
|
) -> Seq2SeqTrainingArguments: |
|
|
|
return Seq2SeqTrainingArguments( |
|
output_dir=log_dir, |
|
per_device_train_batch_size=1, |
|
per_device_eval_batch_size=1, |
|
gradient_accumulation_steps=1, |
|
eval_accumulation_steps=1, |
|
tf32=True, |
|
bf16=True, |
|
eval_strategy="steps", |
|
save_strategy="steps", |
|
max_steps=max_steps, |
|
save_steps=save_steps, |
|
eval_steps=eval_steps, |
|
warmup_steps=warmup_steps, |
|
num_train_epochs=num_train_epochs, |
|
logging_steps=logging_steps, |
|
logging_dir=log_dir, |
|
logging_strategy="steps", |
|
report_to=["tensorboard"], |
|
push_to_hub=False, |
|
disable_tqdm=False, |
|
save_total_limit=1, |
|
label_names=["labels"], |
|
optim="adamw_torch", |
|
lr_scheduler_type="cosine", |
|
learning_rate=learning_rate, |
|
weight_decay=weight_decay, |
|
save_safetensors=False, |
|
eval_on_start=eval_on_start, |
|
batch_eval_metrics=batch_eval_metrics, |
|
max_grad_norm=max_grad_norm, |
|
) |
|
|
|
def main(): |
|
|
|
token = "" |
|
log_dir = os.path.join('./output/logs', datetime.now().strftime(format='%m-%d_%H')) |
|
os.makedirs(name=log_dir, exist_ok=True) |
|
tokenizer = setup_tokenizer(token) |
|
|
|
def sanity(sanity: bool): |
|
|
|
if sanity: |
|
training_args = get_training_args( |
|
log_dir, |
|
batch_eval_metrics = False, |
|
max_steps = 10, |
|
save_steps = 0, |
|
eval_steps = 1, |
|
warmup_steps = 0, |
|
logging_steps = 1, |
|
eval_on_start = False, |
|
learning_rate = 5e-6, |
|
weight_decay = 0.01, |
|
) |
|
else: |
|
training_args = get_training_args( |
|
log_dir, |
|
batch_eval_metrics = False, |
|
max_steps = 1000, |
|
save_steps = 1000, |
|
eval_steps = 100, |
|
warmup_steps = 100, |
|
logging_steps = 10, |
|
eval_on_start = False, |
|
learning_rate = 2.5e-4, |
|
weight_decay = 0.01, |
|
) |
|
|
|
return training_args |
|
|
|
param = Dimensions( |
|
mels=128, |
|
aud_ctx=1500, |
|
aud_head=4, |
|
aud_dims=512, |
|
aud_idx=4, |
|
vocab=40000, |
|
text_ctx=512, |
|
text_head=4, |
|
text_dims=512, |
|
text_idx=4, |
|
act="swish", |
|
debug={}, |
|
cross_attn=True, |
|
f0_rotary=False, |
|
features = ["spectrogram"] |
|
) |
|
|
|
sanity_check = False |
|
training_args = sanity(sanity_check) |
|
dataset_config = { |
|
"spectrogram": True, |
|
"waveforms": False, |
|
"pitch": False, |
|
"downsamples": False, |
|
"frequency": True, |
|
"hilbert": False, |
|
"hop_length": 128, |
|
"fmin": 150, |
|
"fmax": 2000, |
|
"n_mels": 128, |
|
"n_fft": 1024, |
|
"sampling_rate": 16000, |
|
"pad_mode": "constant", |
|
"center": True, |
|
"power": 2.0, |
|
"window_fn": torch.hann_window, |
|
"mel_scale": "htk", |
|
"norm": None, |
|
"normalized": False} |
|
|
|
model = create_model(param) |
|
|
|
global global_model |
|
global_model = model |
|
|
|
metrics_fn = partial(compute_metrics, print_pred=False, num_samples=5, |
|
tokenizer=tokenizer, model=model) |
|
|
|
print(f"{'Sanity check' if sanity_check else 'Training'} mode") |
|
train_dataset, test_dataset = prepare_datasets( |
|
tokenizer=tokenizer, |
|
token=token, |
|
sanity_check=sanity_check, |
|
dataset_config=dataset_config) |
|
|
|
trainer = Seq2SeqTrainer( |
|
args=training_args, |
|
model=model, |
|
train_dataset=train_dataset, |
|
eval_dataset=test_dataset, |
|
data_collator=DataCollator(tokenizer=tokenizer), |
|
compute_metrics=metrics_fn, |
|
) |
|
|
|
model.init_weights() |
|
trainer.train() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|