# Copyright 2025 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn.functional as F def map_phone_to_tokendict(item, pad_bos_eos=True): # Merge Chinese phone and tone (Original dict ends at 173, i.e., ph_dict_size=173). 146~173 is punctuations. phone = item['txt_token'].clone() merged_phone = item['txt_token'].clone() tone_tmp = item['tone'].clone() # In tone_dict, tone_1 is 4, tone_2 is 11, tone_3 is 12, tone_4 is 13, tone_5 is 14, tone_6 is 15 tone_tmp[tone_tmp==4] = 1 tone_tmp[tone_tmp==11] = 2 tone_tmp[tone_tmp==12] = 3 tone_tmp[tone_tmp==13] = 4 tone_tmp[tone_tmp==14] = 5 tone_tmp[tone_tmp==15] = 6 # Chinese phones lie in 3~100 in the phone_dict, we map them to 200~788 ch_phone_idx = (phone >= 3) & (phone <= 100) merged_phone[ch_phone_idx] = (merged_phone[ch_phone_idx] - 3) * 6 + 200 + tone_tmp[ch_phone_idx] if pad_bos_eos: merged_phone = F.pad(merged_phone, (1, 0), mode='constant', value=798) merged_phone = F.pad(merged_phone, (0, 1), mode='constant', value=799) return merged_phone def split_ph_timestamp(ph_timestamp): ''' Input: ph_timestamp, shape [T] ''' # Map the timestamp of each phone back to its original frame-level lengths ph_timestamp[ph_timestamp >= 800] -= 800 ph_list = [] tone_list = [] dur_list = [] cur_timestamp = 0 for idx, item in enumerate(ph_timestamp): if idx % 2 == 0: # Map Chinese phones back to its original phone_dict if (200 <= item <= 788): ph = (item - 200 - 1) // 6 + 3 tone = (item - 200 - 1) % 6 + 1 if tone == 1: tone = 4 else: tone = tone + 9 # Set English tone to '3' else: ph = item tone = 3 ph_list.append(ph) tone_list.append(tone) else: dur_list.append((item - cur_timestamp)) cur_timestamp = item assert len(ph_list) == len(dur_list), f"{len(ph_list)}, {len(dur_list)}" ph_seq, tone_seq, dur_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list), torch.LongTensor(dur_list) return ph_seq, tone_seq, dur_seq, ph_timestamp[-1] def split_ph(ph_seq): ''' Input: ph_timestamp, shape [T] ''' ph_list = [] tone_list = [] for idx, item in enumerate(ph_seq): # Map Chinese phones back to its original phone_dict if (200 <= item <= 788): ph = (item - 200 - 1) // 6 + 3 tone = (item - 200 - 1) % 6 + 1 if tone == 1: tone = 4 else: tone = tone + 9 # Set English tone to '3' else: ph = item tone = 3 ph_list.append(ph) tone_list.append(tone) assert len(ph_list) == len(tone_list) ph_seq, tone_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list) return ph_seq, tone_seq