martin
initial
67c46fd
raw
history blame
6.64 kB
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch.nn import functional as F
from omegaconf import DictConfig
from cosyvoice.utils.mask import make_pad_mask
import time
class MaskedDiffWithXvec(torch.nn.Module):
def __init__(
self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
encoder: torch.nn.Module = None,
length_regulator: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {
"in_channels": 240,
"out_channel": 80,
"spk_emb_dim": 80,
"n_spks": 1,
"cfm_params": DictConfig(
{
"sigma_min": 1e-06,
"solver": "euler",
"t_scheduler": "cosine",
"training_cfg_rate": 0.2,
"inference_cfg_rate": 0.7,
"reg_loss_type": "l1",
}
),
"decoder_params": {
"channels": [256, 256],
"dropout": 0.0,
"attention_head_dim": 64,
"n_blocks": 4,
"num_mid_blocks": 12,
"num_heads": 8,
"act_fn": "gelu",
},
},
mel_feat_conf: Dict = {
"n_fft": 1024,
"num_mels": 80,
"sampling_rate": 22050,
"hop_size": 256,
"win_size": 1024,
"fmin": 0,
"fmax": 8000,
},
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.encoder = encoder
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = decoder
self.length_regulator = length_regulator
self.only_mask_loss = only_mask_loss
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch["speech_token"].to(device)
token_len = batch["speech_token_len"].to(device)
feat = batch["speech_feat"].to(device)
feat_len = batch["speech_feat_len"].to(device)
embedding = batch["embedding"].to(device)
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
h, h_lengths = self.length_regulator(h, feat_len)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h)
feat = F.interpolate(
feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest"
).squeeze(dim=1)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds,
)
return {"loss": loss}
@torch.inference_mode()
def inference(
self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding,
):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
# text encode
token, token_len = (
torch.concat([prompt_token, token], dim=1),
prompt_token_len + token_len,
)
token = self.input_embedding(torch.clamp(token, min=0))
h, _ = self.encoder.inference(token, token_len)
h = self.encoder_proj(h)
mel_len1, mel_len2 = prompt_feat.shape[1], int(
token_len2
/ self.input_frame_rate
* self.mel_feat_conf["sampling_rate"]
/ self.mel_feat_conf["hop_size"]
)
h, _ = self.length_regulator.inference(
h[:, :token_len1],
h[:, token_len1:],
mel_len1,
mel_len2,
)
# get conditions
conds = torch.zeros(
[1, mel_len1 + mel_len2, self.output_size], device=token.device
)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
# mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
mask = torch.ones(
[1, mel_len1 + mel_len2], device=h.device, dtype=torch.bfloat16
)
feat = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10,
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat