RunTasking / VitsModelSplit /dataset_features_collector.py
wasmdashai's picture
Upload 26 files
2da45ea verified
import numpy as np
import os
from datasets import Dataset,DatasetDict
from typing import Union,List,Dict
import torch
from dataclasses import dataclass
from transformers.feature_extraction_utils import BatchFeature
from VitsModelSplit.feature_extraction import VitsFeatureExtractor
from VitsModelSplit.vits_model import VitsModel
from transformers import AutoTokenizer
#.............................................
@dataclass
class DataSetFeaturesCollector:
def __init__(self,tokenizer,model,feature_extractor,forward_attention_mask=True) -> None:
self.tokenizer=tokenizer
self.feature_extractor = feature_extractor
self.model=model
self.forward_attention_mask = forward_attention_mask
#.............................................
def pad_waveform(self, raw_speech):
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
if is_batched_numpy and len(raw_speech.shape) > 2:
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
)
if is_batched:
raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
raw_speech = raw_speech.astype(np.float32)
# always return batch
if not is_batched:
raw_speech = [np.asarray([raw_speech]).T]
batched_speech = BatchFeature({"input_features": raw_speech})
# convert into correct format for padding
padded_inputs = self.feature_extractor.pad(
batched_speech,
padding=True,
return_attention_mask=False,
return_tensors="pt",
)["input_features"]
return padded_inputs
#.............................................
def prepare_dataset(self,batch):
sample = batch['audio']
audio_inputs = self.feature_extractor(
sample,
sampling_rate=16000,
return_attention_mask=False,
do_normalize=False,
)
batch["labels"] = audio_inputs.get("input_features")[0]
batch["waveform_input_length"] = len(sample)
batch["waveform"] = batch['audio']
batch["mel_scaled_input_features"] = audio_inputs.get("mel_scaled_input_features")[0]
textsample = batch['text']
inputs = self.tokenizer(textsample, return_tensors="pt")
inputs = self.tokenizer.pad({'input_ids':inputs.input_ids})
batch['input_ids'] = inputs.input_ids
batch['attention_mask'] = inputs.attention_mask
# batch['speaker_id']=batch['speaker_id']
return batch
#.............................................
def __call__(self, dataset: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
dataset = Dataset.from_list(dataset)
features = dataset.map(
self.prepare_dataset,
remove_columns=dataset.column_names,
desc="preprocess",
)
features = list(features)
model_input_name = "input_ids"
input_ids = [{model_input_name: feature[model_input_name][0]} for feature in features]
# pad input tokens
batch = self.tokenizer.pad(input_ids, return_tensors="pt", return_attention_mask=self.forward_attention_mask)
# pad waveform
waveforms = [np.array(feature["waveform"]) for feature in features]
batch["waveform"] = self.pad_waveform(waveforms)
# pad spectrogram
label_features = [np.array(feature["labels"]) for feature in features]
labels_batch = self.feature_extractor.pad(
{"input_features": [i.T for i in label_features]}, return_tensors="pt", return_attention_mask=True
)
labels = labels_batch["input_features"].transpose(1, 2)
batch["labels"] = labels
batch["labels_attention_mask"] = labels_batch["attention_mask"]
# pad mel spectrogram
mel_scaled_input_features = {
"input_features": [np.array(feature["mel_scaled_input_features"]).squeeze().T for feature in features]
}
mel_scaled_input_features = self.feature_extractor.pad(
mel_scaled_input_features, return_tensors="pt", return_attention_mask=True
)["input_features"].transpose(1, 2)
batch["mel_scaled_input_features"] = mel_scaled_input_features
batch["speaker_id"] = (
torch.tensor([feature["speaker_id"] for feature in dataset]) if "speaker_id" in dataset[0] else None
)
with torch.no_grad():
padding_mask =torch.ones_like(batch['input_ids']).unsqueeze(-1).float()
text_encoder_output = self.model.text_encoder(batch['input_ids'],
padding_mask=padding_mask,
attention_mask = batch['attention_mask']
)
batch['text_encoder_output'] = text_encoder_output
posterior_latents, posterior_means, posterior_log_variances = self.model.posterior_encoder(
batch['labels'], batch['labels_attention_mask'].unsqueeze(1).float()
)
posterior_encode_output={
'posterior_latents':posterior_latents,
'posterior_means':posterior_means,
'posterior_log_variances':posterior_log_variances
}
batch['posterior_encode_output']=posterior_encode_output
return batch
#..............................................................
#.............................................
def run_dataset_features_collection(
dataset_dir,
train_split_name ="train",
eval_split_name="eval",
full_generation_name = 'full_generation',
tokenizer = None,
model = None,
feature_extractor = None,
train_batch_size = 1,
eval_batch_size = 1,
output_dir = "dataset_features"
):
dataset = DatasetDict.load_from_disk(dataset_dir)
data_collator = DataSetFeaturesCollector(
tokenizer = tokenizer,
model = model,
feature_extractor = feature_extractor,
forward_attention_mask = True
)
if train_split_name:
train_dataloader = torch.utils.data.DataLoader(
dataset[train_split_name],
shuffle=False,
collate_fn=data_collator,
batch_size=train_batch_size,
sampler=None,
)
train_dir = os.path.join(output_dir,"train")
os.makedirs(train_dir,exist_ok=True)
for step, batch in enumerate(train_dataloader):
print(f"Train Dataset - batch {step}, waveform {(batch['waveform'].shape)},tokens {(batch['input_ids'].shape)}... ")
fname = os.path.join(train_dir,f"train-batch-{step}.bin")
with open(fname, "wb") as f:
torch.save(batch, f)
if eval_split_name:
eval_dataloader = torch.utils.data.DataLoader(
dataset[eval_split_name],
shuffle=False,
collate_fn=data_collator,
batch_size=eval_batch_size,
sampler=None,
)
eval_dir = os.path.join(output_dir,"eval")
os.makedirs(eval_dir,exist_ok=True)
for step, batch in enumerate(eval_dataloader):
print(f"Eval Dataset - batch {step}, waveform {(batch['waveform'].shape)},tokens {(batch['input_ids'].shape)}... ")
fname = os.path.join(eval_dir,f"eval-batch-{step}.bin")
with open(fname, "wb") as f:
torch.save(batch, f)
if full_generation_name:
full_generation_dataloader = torch.utils.data.DataLoader(
dataset[full_generation_name],
shuffle=False,
collate_fn=data_collator,
batch_size=1,
sampler=None,
)
full_generation_dir = os.path.join(output_dir,"full_generation")
os.makedirs(full_generation_dir,exist_ok=True)
for step, batch in enumerate(full_generation_dataloader):
print(f"Full Generation Dataset - batch {step}, waveform {(batch['waveform'].shape)},tokens {(batch['input_ids'].shape)}... ")
fname = os.path.join(full_generation_dir,f"full-generation-batch-{step}.bin")
with open(fname, "wb") as f:
torch.save(batch, f)
#...........................................................................
import torch.utils.data
class FeaturesCollectionDataset(torch.utils.data.Dataset):
def __init__(self,dataset_dir,device='cpu') -> None:
self.dataset_dir = dataset_dir
self.batchs_path = sorted([os.path.join(self.dataset_dir,file) for file in os.listdir(dataset_dir) if file.endswith('.bin')])
self.device = device
def __len__(self):
return len(self.batchs_path)
def __getitem__(self, idx):
batch_name = self.batchs_path[idx]
with open(batch_name, "rb") as f:
batch = torch.load(f,map_location=torch.device(self.device))
return batch
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
"""
Maintain similar input lengths in a batch.
Length groups are specified by boundaries.
Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
It removes samples which are not included in the boundaries.
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
"""
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
self.lengths =dataset.lengths
self.batch_size = batch_size
self.boundaries = boundaries
self.buckets, self.num_samples_per_bucket = self._create_buckets()
self.total_size = sum(self.num_samples_per_bucket)
self.num_samples = self.total_size // self.num_replicas
def _create_buckets(self):
buckets = [[] for _ in range(len(self.boundaries) - 1)]
for i in range(len(self.lengths)):
length = self.lengths[i]
idx_bucket = self._bisect(length)
if idx_bucket != -1:
buckets[idx_bucket].append(i)
for i in range(len(buckets) - 1, 0, -1):
if len(buckets[i]) == 0:
buckets.pop(i)
self.boundaries.pop(i+1)
num_samples_per_bucket = []
for i in range(len(buckets)):
len_bucket = len(buckets[i])
total_batch_size = self.num_replicas * self.batch_size
rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
num_samples_per_bucket.append(len_bucket + rem)
return buckets, num_samples_per_bucket
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = []
if self.shuffle:
for bucket in self.buckets:
indices.append(torch.randperm(len(bucket), generator=g).tolist())
else:
for bucket in self.buckets:
indices.append(list(range(len(bucket))))
batches = []
for i in range(len(self.buckets)):
bucket = self.buckets[i]
len_bucket = len(bucket)
ids_bucket = indices[i]
num_samples_bucket = self.num_samples_per_bucket[i]
# add extra samples to make it evenly divisible
rem = num_samples_bucket - len_bucket
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
# subsample
ids_bucket = ids_bucket[self.rank::self.num_replicas]
# batching
for j in range(len(ids_bucket) // self.batch_size):
batch = [bucket[idx] for idx in ids_bucket[j*self.batch_size:(j+1)*self.batch_size]]
batches.append(batch)
if self.shuffle:
batch_ids = torch.randperm(len(batches), generator=g).tolist()
batches = [batches[i] for i in batch_ids]
self.batches = batches
assert len(self.batches) * self.batch_size == self.num_samples
return iter(self.batches)
def _bisect(self, x, lo=0, hi=None):
if hi is None:
hi = len(self.boundaries) - 1
if hi > lo:
mid = (hi + lo) // 2
if self.boundaries[mid] < x and x <= self.boundaries[mid+1]:
return mid
elif x <= self.boundaries[mid]:
return self._bisect(x, lo, mid)
else:
return self._bisect(x, mid + 1, hi)
else:
return -1
def __len__(self):
return self.num_samples // self.batch_size
class VitsCollectionDataset(torch.utils.data.Dataset):
def __init__(self,dataset,hop_length=256,rate=16_000,device='cpu') -> None:
self.dataset = dataset
self.lengths =(torch.tensor(dataset['secs'])*rate//(2*hop_length)).tolist()
self.device = device
def __len__(self):
return self.dataset.num_rows
def __getitem__(self, idx):
return self.dataset[idx]
def get_dataloader(dir_db_train,feature_extractor,name_db='train',batch_size=8,num_workers=0):
dataset = DatasetDict.load_from_disk(dir_db_train)
db_train=VitsCollectionDataset(dataset[name_db])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=VitsModel.from_pretrained("facebook/mms-tts-ara").to(device)
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-ara",cache_dir="./")#.to("cuda")
train_sampler = DistributedBucketSampler(
db_train,
batch_size,
[32,300,400,500,600,700,800,900,1000],
num_replicas=1,
rank=0,
shuffle=True)
data_collator = DataSetFeaturesCollector(
tokenizer = tokenizer,
model = model,
feature_extractor = feature_extractor,
forward_attention_mask = True
)
train_dataloader = torch.utils.data.DataLoader(
db_train,
num_workers=num_workers, shuffle=False, pin_memory=True,
collate_fn=data_collator, batch_sampler=train_sampler
)
return train_dataloader