chong.zhang
update
96fe5d9
import os
from typing import Tuple, Any, Union, Dict
import torch
import yaml
from huggingface_hub import hf_hub_download
from torch import nn
from inspiremusic.wavtokenizer.decoder.feature_extractors import FeatureExtractor, EncodecFeatures
from inspiremusic.wavtokenizer.decoder.heads import FourierHead
from inspiremusic.wavtokenizer.decoder.models import Backbone
def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
"""Instantiates a class with the given args and init.
Args:
args: Positional arguments required for instantiation.
init: Dict of the form {"class_path":...,"init_args":...}.
Returns:
The instantiated class object.
"""
kwargs = init.get("init_args", {})
if not isinstance(args, tuple):
args = (args,)
class_module, class_name = init["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
args_class = getattr(module, class_name)
return args_class(*args, **kwargs)
class WavTokenizer(nn.Module):
"""
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
This class is primarily designed for inference, with support for loading from pretrained
model checkpoints. It consists of three main components: a feature extractor,
a backbone, and a head.
"""
def __init__(
self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead,
):
super().__init__()
self.feature_extractor = feature_extractor
self.backbone = backbone
self.head = head
@classmethod
def from_hparams(cls, config_path: str) -> "Vocos":
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
feature_extractor = instantiate_class(args=(), init=config["feature_extractor"])
backbone = instantiate_class(args=(), init=config["backbone"])
head = instantiate_class(args=(), init=config["head"])
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
return model
@classmethod
def from_pretrained(self, repo_id: str) -> "Vocos":
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
model = self.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu")
if isinstance(model.feature_extractor, EncodecFeatures):
encodec_parameters = {
"feature_extractor.encodec." + key: value
for key, value in model.feature_extractor.encodec.state_dict().items()
}
state_dict.update(encodec_parameters)
model.load_state_dict(state_dict)
model.eval()
return model
@classmethod
def from_hparams_feat(cls, config_path: str) -> "Vocos":
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
head = instantiate_class(args=(), init=config['model']['init_args']["head"])
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
return model
@classmethod
def from_pretrained_feat(self, config_path, model_path):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model = self.from_hparams_feat(config_path)
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
state_dict = dict()
for k, v in state_dict_raw.items():
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
state_dict[k] = v
model.load_state_dict(state_dict)
model.eval()
return model
@classmethod
def estimator(self, config_path, model_path):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model = self.from_hparams_feat(config_path)
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
state_dict = dict()
for k, v in state_dict_raw.items():
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
state_dict[k] = v
model.load_state_dict(state_dict)
model.eval()
return model
@classmethod
def from_pretrained0911(self, config_path, model_folder_path):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model = self.from_hparams0802(config_path)
models = os.listdir(model_folder_path)
val_loss = []
for item in models:
if not item.startswith('vocos_'):
continue
val_loss.append(item[-11:-5])
val_loss.sort()
val_loss = val_loss[:3] # 取前3性能较好的模型平均
state_dict = dict()
state_dicts = []
for item in models:
if not item.startswith('vocos_'):
continue
ll = item[-11:-5]
if ll not in val_loss:
continue
model_path = model_folder_path + '/' + item
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
state_dict_single = dict()
for k, v in state_dict_raw.items():
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
state_dict_single[k] = v
state_dicts.append(state_dict_single)
for kk in state_dicts[0].keys():
vv = state_dicts[0][kk]
for i in range(1, len(state_dicts)):
ss = state_dicts[i]
vv += ss[kk]
vm = vv/len(state_dicts)
state_dict[kk] = vm
model.load_state_dict(state_dict)
model.eval()
return model
@torch.inference_mode()
def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
which is then passed through the backbone and the head to reconstruct the audio output.
Args:
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
where B is the batch size and L is the waveform length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818
audio_output = self.decode(features, **kwargs)
return audio_output
# 0818
@torch.inference_mode()
def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
features, discrete_codes, _ = self.feature_extractor(audio_input, **kwargs)
return features,discrete_codes
# 0818
@torch.inference_mode()
def encode_infer(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
features, discrete_codes, _ = self.feature_extractor.infer(audio_input, **kwargs)
return features,discrete_codes
@torch.inference_mode()
def infer(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
_, discrete_codes, _ = self.feature_extractor._infer(audio_input, **kwargs)
discrete_codes = discrete_codes.clamp(min=0, max=16383)
return discrete_codes
@torch.inference_mode()
def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""
Method to decode audio waveform from already calculated features. The features input is passed through
the backbone and the head to reconstruct the audio output.
Args:
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
C denotes the feature dimension, and L is the sequence length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
x = self.backbone(features_input, **kwargs)
audio_output = self.head(x)
return audio_output
@torch.inference_mode()
def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
"""
Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
codebook weights.
Args:
codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
where K is the number of codebooks, B is the batch size and L is the sequence length.
Returns:
Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
and L is the sequence length.
"""
assert isinstance(
self.feature_extractor, EncodecFeatures
), "Feature extractor should be an instance of EncodecFeatures"
if codes.dim() == 2:
codes = codes.unsqueeze(1)
n_bins = self.feature_extractor.encodec.quantizer.bins
offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
embeddings_idxs = codes + offsets.view(-1, 1, 1)
tmp=torch.cat([vq.codebook for vq in self.feature_extractor.encodec.quantizer.vq.layers],dim=0)
# features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
features = torch.nn.functional.embedding(embeddings_idxs, tmp).sum(dim=0)
features = features.transpose(1, 2)
return features