MegaTTS3 / tts /modules /ar_dur /commons /nar_tts_modules.py
ZiyueJiang's picture
first commit for huggingface space
593f3bc
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch import nn
import torch.nn.functional as F
class LengthRegulator(torch.nn.Module):
def __init__(self, pad_value=0.0):
super(LengthRegulator, self).__init__()
self.pad_value = pad_value
def forward(self, dur, dur_padding=None, alpha=1.0):
"""
Example (no batch dim version):
1. dur = [2,2,3]
2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
3. token_mask = [[1,1,0,0,0,0,0],
[0,0,1,1,0,0,0],
[0,0,0,0,1,1,1]]
4. token_idx * token_mask = [[1,1,0,0,0,0,0],
[0,0,2,2,0,0,0],
[0,0,0,0,3,3,3]]
5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
:param dur: Batch of durations of each frame (B, T_txt)
:param dur_padding: Batch of padding of each frame (B, T_txt)
:param alpha: duration rescale coefficient
:return:
mel2ph (B, T_speech)
assert alpha > 0
"""
dur = torch.round(dur.float() * alpha).long()
if dur_padding is not None:
dur = dur * (1 - dur_padding.long())
token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
dur_cumsum = torch.cumsum(dur, 1)
dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
mel2token = (token_idx * token_mask.long()).sum(1)
return mel2token
class PosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim) * -emb)
self.emb = emb # TODO
def forward(self, x):
emb = x[:, :, None] * self.emb[None, None, :].to(x.device)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb