Ricecake123 commited on
Commit
e79b770
1 Parent(s): 298e47d

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. AR/__init__.py +0 -0
  2. AR/__pycache__/__init__.cpython-310.pyc +0 -0
  3. AR/__pycache__/__init__.cpython-39.pyc +0 -0
  4. AR/data/__init__.py +0 -0
  5. AR/data/__pycache__/__init__.cpython-310.pyc +0 -0
  6. AR/data/__pycache__/__init__.cpython-39.pyc +0 -0
  7. AR/data/__pycache__/bucket_sampler.cpython-310.pyc +0 -0
  8. AR/data/__pycache__/bucket_sampler.cpython-39.pyc +0 -0
  9. AR/data/__pycache__/data_module.cpython-310.pyc +0 -0
  10. AR/data/__pycache__/data_module.cpython-39.pyc +0 -0
  11. AR/data/__pycache__/dataset.cpython-310.pyc +0 -0
  12. AR/data/__pycache__/dataset.cpython-39.pyc +0 -0
  13. AR/data/bucket_sampler.py +157 -0
  14. AR/data/data_module.py +66 -0
  15. AR/data/dataset.py +302 -0
  16. AR/exps/__init__.py +0 -0
  17. AR/exps/beats/BEATs.py +179 -0
  18. AR/exps/beats/README.md +127 -0
  19. AR/exps/beats/Tokenizers.py +172 -0
  20. AR/exps/beats/__init__.py +2 -0
  21. AR/exps/beats/backbone.py +791 -0
  22. AR/exps/beats/config.py +19 -0
  23. AR/exps/beats/modules.py +220 -0
  24. AR/exps/beats/ontology.json +0 -0
  25. AR/exps/beats/quantizer.py +235 -0
  26. AR/exps/get_beats_librilight.py +321 -0
  27. AR/exps/get_phones.py +232 -0
  28. AR/exps/get_phones_librilight.py +198 -0
  29. AR/exps/get_txt_librilight.py +255 -0
  30. AR/exps/split_train_val.py +35 -0
  31. AR/exps/t2s.py +197 -0
  32. AR/exps/test.py +139 -0
  33. AR/exps/text.txt +10 -0
  34. AR/exps/train.py +103 -0
  35. AR/exps/train_librilight_6k.py +170 -0
  36. AR/models/__init__.py +0 -0
  37. AR/models/__pycache__/__init__.cpython-310.pyc +0 -0
  38. AR/models/__pycache__/__init__.cpython-39.pyc +0 -0
  39. AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc +0 -0
  40. AR/models/__pycache__/t2s_lightning_module.cpython-39.pyc +0 -0
  41. AR/models/__pycache__/t2s_model.cpython-310.pyc +0 -0
  42. AR/models/__pycache__/t2s_model.cpython-39.pyc +0 -0
  43. AR/models/__pycache__/utils.cpython-310.pyc +0 -0
  44. AR/models/__pycache__/utils.cpython-39.pyc +0 -0
  45. AR/models/t2s_lightning_module.py +128 -0
  46. AR/models/t2s_model.py +298 -0
  47. AR/models/utils.py +164 -0
  48. AR/modules/__init__.py +0 -0
  49. AR/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  50. AR/modules/__pycache__/__init__.cpython-39.pyc +0 -0
AR/__init__.py ADDED
File without changes
AR/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (135 Bytes). View file
 
AR/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (138 Bytes). View file
 
AR/data/__init__.py ADDED
File without changes
AR/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (140 Bytes). View file
 
AR/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (143 Bytes). View file
 
AR/data/__pycache__/bucket_sampler.cpython-310.pyc ADDED
Binary file (4.42 kB). View file
 
AR/data/__pycache__/bucket_sampler.cpython-39.pyc ADDED
Binary file (4.39 kB). View file
 
AR/data/__pycache__/data_module.cpython-310.pyc ADDED
Binary file (2.27 kB). View file
 
AR/data/__pycache__/data_module.cpython-39.pyc ADDED
Binary file (2.29 kB). View file
 
AR/data/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (6.58 kB). View file
 
AR/data/__pycache__/dataset.cpython-39.pyc ADDED
Binary file (6.57 kB). View file
 
AR/data/bucket_sampler.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/bucketsampler.py
2
+ import itertools
3
+ import math
4
+ import random
5
+ from random import shuffle
6
+ from typing import Iterator
7
+ from typing import Optional
8
+ from typing import TypeVar
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ from torch.utils.data import Dataset
13
+ from torch.utils.data import Sampler
14
+
15
+ __all__ = [
16
+ "DistributedBucketSampler",
17
+ ]
18
+
19
+ T_co = TypeVar('T_co', covariant=True)
20
+
21
+
22
+ class DistributedBucketSampler(Sampler[T_co]):
23
+ r"""
24
+ sort the dataset wrt. input length
25
+ divide samples into buckets
26
+ sort within buckets
27
+ divide buckets into batches
28
+ sort batches
29
+ """
30
+
31
+ def __init__(self,
32
+ dataset: Dataset,
33
+ num_replicas: Optional[int]=None,
34
+ rank: Optional[int]=None,
35
+ shuffle: bool=True,
36
+ seed: int=0,
37
+ drop_last: bool=False,
38
+ batch_size: int=32) -> None:
39
+ if num_replicas is None:
40
+ if not dist.is_available():
41
+ raise RuntimeError(
42
+ "Requires distributed package to be available")
43
+ num_replicas = dist.get_world_size()
44
+ if rank is None:
45
+ if not dist.is_available():
46
+ raise RuntimeError(
47
+ "Requires distributed package to be available")
48
+ rank = dist.get_rank()
49
+ torch.cuda.set_device(rank)
50
+ if rank >= num_replicas or rank < 0:
51
+ raise ValueError("Invalid rank {}, rank should be in the interval"
52
+ " [0, {}]".format(rank, num_replicas - 1))
53
+ self.dataset = dataset
54
+ self.num_replicas = num_replicas
55
+ self.rank = rank
56
+ self.epoch = 0
57
+ self.drop_last = drop_last
58
+ # If the dataset length is evenly divisible by # of replicas, then there
59
+ # is no need to drop any data, since the dataset will be split equally.
60
+ if self.drop_last and len(
61
+ self.
62
+ dataset) % self.num_replicas != 0: # type: ignore[arg-type]
63
+ # Split to nearest available length that is evenly divisible.
64
+ # This is to ensure each rank receives the same amount of data when
65
+ # using this Sampler.
66
+ self.num_samples = math.ceil(
67
+ (len(self.dataset) - self.num_replicas) /
68
+ self.num_replicas # type: ignore[arg-type]
69
+ )
70
+ else:
71
+ self.num_samples = math.ceil(
72
+ len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
73
+ self.total_size = self.num_samples * self.num_replicas
74
+ self.shuffle = shuffle
75
+ self.seed = seed
76
+ self.batch_size = batch_size
77
+ self.id_with_length = self._get_sample_lengths()
78
+ self.id_buckets = self.make_buckets(bucket_width=2.0)
79
+
80
+ def _get_sample_lengths(self):
81
+ id_with_lengths = []
82
+ for i in range(len(self.dataset)):
83
+ id_with_lengths.append((i, self.dataset.get_sample_length(i)))
84
+ id_with_lengths.sort(key=lambda x: x[1])
85
+ return id_with_lengths
86
+
87
+ def make_buckets(self, bucket_width: float=2.0):
88
+ buckets = []
89
+ cur = []
90
+ max_sec = bucket_width
91
+ for id, sec in self.id_with_length:
92
+ if sec < max_sec:
93
+ cur.append(id)
94
+ else:
95
+ buckets.append(cur)
96
+ cur = [id]
97
+ max_sec += bucket_width
98
+ if len(cur) > 0:
99
+ buckets.append(cur)
100
+ return buckets
101
+
102
+ def __iter__(self) -> Iterator[T_co]:
103
+ if self.shuffle:
104
+ # deterministically shuffle based on epoch and seed
105
+ g = torch.Generator()
106
+ g.manual_seed(self.seed + self.epoch)
107
+ random.seed(self.epoch + self.seed)
108
+ shuffled_bucket = []
109
+ for buc in self.id_buckets:
110
+ buc_copy = buc.copy()
111
+ shuffle(buc_copy)
112
+ shuffled_bucket.append(buc_copy)
113
+ grouped_batch_size = self.batch_size * self.num_replicas
114
+ shuffled_bucket = list(itertools.chain(*shuffled_bucket))
115
+ n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
116
+ batches = [
117
+ shuffled_bucket[b * grouped_batch_size:(b + 1) *
118
+ grouped_batch_size] for b in range(n_batch)
119
+ ]
120
+ shuffle(batches)
121
+ indices = list(itertools.chain(*batches))
122
+ else:
123
+ # type: ignore[arg-type]
124
+ indices = list(range(len(self.dataset)))
125
+
126
+ if not self.drop_last:
127
+ # add extra samples to make it evenly divisible
128
+ padding_size = self.total_size - len(indices)
129
+ if padding_size <= len(indices):
130
+ indices += indices[:padding_size]
131
+ else:
132
+ indices += (indices * math.ceil(padding_size /
133
+ len(indices)))[:padding_size]
134
+ else:
135
+ # remove tail of data to make it evenly divisible.
136
+ indices = indices[:self.total_size]
137
+ assert len(indices) == self.total_size
138
+
139
+ # subsample
140
+ indices = indices[self.rank:self.total_size:self.num_replicas]
141
+ assert len(indices) == self.num_samples
142
+
143
+ return iter(indices)
144
+
145
+ def __len__(self) -> int:
146
+ return self.num_samples
147
+
148
+ def set_epoch(self, epoch: int) -> None:
149
+ r"""
150
+ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
151
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
152
+ sampler will yield the same ordering.
153
+
154
+ Args:
155
+ epoch (int): Epoch number.
156
+ """
157
+ self.epoch = epoch
AR/data/data_module.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
2
+ from pytorch_lightning import LightningDataModule
3
+ from AR.data.bucket_sampler import DistributedBucketSampler
4
+ from AR.data.dataset import Text2SemanticDataset
5
+ from torch.utils.data import DataLoader
6
+
7
+
8
+ class Text2SemanticDataModule(LightningDataModule):
9
+ def __init__(self, config, train_semantic_path, train_phoneme_path,dev_semantic_path=None, dev_phoneme_path=None):
10
+ super().__init__()
11
+ self.config = config
12
+ self.train_semantic_path = train_semantic_path
13
+ self.train_phoneme_path = train_phoneme_path
14
+ self.dev_semantic_path = dev_semantic_path
15
+ self.dev_phoneme_path = dev_phoneme_path
16
+ self.num_workers = self.config['data']['num_workers']
17
+
18
+ def prepare_data(self):
19
+ pass
20
+
21
+ def setup(self, stage=None, output_logs=False):
22
+ self._train_dataset = Text2SemanticDataset(
23
+ phoneme_path=self.train_phoneme_path,
24
+ semantic_path=self.train_semantic_path,
25
+ max_sec=self.config['data']['max_sec'],
26
+ pad_val=self.config['data']['pad_val'])
27
+ self._dev_dataset = self._train_dataset
28
+ # self._dev_dataset = Text2SemanticDataset(
29
+ # phoneme_path=self.dev_phoneme_path,
30
+ # semantic_path=self.dev_semantic_path,
31
+ # max_sample=self.config['data']['max_eval_sample'],
32
+ # max_sec=self.config['data']['max_sec'],
33
+ # pad_val=self.config['data']['pad_val'])
34
+
35
+ def train_dataloader(self):
36
+ batch_size = self.config['train']['batch_size']
37
+ sampler = DistributedBucketSampler(
38
+ self._train_dataset, batch_size=batch_size)
39
+ return DataLoader(
40
+ self._train_dataset,
41
+ batch_size=batch_size,
42
+ sampler=sampler,
43
+ collate_fn=self._train_dataset.collate,
44
+ num_workers=self.num_workers,
45
+ persistent_workers=True,
46
+ prefetch_factor=16
47
+ )
48
+
49
+ def val_dataloader(self):
50
+ return DataLoader(
51
+ self._dev_dataset,
52
+ batch_size=1,
53
+ shuffle=False,
54
+ collate_fn=self._train_dataset.collate,
55
+ num_workers=max(self.num_workers,12),
56
+ persistent_workers=True,
57
+ prefetch_factor=16
58
+ )
59
+
60
+ # 这个会使用到嘛?
61
+ def test_dataloader(self):
62
+ return DataLoader(
63
+ self._dev_dataset,
64
+ batch_size=1,
65
+ shuffle=False,
66
+ collate_fn=self._train_dataset.collate)
AR/data/dataset.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
2
+ import pdb
3
+ import sys
4
+ # sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
5
+ import traceback,os
6
+ from typing import Dict
7
+ from typing import List
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch,json
12
+ from torch.utils.data import DataLoader
13
+ from torch.utils.data import Dataset
14
+ from transformers import AutoTokenizer
15
+
16
+ from text import cleaned_text_to_sequence
17
+ # from config import exp_dir
18
+
19
+ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
20
+ seq = sequences[0]
21
+ ndim = seq.ndim
22
+ if axis < 0:
23
+ axis += ndim
24
+ dtype = seq.dtype
25
+ pad_value = dtype.type(pad_value)
26
+ seq_lengths = [seq.shape[axis] for seq in sequences]
27
+ max_length = np.max(seq_lengths)
28
+
29
+ padded_sequences = []
30
+ for seq, length in zip(sequences, seq_lengths):
31
+ padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (
32
+ ndim - axis - 1)
33
+ padded_seq = np.pad(
34
+ seq, padding, mode='constant', constant_values=pad_value)
35
+ padded_sequences.append(padded_seq)
36
+ batch = np.stack(padded_sequences)
37
+ return batch
38
+
39
+ class Text2SemanticDataset(Dataset):
40
+ """dataset class for text tokens to semantic model training."""
41
+
42
+ def __init__(self,
43
+ phoneme_path: str,
44
+ semantic_path: str,
45
+ max_sample: int = None,
46
+ max_sec: int = 100,
47
+ pad_val: int = 1024,
48
+ # min value of phoneme/sec
49
+ min_ps_ratio: int = 3,
50
+ # max value of phoneme/sec
51
+ max_ps_ratio: int = 25) -> None:
52
+ super().__init__()
53
+
54
+ self.semantic_data = pd.read_csv(semantic_path, delimiter='\t', encoding="utf-8")
55
+ # get dict
56
+ self.path2=phoneme_path#"%s/2-name2text.txt"%exp_dir#phoneme_path
57
+ self.path3="%s/3-bert"%(os.path.basename(phoneme_path))#"%s/3-bert"%exp_dir#bert_dir
58
+ self.path6=semantic_path#"%s/6-name2semantic.tsv"%exp_dir#semantic_path
59
+ assert os.path.exists(self.path2)
60
+ assert os.path.exists(self.path6)
61
+ self.phoneme_data={}
62
+ with open(self.path2,"r",encoding="utf8")as f:
63
+ lines=f.read().strip("\n").split("\n")
64
+
65
+ for line in lines:
66
+ tmp=line.split("\t")
67
+ if(len(tmp)!=4):continue
68
+ self.phoneme_data[tmp[0]]=[tmp[1],tmp[2],tmp[3]]
69
+
70
+ # self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
71
+ # pad for semantic tokens
72
+ self.PAD: int = pad_val
73
+ # self.hz = 25
74
+ # with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
75
+ # data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
76
+ # self.hz=int(data[:-2])#
77
+ self.hz=int(os.environ.get("hz","25hz")[:-2])
78
+
79
+ # max seconds of semantic token
80
+ self.max_sec = max_sec
81
+ self.min_ps_ratio = min_ps_ratio
82
+ self.max_ps_ratio = max_ps_ratio
83
+
84
+ if max_sample is not None:
85
+ self.semantic_data = self.semantic_data[:max_sample]
86
+
87
+ # {idx: (semantic, phoneme)}
88
+ # semantic list, phoneme list
89
+ self.semantic_phoneme = []
90
+ self.item_names = []
91
+
92
+ self.inited = False
93
+
94
+ if not self.inited:
95
+ # 调用初始化函数
96
+ self.init_batch()
97
+ self.inited = True
98
+ del self.semantic_data
99
+ del self.phoneme_data
100
+ # self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
101
+ # self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
102
+
103
+
104
+ def init_batch(self):
105
+ semantic_data_len = len(self.semantic_data)
106
+ phoneme_data_len = len(self.phoneme_data.keys())
107
+ print("semantic_data_len:", semantic_data_len)
108
+ print("phoneme_data_len:", phoneme_data_len)
109
+ idx = 0
110
+ num_not_in = 0
111
+ num_deleted_bigger = 0
112
+ num_deleted_ps = 0
113
+ for i in range(semantic_data_len):
114
+ # 先依次遍历
115
+ # get str
116
+ item_name = self.semantic_data['item_name'][i]
117
+ # print(self.phoneme_data)
118
+ try:
119
+ phoneme, word2ph, text = self.phoneme_data[item_name]
120
+ except Exception:
121
+ traceback.print_exc()
122
+ # print(f"{item_name} not in self.phoneme_data !")
123
+ num_not_in += 1
124
+ continue
125
+
126
+ semantic_str = self.semantic_data['semantic_audio'][i]
127
+ # get token list
128
+ semantic_ids = [int(idx) for idx in semantic_str.split(' ')]
129
+ # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
130
+ # 过滤掉太长的样本
131
+ if len(semantic_ids) > self.max_sec * self.hz:#########1###根据token���数推测总时长过滤时长60s(config里)#40*25=1k
132
+ num_deleted_bigger += 1
133
+ continue
134
+ # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
135
+ phoneme = phoneme.split(' ')
136
+
137
+ try:
138
+ phoneme_ids = cleaned_text_to_sequence(phoneme)
139
+ except:
140
+ traceback.print_exc()
141
+ # print(f"{item_name} not in self.phoneme_data !")
142
+ num_not_in += 1
143
+ continue
144
+ # if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
145
+ if len(phoneme_ids) >self.max_sec * self.hz/2.5:###########2:改为恒定限制为semantic/2.5就行
146
+ num_deleted_ps += 1
147
+ continue
148
+ # if len(semantic_ids) > 1000:###########3
149
+ # num_deleted_bigger += 1
150
+ # continue
151
+
152
+ ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
153
+
154
+ if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio:##########4#3~25#每秒多少个phone
155
+ num_deleted_ps += 1
156
+ # print(item_name)
157
+ continue
158
+
159
+ self.semantic_phoneme.append((semantic_ids, phoneme_ids))
160
+ idx += 1
161
+ self.item_names.append(item_name)
162
+
163
+ min_num=100#20直接不补#30补了也不存ckpt
164
+ leng =len(self.semantic_phoneme)
165
+ if(leng<min_num):
166
+ tmp1=self.semantic_phoneme
167
+ tmp2=self.item_names
168
+ self.semantic_phoneme=[]
169
+ self.item_names=[]
170
+ for _ in range(max(2,int(min_num/leng))):
171
+ self.semantic_phoneme+=tmp1
172
+ self.item_names+=tmp2
173
+ if num_not_in > 0:
174
+ print(f"there are {num_not_in} semantic datas not in phoneme datas")
175
+ if num_deleted_bigger > 0:
176
+ print(
177
+ f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
178
+ )
179
+ if num_deleted_ps > 0:
180
+ # 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
181
+ print(
182
+ f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
183
+ )
184
+ '''
185
+ there are 31 semantic datas not in phoneme datas
186
+ deleted 34 audios who's duration are bigger than 54 seconds
187
+ deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
188
+ dataset.__len__(): 366463
189
+
190
+ '''
191
+ # 345410 for LibriTTS
192
+ print("dataset.__len__():", self.__len__())
193
+
194
+ def __get_item_names__(self) -> List[str]:
195
+ return self.item_names
196
+
197
+ def __len__(self) -> int:
198
+ return len(self.semantic_phoneme)
199
+
200
+ def __getitem__(self, idx: int) -> Dict:
201
+ semantic_ids, phoneme_ids = self.semantic_phoneme[idx]
202
+ item_name = self.item_names[idx]
203
+ phoneme_ids_len = len(phoneme_ids)
204
+ # semantic tokens target
205
+ semantic_ids_len = len(semantic_ids)
206
+
207
+ flag=0
208
+ path_bert = "%s/%s.pt" % (self.path3, item_name)
209
+ if(os.path.exists(path_bert)==True):bert_feature = torch.load(path_bert,map_location="cpu")
210
+ else:flag=1
211
+ if(flag==1):
212
+ # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
213
+ bert_feature=None
214
+ else:
215
+ assert bert_feature.shape[-1] == len(phoneme_ids)
216
+ return {
217
+ 'idx': idx,
218
+ 'phoneme_ids': phoneme_ids,
219
+ 'phoneme_ids_len': phoneme_ids_len,
220
+ 'semantic_ids': semantic_ids,
221
+ 'semantic_ids_len': semantic_ids_len,
222
+ 'bert_feature': bert_feature,
223
+ }
224
+
225
+ def get_sample_length(self, idx: int):
226
+ semantic_ids = self.semantic_phoneme[idx][0]
227
+ sec = 1.0 * len(semantic_ids) / self.hz
228
+ return sec
229
+
230
+ def collate(self, examples: List[Dict]) -> Dict:
231
+ sample_index: List[int] = []
232
+ phoneme_ids: List[torch.Tensor] = []
233
+ phoneme_ids_lens: List[int] = []
234
+ semantic_ids: List[torch.Tensor] = []
235
+ semantic_ids_lens: List[int] = []
236
+ # return
237
+
238
+
239
+ for item in examples:
240
+ sample_index.append(item["idx"])
241
+ phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
242
+ semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
243
+ phoneme_ids_lens.append(item["phoneme_ids_len"])
244
+ semantic_ids_lens.append(item["semantic_ids_len"])
245
+
246
+ # pad 0
247
+ phoneme_ids = batch_sequences(phoneme_ids)
248
+ semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
249
+
250
+ # # convert each batch to torch.tensor
251
+ phoneme_ids = torch.tensor(phoneme_ids)
252
+ semantic_ids = torch.tensor(semantic_ids)
253
+ phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
254
+ semantic_ids_lens = torch.tensor(semantic_ids_lens)
255
+ bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
256
+ bert_padded.zero_()
257
+
258
+ for idx, item in enumerate(examples):
259
+ bert = item['bert_feature']
260
+ if(bert!=None):
261
+ bert_padded[idx, :, :bert.shape[-1]] = bert
262
+
263
+ return {
264
+ # List[int]
265
+ "ids": sample_index,
266
+ # torch.Tensor (B, max_phoneme_length)
267
+ "phoneme_ids": phoneme_ids,
268
+ # torch.Tensor (B)
269
+ "phoneme_ids_len": phoneme_ids_lens,
270
+ # torch.Tensor (B, max_semantic_ids_length)
271
+ "semantic_ids": semantic_ids,
272
+ # torch.Tensor (B)
273
+ "semantic_ids_len": semantic_ids_lens,
274
+ # torch.Tensor (B, 1024, max_phoneme_length)
275
+ "bert_feature": bert_padded,
276
+ }
277
+
278
+
279
+ if __name__ == '__main__':
280
+ root_dir = '/data/docker/liujing04/gpt-vits/prepare/dump_mix/'
281
+ dataset = Text2SemanticDataset(
282
+ phoneme_path=root_dir + 'phoneme_train.npy',
283
+ semantic_path=root_dir + 'semantic_train.tsv')
284
+
285
+ batch_size = 12
286
+ dataloader = DataLoader(
287
+ dataset,
288
+ batch_size=batch_size,
289
+ collate_fn=dataset.collate,
290
+ shuffle=False)
291
+ for i, batch in enumerate(dataloader):
292
+ if(i%1000==0):print(i)
293
+ # if i == 0:
294
+ # print('batch["ids"]:', batch["ids"])
295
+ # print('batch["phoneme_ids"]:', batch["phoneme_ids"],
296
+ # batch["phoneme_ids"].shape)
297
+ # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
298
+ # batch["phoneme_ids_len"].shape)
299
+ # print('batch["semantic_ids"]:', batch["semantic_ids"],
300
+ # batch["semantic_ids"].shape)
301
+ # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
302
+ # batch["semantic_ids_len"].shape)
AR/exps/__init__.py ADDED
File without changes
AR/exps/beats/BEATs.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+ import logging
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+ from torch.nn import LayerNorm
16
+
17
+ from .backbone import TransformerEncoder
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class BEATsConfig:
23
+ def __init__(self, cfg=None):
24
+ self.input_patch_size: int = -1 # path size of patch embedding
25
+ self.embed_dim: int = 512 # patch embedding dimension
26
+ self.conv_bias: bool = False # include bias in conv encoder
27
+
28
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
29
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
30
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
31
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
32
+ self.activation_fn: str = "gelu" # activation function to use
33
+
34
+ self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
35
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
36
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
37
+
38
+ # dropouts
39
+ self.dropout: float = 0.1 # dropout probability for the transformer
40
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
41
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
42
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
43
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
44
+
45
+ # positional embeddings
46
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
47
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
48
+
49
+ # relative position embedding
50
+ self.relative_position_embedding: bool = False # apply relative position embedding
51
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
52
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
53
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
54
+
55
+ # label predictor
56
+ self.finetuned_model: bool = False # whether the model is a fine-tuned model.
57
+ self.predictor_dropout: float = 0.1 # dropout probability for the predictor
58
+ self.predictor_class: int = 527 # target class number for the predictor
59
+
60
+ if cfg is not None:
61
+ self.update(cfg)
62
+
63
+ def update(self, cfg: dict):
64
+ self.__dict__.update(cfg)
65
+
66
+
67
+ class BEATs(nn.Module):
68
+ def __init__(
69
+ self,
70
+ cfg: BEATsConfig, ) -> None:
71
+ super().__init__()
72
+ logger.info(f"BEATs Config: {cfg.__dict__}")
73
+
74
+ self.cfg = cfg
75
+
76
+ self.embed = cfg.embed_dim
77
+ self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)
78
+ if self.embed != cfg.encoder_embed_dim else
79
+ None)
80
+
81
+ self.input_patch_size = cfg.input_patch_size
82
+ self.patch_embedding = nn.Conv2d(
83
+ 1,
84
+ self.embed,
85
+ kernel_size=self.input_patch_size,
86
+ stride=self.input_patch_size,
87
+ bias=cfg.conv_bias)
88
+
89
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
90
+
91
+ assert not cfg.deep_norm or not cfg.layer_norm_first
92
+ self.encoder = TransformerEncoder(cfg)
93
+ self.layer_norm = LayerNorm(self.embed)
94
+
95
+ if cfg.finetuned_model:
96
+ self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
97
+ self.predictor = nn.Linear(cfg.encoder_embed_dim,
98
+ cfg.predictor_class)
99
+ else:
100
+ self.predictor = None
101
+
102
+ def forward_padding_mask(
103
+ self,
104
+ features: torch.Tensor,
105
+ padding_mask: torch.Tensor, ) -> torch.Tensor:
106
+ extra = padding_mask.size(1) % features.size(1)
107
+ if extra > 0:
108
+ padding_mask = padding_mask[:, :-extra]
109
+ padding_mask = padding_mask.view(
110
+ padding_mask.size(0), features.size(1), -1)
111
+ padding_mask = padding_mask.all(-1)
112
+ return padding_mask
113
+
114
+ def preprocess(
115
+ self,
116
+ source: torch.Tensor,
117
+ fbank_mean: float=15.41663,
118
+ fbank_std: float=6.55582, ) -> torch.Tensor:
119
+ fbanks = []
120
+ for waveform in source:
121
+ waveform = waveform.unsqueeze(0) * 2**15
122
+ fbank = ta_kaldi.fbank(
123
+ waveform,
124
+ num_mel_bins=128,
125
+ sample_frequency=16000,
126
+ frame_length=25,
127
+ frame_shift=10)
128
+ fbanks.append(fbank)
129
+ fbank = torch.stack(fbanks, dim=0)
130
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
131
+ return fbank
132
+
133
+ def extract_features(
134
+ self,
135
+ source: torch.Tensor,
136
+ padding_mask: Optional[torch.Tensor]=None,
137
+ fbank_mean: float=15.41663,
138
+ fbank_std: float=6.55582, ):
139
+ fbank = self.preprocess(
140
+ source, fbank_mean=fbank_mean, fbank_std=fbank_std)
141
+
142
+ if padding_mask is not None:
143
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
144
+
145
+ fbank = fbank.unsqueeze(1)
146
+ features = self.patch_embedding(fbank)
147
+ features = features.reshape(features.shape[0], features.shape[1], -1)
148
+ features = features.transpose(1, 2)
149
+ features = self.layer_norm(features)
150
+
151
+ if padding_mask is not None:
152
+ padding_mask = self.forward_padding_mask(features, padding_mask)
153
+
154
+ if self.post_extract_proj is not None:
155
+ features = self.post_extract_proj(features)
156
+
157
+ x = self.dropout_input(features)
158
+
159
+ x, layer_results = self.encoder(
160
+ x,
161
+ padding_mask=padding_mask, )
162
+
163
+ if self.predictor is not None:
164
+ x = self.predictor_dropout(x)
165
+ logits = self.predictor(x)
166
+
167
+ if padding_mask is not None and padding_mask.any():
168
+ logits[padding_mask] = 0
169
+ logits = logits.sum(dim=1)
170
+ logits = logits / (~padding_mask).sum(
171
+ dim=1).unsqueeze(-1).expand_as(logits)
172
+ else:
173
+ logits = logits.mean(dim=1)
174
+
175
+ lprobs = torch.sigmoid(logits)
176
+
177
+ return lprobs, padding_mask
178
+ else:
179
+ return x, padding_mask
AR/exps/beats/README.md ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # BEATs
3
+
4
+ [**BEATs**](https://arxiv.org/abs/2212.09058): **Audio Pre-Training with Acoustic Tokenizers**
5
+
6
+ Official PyTorch implementation and pretrained models of BEATs
7
+
8
+ ## Pre-Trained and Fine-Tuned Tokenizers and Models
9
+ Iterations | Tokenizer | Pre-Trained Model | AudioSet Fine-Tuned Model 1 | AudioSet Fine-Tuned Model 2
10
+ |---|---|---|---|---
11
+ Iter1 | Random Projection | [BEATs_iter1](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter1 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter1 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
12
+ Iter2 | [Tokenizer_iter2](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter2](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter2 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter2 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
13
+ Iter3 | [Tokenizer_iter3](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
14
+ Iter3+ | [Tokenizer_iter3+ (AS20K)](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3_plus_AS20K.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3+ (AS20K)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS20K) (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS20K) (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
15
+ Iter3+ | [Tokenizer_iter3+ (AS2M)](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3+ (AS2M)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS2M) (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS2M) (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
16
+
17
+
18
+ ### Load Tokenizers
19
+
20
+ ```python
21
+ import torch
22
+ from Tokenizers import TokenizersConfig, Tokenizers
23
+
24
+ # load the pre-trained checkpoints
25
+ checkpoint = torch.load('/path/to/tokenizer.pt')
26
+
27
+ cfg = TokenizersConfig(checkpoint['cfg'])
28
+ BEATs_tokenizer = Tokenizers(cfg)
29
+ BEATs_tokenizer.load_state_dict(checkpoint['model'])
30
+ BEATs_tokenizer.eval()
31
+
32
+ # tokenize the audio and generate the labels
33
+ audio_input_16khz = torch.randn(1, 10000)
34
+ padding_mask = torch.zeros(1, 10000).bool()
35
+
36
+ labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)
37
+ ```
38
+
39
+
40
+ ### Load Pre-Trained Models
41
+
42
+ ```python
43
+ import torch
44
+ from BEATs import BEATs, BEATsConfig
45
+
46
+ # load the pre-trained checkpoints
47
+ checkpoint = torch.load('/path/to/model.pt')
48
+
49
+ cfg = BEATsConfig(checkpoint['cfg'])
50
+ BEATs_model = BEATs(cfg)
51
+ BEATs_model.load_state_dict(checkpoint['model'])
52
+ BEATs_model.eval()
53
+
54
+ # extract the the audio representation
55
+ audio_input_16khz = torch.randn(1, 10000)
56
+ padding_mask = torch.zeros(1, 10000).bool()
57
+
58
+ representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
59
+ ```
60
+
61
+
62
+ ### Load Fine-tuned Models
63
+
64
+ ```python
65
+ import torch
66
+ from BEATs import BEATs, BEATsConfig
67
+
68
+ # load the fine-tuned checkpoints
69
+ checkpoint = torch.load('/path/to/model.pt')
70
+
71
+ cfg = BEATsConfig(checkpoint['cfg'])
72
+ BEATs_model = BEATs(cfg)
73
+ BEATs_model.load_state_dict(checkpoint['model'])
74
+ BEATs_model.eval()
75
+
76
+ # predict the classification probability of each class
77
+ audio_input_16khz = torch.randn(3, 10000)
78
+ padding_mask = torch.zeros(3, 10000).bool()
79
+
80
+ probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
81
+
82
+ for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
83
+ top5_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
84
+ print(f'Top 5 predicted labels of the {i}th audio are {top5_label} with probability of {top5_label_prob}')
85
+ ```
86
+
87
+ ## Evaluation Results
88
+
89
+ ### Comparing with the SOTA Single Models
90
+ ![alt text](Evaluation_Results/Comparing_with_the_SOTA_Single_Models.png)
91
+
92
+
93
+ ### Comparing with the SOTA Ensemble Models
94
+ ![alt text](Evaluation_Results/Comparing_with_the_SOTA_Ensemble_Models.png)
95
+
96
+
97
+ ### Comparing Different BEATS Tokenizers
98
+ ![alt text](Evaluation_Results/Comparing_Different_BEATS_Tokenizers.png)
99
+
100
+
101
+ ### Comparing Different Pre-Training Targets
102
+ ![alt text](Evaluation_Results/Comparing_Different_Pre-Training_Targets.png)
103
+
104
+
105
+ ## License
106
+ This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
107
+ Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) and [VQGAN](https://github.com/CompVis/taming-transformers) project.
108
+
109
+ [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
110
+
111
+
112
+ ### Reference
113
+ If you find our work is useful in your research, please cite the following paper:
114
+ ``` latex
115
+ @article{Chen2022beats,
116
+ title = {BEATs: Audio Pre-Training with Acoustic Tokenizers},
117
+ author = {Sanyuan Chen and Yu Wu and Chengyi Wang and Shujie Liu and Daniel Tompkins and Zhuo Chen and Furu Wei},
118
+ eprint={2212.09058},
119
+ archivePrefix={arXiv},
120
+ year={2022}
121
+ }
122
+ ```
123
+ ### Contact Information
124
+
125
+ For help or issues using BEATs models, please submit a GitHub issue.
126
+
127
+ For other communications related to BEATs, please contact Yu Wu (`[email protected]`).
AR/exps/beats/Tokenizers.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+ import logging
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+ from backbone import (
16
+ TransformerEncoder, )
17
+ from quantizer import (
18
+ NormEMAVectorQuantizer, )
19
+ from torch.nn import LayerNorm
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class TokenizersConfig:
25
+ def __init__(self, cfg=None):
26
+ self.input_patch_size: int = -1 # path size of patch embedding
27
+ self.embed_dim: int = 512 # patch embedding dimension
28
+ self.conv_bias: bool = False # include bias in conv encoder
29
+
30
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
31
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
32
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
33
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
34
+ self.activation_fn: str = "gelu" # activation function to use
35
+
36
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
37
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
38
+
39
+ # dropouts
40
+ self.dropout: float = 0.1 # dropout probability for the transformer
41
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
42
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
43
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
44
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
45
+
46
+ # positional embeddings
47
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
48
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
49
+
50
+ # relative position embedding
51
+ self.relative_position_embedding: bool = False # apply relative position embedding
52
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
53
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
54
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
55
+
56
+ # quantizer
57
+ self.quant_n: int = 1024 # codebook number in quantizer
58
+ self.quant_dim: int = 256 # codebook dimension in quantizer
59
+
60
+ if cfg is not None:
61
+ self.update(cfg)
62
+
63
+ def update(self, cfg: dict):
64
+ self.__dict__.update(cfg)
65
+
66
+
67
+ class Tokenizers(nn.Module):
68
+ def __init__(
69
+ self,
70
+ cfg: TokenizersConfig, ) -> None:
71
+ super().__init__()
72
+ logger.info(f"Tokenizers Config: {cfg.__dict__}")
73
+
74
+ self.cfg = cfg
75
+
76
+ self.embed = cfg.embed_dim
77
+ self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)
78
+ if self.embed != cfg.encoder_embed_dim else
79
+ None)
80
+
81
+ self.input_patch_size = cfg.input_patch_size
82
+ self.patch_embedding = nn.Conv2d(
83
+ 1,
84
+ self.embed,
85
+ kernel_size=self.input_patch_size,
86
+ stride=self.input_patch_size,
87
+ bias=cfg.conv_bias)
88
+
89
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
90
+
91
+ assert not cfg.deep_norm or not cfg.layer_norm_first
92
+ self.encoder = TransformerEncoder(cfg)
93
+ self.layer_norm = LayerNorm(self.embed)
94
+
95
+ self.quantize = NormEMAVectorQuantizer(
96
+ n_embed=cfg.quant_n,
97
+ embedding_dim=cfg.quant_dim,
98
+ beta=1.0,
99
+ kmeans_init=True,
100
+ decay=0.99, )
101
+ self.quant_n = cfg.quant_n
102
+ self.quantize_layer = nn.Sequential(
103
+ nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
104
+ nn.Tanh(),
105
+ nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
106
+ )
107
+
108
+ def forward_padding_mask(
109
+ self,
110
+ features: torch.Tensor,
111
+ padding_mask: torch.Tensor, ) -> torch.Tensor:
112
+ extra = padding_mask.size(1) % features.size(1)
113
+ if extra > 0:
114
+ padding_mask = padding_mask[:, :-extra]
115
+ padding_mask = padding_mask.view(
116
+ padding_mask.size(0), features.size(1), -1)
117
+ padding_mask = padding_mask.all(-1)
118
+ return padding_mask
119
+
120
+ def preprocess(
121
+ self,
122
+ source: torch.Tensor,
123
+ fbank_mean: float=15.41663,
124
+ fbank_std: float=6.55582, ) -> torch.Tensor:
125
+ fbanks = []
126
+ for waveform in source:
127
+ waveform = waveform.unsqueeze(0) * 2**15
128
+ fbank = ta_kaldi.fbank(
129
+ waveform,
130
+ num_mel_bins=128,
131
+ sample_frequency=16000,
132
+ frame_length=25,
133
+ frame_shift=10)
134
+ fbanks.append(fbank)
135
+ fbank = torch.stack(fbanks, dim=0)
136
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
137
+ return fbank
138
+
139
+ def extract_labels(
140
+ self,
141
+ source: torch.Tensor,
142
+ padding_mask: Optional[torch.Tensor]=None,
143
+ fbank_mean: float=15.41663,
144
+ fbank_std: float=6.55582, ):
145
+ fbank = self.preprocess(
146
+ source, fbank_mean=fbank_mean, fbank_std=fbank_std)
147
+
148
+ if padding_mask is not None:
149
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
150
+
151
+ fbank = fbank.unsqueeze(1)
152
+ features = self.patch_embedding(fbank)
153
+ features = features.reshape(features.shape[0], features.shape[1], -1)
154
+ features = features.transpose(1, 2)
155
+ features = self.layer_norm(features)
156
+
157
+ if padding_mask is not None:
158
+ padding_mask = self.forward_padding_mask(features, padding_mask)
159
+
160
+ if self.post_extract_proj is not None:
161
+ features = self.post_extract_proj(features)
162
+
163
+ x = self.dropout_input(features)
164
+
165
+ x, layer_results = self.encoder(
166
+ x,
167
+ padding_mask=padding_mask, )
168
+
169
+ quantize_input = self.quantize_layer(x)
170
+ quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
171
+
172
+ return embed_ind
AR/exps/beats/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # this folder is modified from https://github.com/microsoft/unilm/tree/master/beats
2
+ # ontology.json is from https://github.com/audioset/ontology/
AR/exps/beats/backbone.py ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+ import math
10
+ from typing import Dict
11
+ from typing import Optional
12
+ from typing import Tuple
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch import nn
18
+ from torch import Tensor
19
+ from torch.nn import LayerNorm
20
+ from torch.nn import Parameter
21
+
22
+ from .modules import get_activation_fn
23
+ from .modules import GLU_Linear
24
+ from .modules import GradMultiply
25
+ from .modules import quant_noise
26
+ from .modules import SamePad
27
+
28
+
29
+ class TransformerEncoder(nn.Module):
30
+ def __init__(self, args):
31
+ super().__init__()
32
+
33
+ self.dropout = args.dropout
34
+ self.embedding_dim = args.encoder_embed_dim
35
+
36
+ self.pos_conv = nn.Conv1d(
37
+ self.embedding_dim,
38
+ self.embedding_dim,
39
+ kernel_size=args.conv_pos,
40
+ padding=args.conv_pos // 2,
41
+ groups=args.conv_pos_groups, )
42
+ dropout = 0
43
+ std = math.sqrt(
44
+ (4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
45
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
46
+ nn.init.constant_(self.pos_conv.bias, 0)
47
+
48
+ self.pos_conv = nn.utils.weight_norm(
49
+ self.pos_conv, name="weight", dim=2)
50
+ self.pos_conv = nn.Sequential(self.pos_conv,
51
+ SamePad(args.conv_pos), nn.GELU())
52
+
53
+ if hasattr(args, "relative_position_embedding"):
54
+ self.relative_position_embedding = args.relative_position_embedding
55
+ self.num_buckets = args.num_buckets
56
+ self.max_distance = args.max_distance
57
+ else:
58
+ self.relative_position_embedding = False
59
+ self.num_buckets = 0
60
+ self.max_distance = 0
61
+
62
+ self.layers = nn.ModuleList([
63
+ TransformerSentenceEncoderLayer(
64
+ embedding_dim=self.embedding_dim,
65
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
66
+ num_attention_heads=args.encoder_attention_heads,
67
+ dropout=self.dropout,
68
+ attention_dropout=args.attention_dropout,
69
+ activation_dropout=args.activation_dropout,
70
+ activation_fn=args.activation_fn,
71
+ layer_norm_first=args.layer_norm_first,
72
+ deep_norm=args.deep_norm,
73
+ has_relative_attention_bias=self.relative_position_embedding,
74
+ num_buckets=self.num_buckets,
75
+ max_distance=self.max_distance,
76
+ gru_rel_pos=args.gru_rel_pos,
77
+ encoder_layers=args.encoder_layers, )
78
+ for i in range(args.encoder_layers)
79
+ ])
80
+ if self.relative_position_embedding:
81
+ for i in range(1, args.encoder_layers):
82
+ del self.layers[i].self_attn.relative_attention_bias
83
+ self.layers[i].self_attn.relative_attention_bias = self.layers[
84
+ 0].self_attn.relative_attention_bias
85
+
86
+ self.layer_norm_first = args.layer_norm_first
87
+ self.layer_norm = LayerNorm(self.embedding_dim)
88
+ self.layerdrop = args.encoder_layerdrop
89
+
90
+ self.apply(init_bert_params)
91
+
92
+ if args.deep_norm:
93
+ deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
94
+ for i in range(args.encoder_layers):
95
+ nn.init.xavier_normal_(
96
+ self.layers[i].self_attn.k_proj.weight, gain=1)
97
+ nn.init.xavier_normal_(
98
+ self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
99
+ nn.init.xavier_normal_(
100
+ self.layers[i].self_attn.q_proj.weight, gain=1)
101
+ nn.init.xavier_normal_(
102
+ self.layers[i].self_attn.out_proj.weight,
103
+ gain=deep_norm_beta)
104
+ nn.init.xavier_normal_(
105
+ self.layers[i].fc1.weight, gain=deep_norm_beta)
106
+ nn.init.xavier_normal_(
107
+ self.layers[i].fc2.weight, gain=deep_norm_beta)
108
+
109
+ self.layer_wise_gradient_decay_ratio = getattr(
110
+ args, "layer_wise_gradient_decay_ratio", 1)
111
+
112
+ def forward(self, x, padding_mask=None, layer=None):
113
+ x, layer_results = self.extract_features(x, padding_mask, layer)
114
+
115
+ if self.layer_norm_first and layer is None:
116
+ x = self.layer_norm(x)
117
+
118
+ return x, layer_results
119
+
120
+ def extract_features(self, x, padding_mask=None, tgt_layer=None):
121
+
122
+ if padding_mask is not None:
123
+ x[padding_mask] = 0
124
+
125
+ x_conv = self.pos_conv(x.transpose(1, 2))
126
+ x_conv = x_conv.transpose(1, 2)
127
+ x = x + x_conv
128
+
129
+ if not self.layer_norm_first:
130
+ x = self.layer_norm(x)
131
+
132
+ x = F.dropout(x, p=self.dropout, training=self.training)
133
+
134
+ # B x T x C -> T x B x C
135
+ x = x.transpose(0, 1)
136
+
137
+ layer_results = []
138
+ z = None
139
+ if tgt_layer is not None:
140
+ layer_results.append((x, z))
141
+ r = None
142
+ pos_bias = None
143
+ for i, layer in enumerate(self.layers):
144
+ if self.layer_wise_gradient_decay_ratio != 1.0:
145
+ x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
146
+ dropout_probability = np.random.random()
147
+ if not self.training or (dropout_probability > self.layerdrop):
148
+ x, z, pos_bias = layer(
149
+ x,
150
+ self_attn_padding_mask=padding_mask,
151
+ need_weights=False,
152
+ pos_bias=pos_bias)
153
+ if tgt_layer is not None:
154
+ layer_results.append((x, z))
155
+ if i == tgt_layer:
156
+ r = x
157
+ break
158
+
159
+ if r is not None:
160
+ x = r
161
+
162
+ # T x B x C -> B x T x C
163
+ x = x.transpose(0, 1)
164
+
165
+ return x, layer_results
166
+
167
+
168
+ class TransformerSentenceEncoderLayer(nn.Module):
169
+ def __init__(
170
+ self,
171
+ embedding_dim: float=768,
172
+ ffn_embedding_dim: float=3072,
173
+ num_attention_heads: float=8,
174
+ dropout: float=0.1,
175
+ attention_dropout: float=0.1,
176
+ activation_dropout: float=0.1,
177
+ activation_fn: str="relu",
178
+ layer_norm_first: bool=False,
179
+ deep_norm: bool=False,
180
+ has_relative_attention_bias: bool=False,
181
+ num_buckets: int=0,
182
+ max_distance: int=0,
183
+ rescale_init: bool=False,
184
+ gru_rel_pos: bool=False,
185
+ encoder_layers: int=0, ) -> None:
186
+
187
+ super().__init__()
188
+ self.embedding_dim = embedding_dim
189
+ self.dropout = dropout
190
+ self.activation_dropout = activation_dropout
191
+
192
+ self.activation_name = activation_fn
193
+ self.activation_fn = get_activation_fn(activation_fn)
194
+ self.self_attn = MultiheadAttention(
195
+ self.embedding_dim,
196
+ num_attention_heads,
197
+ dropout=attention_dropout,
198
+ self_attention=True,
199
+ has_relative_attention_bias=has_relative_attention_bias,
200
+ num_buckets=num_buckets,
201
+ max_distance=max_distance,
202
+ rescale_init=rescale_init,
203
+ gru_rel_pos=gru_rel_pos, )
204
+
205
+ self.dropout1 = nn.Dropout(dropout)
206
+ self.dropout2 = nn.Dropout(self.activation_dropout)
207
+ self.dropout3 = nn.Dropout(dropout)
208
+
209
+ self.layer_norm_first = layer_norm_first
210
+
211
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
212
+
213
+ if self.activation_name == "glu":
214
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim,
215
+ "swish")
216
+ else:
217
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
218
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
219
+
220
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
221
+
222
+ self.deep_norm = deep_norm
223
+ if self.deep_norm:
224
+ self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
225
+ else:
226
+ self.deep_norm_alpha = 1
227
+
228
+ def forward(self,
229
+ x: torch.Tensor,
230
+ self_attn_mask: torch.Tensor=None,
231
+ self_attn_padding_mask: torch.Tensor=None,
232
+ need_weights: bool=False,
233
+ pos_bias=None):
234
+ residual = x
235
+
236
+ if self.layer_norm_first:
237
+ x = self.self_attn_layer_norm(x)
238
+ x, attn, pos_bias = self.self_attn(
239
+ query=x,
240
+ key=x,
241
+ value=x,
242
+ key_padding_mask=self_attn_padding_mask,
243
+ need_weights=False,
244
+ attn_mask=self_attn_mask,
245
+ position_bias=pos_bias)
246
+ x = self.dropout1(x)
247
+ x = residual + x
248
+
249
+ residual = x
250
+ x = self.final_layer_norm(x)
251
+ if self.activation_name == "glu":
252
+ x = self.fc1(x)
253
+ else:
254
+ x = self.activation_fn(self.fc1(x))
255
+ x = self.dropout2(x)
256
+ x = self.fc2(x)
257
+ x = self.dropout3(x)
258
+ x = residual + x
259
+ else:
260
+ x, attn, pos_bias = self.self_attn(
261
+ query=x,
262
+ key=x,
263
+ value=x,
264
+ key_padding_mask=self_attn_padding_mask,
265
+ need_weights=need_weights,
266
+ attn_mask=self_attn_mask,
267
+ position_bias=pos_bias)
268
+
269
+ x = self.dropout1(x)
270
+ x = residual * self.deep_norm_alpha + x
271
+
272
+ x = self.self_attn_layer_norm(x)
273
+
274
+ residual = x
275
+ if self.activation_name == "glu":
276
+ x = self.fc1(x)
277
+ else:
278
+ x = self.activation_fn(self.fc1(x))
279
+ x = self.dropout2(x)
280
+ x = self.fc2(x)
281
+ x = self.dropout3(x)
282
+ x = residual * self.deep_norm_alpha + x
283
+ x = self.final_layer_norm(x)
284
+
285
+ return x, attn, pos_bias
286
+
287
+
288
+ class MultiheadAttention(nn.Module):
289
+ """Multi-headed attention.
290
+
291
+ See "Attention Is All You Need" for more details.
292
+ """
293
+
294
+ def __init__(
295
+ self,
296
+ embed_dim,
297
+ num_heads,
298
+ kdim=None,
299
+ vdim=None,
300
+ dropout=0.0,
301
+ bias=True,
302
+ add_bias_kv=False,
303
+ add_zero_attn=False,
304
+ self_attention=False,
305
+ encoder_decoder_attention=False,
306
+ q_noise=0.0,
307
+ qn_block_size=8,
308
+ has_relative_attention_bias=False,
309
+ num_buckets=32,
310
+ max_distance=128,
311
+ gru_rel_pos=False,
312
+ rescale_init=False, ):
313
+ super().__init__()
314
+ self.embed_dim = embed_dim
315
+ self.kdim = kdim if kdim is not None else embed_dim
316
+ self.vdim = vdim if vdim is not None else embed_dim
317
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
318
+
319
+ self.num_heads = num_heads
320
+ self.dropout_module = nn.Dropout(dropout)
321
+
322
+ self.has_relative_attention_bias = has_relative_attention_bias
323
+ self.num_buckets = num_buckets
324
+ self.max_distance = max_distance
325
+ if self.has_relative_attention_bias:
326
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
327
+
328
+ self.head_dim = embed_dim // num_heads
329
+ self.q_head_dim = self.head_dim
330
+ self.k_head_dim = self.head_dim
331
+ assert (self.head_dim * num_heads == self.embed_dim
332
+ ), "embed_dim must be divisible by num_heads"
333
+ self.scaling = self.head_dim**-0.5
334
+
335
+ self.self_attention = self_attention
336
+ self.encoder_decoder_attention = encoder_decoder_attention
337
+
338
+ assert not self.self_attention or self.qkv_same_dim, (
339
+ "Self-attention requires query, key and "
340
+ "value to be of the same size")
341
+
342
+ k_bias = True
343
+ if rescale_init:
344
+ k_bias = False
345
+
346
+ k_embed_dim = embed_dim
347
+ q_embed_dim = embed_dim
348
+
349
+ self.k_proj = quant_noise(
350
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise,
351
+ qn_block_size)
352
+ self.v_proj = quant_noise(
353
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
354
+ self.q_proj = quant_noise(
355
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise,
356
+ qn_block_size)
357
+
358
+ self.out_proj = quant_noise(
359
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
360
+
361
+ if add_bias_kv:
362
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
363
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
364
+ else:
365
+ self.bias_k = self.bias_v = None
366
+
367
+ self.add_zero_attn = add_zero_attn
368
+
369
+ self.gru_rel_pos = gru_rel_pos
370
+ if self.gru_rel_pos:
371
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
372
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
373
+
374
+ self.reset_parameters()
375
+
376
+ def reset_parameters(self):
377
+ if self.qkv_same_dim:
378
+ # Empirically observed the convergence to be much better with
379
+ # the scaled initialization
380
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
381
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
382
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
383
+ else:
384
+ nn.init.xavier_uniform_(self.k_proj.weight)
385
+ nn.init.xavier_uniform_(self.v_proj.weight)
386
+ nn.init.xavier_uniform_(self.q_proj.weight)
387
+
388
+ nn.init.xavier_uniform_(self.out_proj.weight)
389
+ if self.out_proj.bias is not None:
390
+ nn.init.constant_(self.out_proj.bias, 0.0)
391
+ if self.bias_k is not None:
392
+ nn.init.xavier_normal_(self.bias_k)
393
+ if self.bias_v is not None:
394
+ nn.init.xavier_normal_(self.bias_v)
395
+ if self.has_relative_attention_bias:
396
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
397
+
398
+ def _relative_positions_bucket(self, relative_positions,
399
+ bidirectional=True):
400
+ num_buckets = self.num_buckets
401
+ max_distance = self.max_distance
402
+ relative_buckets = 0
403
+
404
+ if bidirectional:
405
+ num_buckets = num_buckets // 2
406
+ relative_buckets += (
407
+ relative_positions > 0).to(torch.long) * num_buckets
408
+ relative_positions = torch.abs(relative_positions)
409
+ else:
410
+ relative_positions = -torch.min(
411
+ relative_positions, torch.zeros_like(relative_positions))
412
+
413
+ max_exact = num_buckets // 2
414
+ is_small = relative_positions < max_exact
415
+
416
+ relative_postion_if_large = max_exact + (
417
+ torch.log(relative_positions.float() / max_exact) / math.log(
418
+ max_distance / max_exact) *
419
+ (num_buckets - max_exact)).to(torch.long)
420
+ relative_postion_if_large = torch.min(
421
+ relative_postion_if_large,
422
+ torch.full_like(relative_postion_if_large, num_buckets - 1))
423
+
424
+ relative_buckets += torch.where(is_small, relative_positions,
425
+ relative_postion_if_large)
426
+ return relative_buckets
427
+
428
+ def compute_bias(self, query_length, key_length):
429
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
430
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
431
+ relative_position = memory_position - context_position
432
+ relative_position_bucket = self._relative_positions_bucket(
433
+ relative_position, bidirectional=True)
434
+ relative_position_bucket = relative_position_bucket.to(
435
+ self.relative_attention_bias.weight.device)
436
+ values = self.relative_attention_bias(relative_position_bucket)
437
+ values = values.permute([2, 0, 1])
438
+ return values
439
+
440
+ def forward(self,
441
+ query,
442
+ key: Optional[Tensor],
443
+ value: Optional[Tensor],
444
+ key_padding_mask: Optional[Tensor]=None,
445
+ incremental_state: Optional[Dict[str, Dict[str, Optional[
446
+ Tensor]]]]=None,
447
+ need_weights: bool=True,
448
+ static_kv: bool=False,
449
+ attn_mask: Optional[Tensor]=None,
450
+ before_softmax: bool=False,
451
+ need_head_weights: bool=False,
452
+ position_bias: Optional[Tensor]=None
453
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
454
+ """Input shape: Time x Batch x Channel
455
+
456
+ Args:
457
+ key_padding_mask (ByteTensor, optional): mask to exclude
458
+ keys that are pads, of shape `(batch, src_len)`, where
459
+ padding elements are indicated by 1s.
460
+ need_weights (bool, optional): return the attention weights,
461
+ averaged over heads (default: False).
462
+ attn_mask (ByteTensor, optional): typically used to
463
+ implement causal attention, where the mask prevents the
464
+ attention from looking forward in time (default: None).
465
+ before_softmax (bool, optional): return the raw attention
466
+ weights and values before the attention softmax.
467
+ need_head_weights (bool, optional): return the attention
468
+ weights for each head. Implies *need_weights*. Default:
469
+ return the average attention weights over all heads.
470
+ """
471
+ if need_head_weights:
472
+ need_weights = True
473
+
474
+ is_tpu = query.device.type == "xla"
475
+
476
+ tgt_len, bsz, embed_dim = query.size()
477
+ src_len = tgt_len
478
+ assert embed_dim == self.embed_dim
479
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
480
+ if key is not None:
481
+ src_len, key_bsz, _ = key.size()
482
+ if not torch.jit.is_scripting():
483
+ assert key_bsz == bsz
484
+ assert value is not None
485
+ assert src_len, bsz == value.shape[:2]
486
+
487
+ if self.has_relative_attention_bias and position_bias is None:
488
+ position_bias = self.compute_bias(tgt_len, src_len)
489
+ position_bias = position_bias.unsqueeze(0).repeat(
490
+ bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
491
+
492
+ if incremental_state is not None:
493
+ saved_state = self._get_input_buffer(incremental_state)
494
+ if saved_state is not None and "prev_key" in saved_state:
495
+ # previous time steps are cached - no need to recompute
496
+ # key and value if they are static
497
+ if static_kv:
498
+ assert self.encoder_decoder_attention and not self.self_attention
499
+ key = value = None
500
+ else:
501
+ saved_state = None
502
+
503
+ if self.self_attention:
504
+ q = self.q_proj(query)
505
+ k = self.k_proj(query)
506
+ v = self.v_proj(query)
507
+ elif self.encoder_decoder_attention:
508
+ # encoder-decoder attention
509
+ q = self.q_proj(query)
510
+ if key is None:
511
+ assert value is None
512
+ k = v = None
513
+ else:
514
+ k = self.k_proj(key)
515
+ v = self.v_proj(key)
516
+
517
+ else:
518
+ assert key is not None and value is not None
519
+ q = self.q_proj(query)
520
+ k = self.k_proj(key)
521
+ v = self.v_proj(value)
522
+ q *= self.scaling
523
+ alpha = 32
524
+ q *= 1 / alpha
525
+
526
+ if self.bias_k is not None:
527
+ assert self.bias_v is not None
528
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
529
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
530
+ if attn_mask is not None:
531
+ attn_mask = torch.cat(
532
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)],
533
+ dim=1)
534
+ if key_padding_mask is not None:
535
+ key_padding_mask = torch.cat(
536
+ [
537
+ key_padding_mask,
538
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
539
+ ],
540
+ dim=1, )
541
+
542
+ q = (q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim)
543
+ .transpose(0, 1))
544
+ if k is not None:
545
+ k = (k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim)
546
+ .transpose(0, 1))
547
+ if v is not None:
548
+ v = (v.contiguous().view(-1, bsz * self.num_heads, self.head_dim)
549
+ .transpose(0, 1))
550
+
551
+ if saved_state is not None:
552
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
553
+ if "prev_key" in saved_state:
554
+ _prev_key = saved_state["prev_key"]
555
+ assert _prev_key is not None
556
+ prev_key = _prev_key.view(bsz * self.num_heads, -1,
557
+ self.head_dim)
558
+ if static_kv:
559
+ k = prev_key
560
+ else:
561
+ assert k is not None
562
+ k = torch.cat([prev_key, k], dim=1)
563
+ src_len = k.size(1)
564
+ if "prev_value" in saved_state:
565
+ _prev_value = saved_state["prev_value"]
566
+ assert _prev_value is not None
567
+ prev_value = _prev_value.view(bsz * self.num_heads, -1,
568
+ self.head_dim)
569
+ if static_kv:
570
+ v = prev_value
571
+ else:
572
+ assert v is not None
573
+ v = torch.cat([prev_value, v], dim=1)
574
+ prev_key_padding_mask: Optional[Tensor] = None
575
+ if "prev_key_padding_mask" in saved_state:
576
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
577
+ assert k is not None and v is not None
578
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
579
+ key_padding_mask=key_padding_mask,
580
+ prev_key_padding_mask=prev_key_padding_mask,
581
+ batch_size=bsz,
582
+ src_len=k.size(1),
583
+ static_kv=static_kv, )
584
+
585
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1,
586
+ self.head_dim)
587
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1,
588
+ self.head_dim)
589
+ saved_state["prev_key_padding_mask"] = key_padding_mask
590
+ # In this branch incremental_state is never None
591
+ assert incremental_state is not None
592
+ incremental_state = self._set_input_buffer(incremental_state,
593
+ saved_state)
594
+ assert k is not None
595
+ assert k.size(1) == src_len
596
+
597
+ # This is part of a workaround to get around fork/join parallelism
598
+ # not supporting Optional types.
599
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
600
+ key_padding_mask = None
601
+
602
+ if key_padding_mask is not None:
603
+ assert key_padding_mask.size(0) == bsz
604
+ assert key_padding_mask.size(1) == src_len
605
+
606
+ if self.add_zero_attn:
607
+ assert v is not None
608
+ src_len += 1
609
+ k = torch.cat(
610
+ [k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
611
+ v = torch.cat(
612
+ [v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
613
+ if attn_mask is not None:
614
+ attn_mask = torch.cat(
615
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)],
616
+ dim=1)
617
+ if key_padding_mask is not None:
618
+ key_padding_mask = torch.cat(
619
+ [
620
+ key_padding_mask,
621
+ torch.zeros(key_padding_mask.size(0),
622
+ 1).type_as(key_padding_mask),
623
+ ],
624
+ dim=1, )
625
+
626
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
627
+ attn_weights = (
628
+ attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
629
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
630
+ bsz)
631
+
632
+ assert list(
633
+ attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
634
+
635
+ if attn_mask is not None:
636
+ attn_mask = attn_mask.unsqueeze(0)
637
+ attn_weights += attn_mask
638
+
639
+ if key_padding_mask is not None:
640
+ # don't attend to padding symbols
641
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
642
+ src_len)
643
+ if not is_tpu:
644
+ attn_weights = attn_weights.masked_fill(
645
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
646
+ float("-inf"), )
647
+ else:
648
+ attn_weights = attn_weights.transpose(0, 2)
649
+ attn_weights = attn_weights.masked_fill(key_padding_mask,
650
+ float("-inf"))
651
+ attn_weights = attn_weights.transpose(0, 2)
652
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
653
+ src_len)
654
+
655
+ if before_softmax:
656
+ return attn_weights, v, position_bias
657
+
658
+ if position_bias is not None:
659
+ attn_mask_rel_pos = position_bias
660
+ if self.gru_rel_pos == 1:
661
+ query_layer = q.view(bsz, self.num_heads, tgt_len,
662
+ self.q_head_dim) * alpha / self.scaling
663
+ _B, _H, _L, __ = query_layer.size()
664
+ gate_a, gate_b = torch.sigmoid(
665
+ self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(
666
+ -1, keepdim=False)).chunk(
667
+ 2, dim=-1)
668
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
669
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len,
670
+ 1) * position_bias
671
+
672
+ attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
673
+
674
+ attn_weights = attn_weights + attn_mask_rel_pos
675
+
676
+ attn_weights_float = F.softmax(attn_weights, dim=-1)
677
+ attn_weights = attn_weights_float.type_as(attn_weights)
678
+ attn_probs = self.dropout_module(attn_weights)
679
+
680
+ assert v is not None
681
+ attn = torch.bmm(attn_probs, v)
682
+ assert list(
683
+ attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
684
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
685
+ attn = self.out_proj(attn)
686
+ attn_weights: Optional[Tensor] = None
687
+ if need_weights:
688
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len,
689
+ src_len).transpose(1, 0)
690
+ if not need_head_weights:
691
+ # average attention weights over heads
692
+ attn_weights = attn_weights.mean(dim=0)
693
+
694
+ return attn, attn_weights, position_bias
695
+
696
+ @staticmethod
697
+ def _append_prev_key_padding_mask(
698
+ key_padding_mask: Optional[Tensor],
699
+ prev_key_padding_mask: Optional[Tensor],
700
+ batch_size: int,
701
+ src_len: int,
702
+ static_kv: bool, ) -> Optional[Tensor]:
703
+ # saved key padding masks have shape (bsz, seq_len)
704
+ if prev_key_padding_mask is not None and static_kv:
705
+ new_key_padding_mask = prev_key_padding_mask
706
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
707
+ new_key_padding_mask = torch.cat(
708
+ [prev_key_padding_mask.float(), key_padding_mask.float()],
709
+ dim=1)
710
+ # During incremental decoding, as the padding token enters and
711
+ # leaves the frame, there will be a time when prev or current
712
+ # is None
713
+ elif prev_key_padding_mask is not None:
714
+ if src_len > prev_key_padding_mask.size(1):
715
+ filler = torch.zeros(
716
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
717
+ device=prev_key_padding_mask.device, )
718
+ new_key_padding_mask = torch.cat(
719
+ [prev_key_padding_mask.float(), filler.float()], dim=1)
720
+ else:
721
+ new_key_padding_mask = prev_key_padding_mask.float()
722
+ elif key_padding_mask is not None:
723
+ if src_len > key_padding_mask.size(1):
724
+ filler = torch.zeros(
725
+ (batch_size, src_len - key_padding_mask.size(1)),
726
+ device=key_padding_mask.device, )
727
+ new_key_padding_mask = torch.cat(
728
+ [filler.float(), key_padding_mask.float()], dim=1)
729
+ else:
730
+ new_key_padding_mask = key_padding_mask.float()
731
+ else:
732
+ new_key_padding_mask = prev_key_padding_mask
733
+ return new_key_padding_mask
734
+
735
+ def _get_input_buffer(
736
+ self,
737
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
738
+ ) -> Dict[str, Optional[Tensor]]:
739
+ result = self.get_incremental_state(incremental_state, "attn_state")
740
+ if result is not None:
741
+ return result
742
+ else:
743
+ empty_result: Dict[str, Optional[Tensor]] = {}
744
+ return empty_result
745
+
746
+ def _set_input_buffer(
747
+ self,
748
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
749
+ buffer: Dict[str, Optional[Tensor]], ):
750
+ return self.set_incremental_state(incremental_state, "attn_state",
751
+ buffer)
752
+
753
+ def apply_sparse_mask(self,
754
+ attn_weights,
755
+ tgt_len: int,
756
+ src_len: int,
757
+ bsz: int):
758
+ return attn_weights
759
+
760
+
761
+ def init_bert_params(module):
762
+ """
763
+ Initialize the weights specific to the BERT Model.
764
+ This overrides the default initializations depending on the specified arguments.
765
+ 1. If normal_init_linear_weights is set then weights of linear
766
+ layer will be initialized using the normal distribution and
767
+ bais will be set to the specified value.
768
+ 2. If normal_init_embed_weights is set then weights of embedding
769
+ layer will be initialized using the normal distribution.
770
+ 3. If normal_init_proj_weights is set then weights of
771
+ in_project_weight for MultiHeadAttention initialized using
772
+ the normal distribution (to be validated).
773
+ """
774
+
775
+ def normal_(data):
776
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
777
+ # so that the RNG is consistent with and without FSDP
778
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
779
+
780
+ if isinstance(module, nn.Linear):
781
+ normal_(module.weight.data)
782
+ if module.bias is not None:
783
+ module.bias.data.zero_()
784
+ if isinstance(module, nn.Embedding):
785
+ normal_(module.weight.data)
786
+ if module.padding_idx is not None:
787
+ module.weight.data[module.padding_idx].zero_()
788
+ if isinstance(module, MultiheadAttention):
789
+ normal_(module.q_proj.weight.data)
790
+ normal_(module.k_proj.weight.data)
791
+ normal_(module.v_proj.weight.data)
AR/exps/beats/config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ # 获取当前脚本的所在目录
5
+ script_dir = os.path.dirname(os.path.abspath(__file__))
6
+
7
+ # JSON 文件的文件名
8
+ json_filename = "ontology.json"
9
+
10
+ # 构建 JSON 文件的完整路径
11
+ json_path = os.path.join(script_dir, json_filename)
12
+
13
+ id_name_dict = {}
14
+
15
+ with open(json_path, 'r') as f:
16
+ json_items = json.load(f)
17
+ # '/m/0dgw9r' -> 'Human sounds' and etc.
18
+ for item in json_items:
19
+ id_name_dict[item['id']] = item['name']
AR/exps/beats/modules.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+ import math
10
+ import warnings
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+
16
+
17
+ class GradMultiply(torch.autograd.Function):
18
+ @staticmethod
19
+ def forward(ctx, x, scale):
20
+ ctx.scale = scale
21
+ res = x.new(x)
22
+ return res
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad):
26
+ return grad * ctx.scale, None
27
+
28
+
29
+ class SamePad(nn.Module):
30
+ def __init__(self, kernel_size, causal=False):
31
+ super().__init__()
32
+ if causal:
33
+ self.remove = kernel_size - 1
34
+ else:
35
+ self.remove = 1 if kernel_size % 2 == 0 else 0
36
+
37
+ def forward(self, x):
38
+ if self.remove > 0:
39
+ x = x[:, :, :-self.remove]
40
+ return x
41
+
42
+
43
+ class Swish(nn.Module):
44
+ def __init__(self):
45
+ super(Swish, self).__init__()
46
+ self.act = torch.nn.Sigmoid()
47
+
48
+ def forward(self, x):
49
+ return x * self.act(x)
50
+
51
+
52
+ class GLU_Linear(nn.Module):
53
+ def __init__(self,
54
+ input_dim,
55
+ output_dim,
56
+ glu_type="sigmoid",
57
+ bias_in_glu=True):
58
+ super(GLU_Linear, self).__init__()
59
+
60
+ self.glu_type = glu_type
61
+ self.output_dim = output_dim
62
+
63
+ if glu_type == "sigmoid":
64
+ self.glu_act = torch.nn.Sigmoid()
65
+ elif glu_type == "swish":
66
+ self.glu_act = Swish()
67
+ elif glu_type == "relu":
68
+ self.glu_act = torch.nn.ReLU()
69
+ elif glu_type == "gelu":
70
+ self.glu_act = torch.nn.GELU()
71
+
72
+ if bias_in_glu:
73
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
74
+ else:
75
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
76
+
77
+ def forward(self, x):
78
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
79
+ x = self.linear(x)
80
+
81
+ if self.glu_type == "bilinear":
82
+ x = (x[:, :, 0:self.output_dim] *
83
+ x[:, :, self.output_dim:self.output_dim * 2])
84
+ else:
85
+ x = (x[:, :, 0:self.output_dim] *
86
+ self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
87
+
88
+ return x
89
+
90
+
91
+ def gelu_accurate(x):
92
+ if not hasattr(gelu_accurate, "_a"):
93
+ gelu_accurate._a = math.sqrt(2 / math.pi)
94
+ return (0.5 * x * (1 + torch.tanh(gelu_accurate._a *
95
+ (x + 0.044715 * torch.pow(x, 3)))))
96
+
97
+
98
+ def gelu(x: torch.Tensor) -> torch.Tensor:
99
+ return torch.nn.functional.gelu(x.float()).type_as(x)
100
+
101
+
102
+ def get_activation_fn(activation: str):
103
+ """Returns the activation function corresponding to `activation`"""
104
+
105
+ if activation == "relu":
106
+ return F.relu
107
+ elif activation == "gelu":
108
+ return gelu
109
+ elif activation == "gelu_fast":
110
+ warnings.warn(
111
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate")
112
+ return gelu_accurate
113
+ elif activation == "gelu_accurate":
114
+ return gelu_accurate
115
+ elif activation == "tanh":
116
+ return torch.tanh
117
+ elif activation == "linear":
118
+ return lambda x: x
119
+ elif activation == "glu":
120
+ return lambda x: x
121
+ else:
122
+ raise RuntimeError(
123
+ "--activation-fn {} not supported".format(activation))
124
+
125
+
126
+ def quant_noise(module, p, block_size):
127
+ """
128
+ Wraps modules and applies quantization noise to the weights for
129
+ subsequent quantization with Iterative Product Quantization as
130
+ described in "Training with Quantization Noise for Extreme Model Compression"
131
+
132
+ Args:
133
+ - module: nn.Module
134
+ - p: amount of Quantization Noise
135
+ - block_size: size of the blocks for subsequent quantization with iPQ
136
+
137
+ Remarks:
138
+ - Module weights must have the right sizes wrt the block size
139
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
140
+ - For more detail on how to quantize by blocks with convolutional weights,
141
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
142
+ - We implement the simplest form of noise here as stated in the paper
143
+ which consists in randomly dropping blocks
144
+ """
145
+
146
+ # if no quantization noise, don't register hook
147
+ if p <= 0:
148
+ return module
149
+
150
+ # supported modules
151
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
152
+
153
+ # test whether module.weight has the right sizes wrt block_size
154
+ is_conv = module.weight.ndim == 4
155
+
156
+ # 2D matrix
157
+ if not is_conv:
158
+ assert (
159
+ module.weight.size(1) %
160
+ block_size == 0), "Input features must be a multiple of block sizes"
161
+
162
+ # 4D matrix
163
+ else:
164
+ # 1x1 convolutions
165
+ if module.kernel_size == (1, 1):
166
+ assert (module.in_channels % block_size == 0
167
+ ), "Input channels must be a multiple of block sizes"
168
+ # regular convolutions
169
+ else:
170
+ k = module.kernel_size[0] * module.kernel_size[1]
171
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
172
+
173
+ def _forward_pre_hook(mod, input):
174
+ # no noise for evaluation
175
+ if mod.training:
176
+ if not is_conv:
177
+ # gather weight and sizes
178
+ weight = mod.weight
179
+ in_features = weight.size(1)
180
+ out_features = weight.size(0)
181
+
182
+ # split weight matrix into blocks and randomly drop selected blocks
183
+ mask = torch.zeros(
184
+ in_features // block_size * out_features,
185
+ device=weight.device)
186
+ mask.bernoulli_(p)
187
+ mask = mask.repeat_interleave(block_size, -1).view(-1,
188
+ in_features)
189
+
190
+ else:
191
+ # gather weight and sizes
192
+ weight = mod.weight
193
+ in_channels = mod.in_channels
194
+ out_channels = mod.out_channels
195
+
196
+ # split weight matrix into blocks and randomly drop selected blocks
197
+ if mod.kernel_size == (1, 1):
198
+ mask = torch.zeros(
199
+ int(in_channels // block_size * out_channels),
200
+ device=weight.device, )
201
+ mask.bernoulli_(p)
202
+ mask = mask.repeat_interleave(block_size, -1).view(
203
+ -1, in_channels)
204
+ else:
205
+ mask = torch.zeros(
206
+ weight.size(0), weight.size(1), device=weight.device)
207
+ mask.bernoulli_(p)
208
+ mask = (
209
+ mask.unsqueeze(2).unsqueeze(3)
210
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]))
211
+
212
+ # scale weights and apply mask
213
+ mask = mask.to(
214
+ torch.
215
+ bool) # x.bool() is not currently supported in TorchScript
216
+ s = 1 / (1 - p)
217
+ mod.weight.data = s * weight.masked_fill(mask, 0)
218
+
219
+ module.register_forward_pre_hook(_forward_pre_hook)
220
+ return module
AR/exps/beats/ontology.json ADDED
The diff for this file is too large to render. See raw diff
 
AR/exps/beats/quantizer.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on VQGAN code bases
7
+ # https://github.com/CompVis/taming-transformers
8
+ # --------------------------------------------------------'
9
+ import torch
10
+ import torch.distributed as distributed
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ try:
15
+ from einops import rearrange, repeat
16
+ except ImportError:
17
+ pass
18
+
19
+
20
+ def l2norm(t):
21
+ return F.normalize(t, p=2, dim=-1)
22
+
23
+
24
+ def ema_inplace(moving_avg, new, decay):
25
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
26
+
27
+
28
+ def sample_vectors(samples, num):
29
+ num_samples, device = samples.shape[0], samples.device
30
+
31
+ if num_samples >= num:
32
+ indices = torch.randperm(num_samples, device=device)[:num]
33
+ else:
34
+ indices = torch.randint(0, num_samples, (num, ), device=device)
35
+
36
+ return samples[indices]
37
+
38
+
39
+ def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
40
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
41
+
42
+ means = sample_vectors(samples, num_clusters)
43
+
44
+ for _ in range(num_iters):
45
+ if use_cosine_sim:
46
+ dists = samples @ means.t()
47
+ else:
48
+ diffs = rearrange(samples, 'n d -> n () d') \
49
+ - rearrange(means, 'c d -> () c d')
50
+ dists = -(diffs**2).sum(dim=-1)
51
+
52
+ buckets = dists.max(dim=-1).indices
53
+ bins = torch.bincount(buckets, minlength=num_clusters)
54
+ zero_mask = bins == 0
55
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
56
+
57
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
58
+ new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
59
+ new_means = new_means / bins_min_clamped[..., None]
60
+
61
+ if use_cosine_sim:
62
+ new_means = l2norm(new_means)
63
+
64
+ means = torch.where(zero_mask[..., None], means, new_means)
65
+
66
+ return means, bins
67
+
68
+
69
+ class EmbeddingEMA(nn.Module):
70
+ def __init__(self,
71
+ num_tokens,
72
+ codebook_dim,
73
+ decay=0.99,
74
+ eps=1e-5,
75
+ kmeans_init=True,
76
+ codebook_init_path=''):
77
+ super().__init__()
78
+ self.num_tokens = num_tokens
79
+ self.codebook_dim = codebook_dim
80
+ self.decay = decay
81
+ self.eps = eps
82
+ if codebook_init_path == '':
83
+ if not kmeans_init:
84
+ weight = torch.randn(num_tokens, codebook_dim)
85
+ weight = l2norm(weight)
86
+ else:
87
+ weight = torch.zeros(num_tokens, codebook_dim)
88
+ self.register_buffer('initted', torch.Tensor([not kmeans_init]))
89
+ else:
90
+ print(f"load init codebook weight from {codebook_init_path}")
91
+ codebook_ckpt_weight = torch.load(
92
+ codebook_init_path, map_location='cpu')
93
+ weight = codebook_ckpt_weight.clone()
94
+ self.register_buffer('initted', torch.Tensor([True]))
95
+
96
+ self.weight = nn.Parameter(weight, requires_grad=False)
97
+ self.cluster_size = nn.Parameter(
98
+ torch.zeros(num_tokens), requires_grad=False)
99
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
100
+ # self.register_buffer('initted', torch.Tensor([not kmeans_init]))
101
+ self.update = True
102
+
103
+ @torch.jit.ignore
104
+ def init_embed_(self, data):
105
+ if self.initted:
106
+ return
107
+ print("Performing Kemans init for codebook")
108
+ embed, cluster_size = kmeans(
109
+ data, self.num_tokens, 10, use_cosine_sim=True)
110
+ self.weight.data.copy_(embed)
111
+ self.cluster_size.data.copy_(cluster_size)
112
+ self.initted.data.copy_(torch.Tensor([True]))
113
+
114
+ def forward(self, embed_id):
115
+ return F.embedding(embed_id, self.weight)
116
+
117
+ def cluster_size_ema_update(self, new_cluster_size):
118
+ self.cluster_size.data.mul_(self.decay).add_(
119
+ new_cluster_size, alpha=1 - self.decay)
120
+
121
+ def embed_avg_ema_update(self, new_embed_avg):
122
+ self.embed_avg.data.mul_(self.decay).add_(
123
+ new_embed_avg, alpha=1 - self.decay)
124
+
125
+ def weight_update(self, num_tokens):
126
+ n = self.cluster_size.sum()
127
+ smoothed_cluster_size = (
128
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n)
129
+ # normalize embedding average with smoothed cluster size
130
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
131
+ # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
132
+ self.weight.data.copy_(embed_normalized)
133
+
134
+
135
+ def norm_ema_inplace(moving_avg, new, decay):
136
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
137
+ moving_avg.data.copy_(l2norm(moving_avg.data))
138
+
139
+
140
+ class NormEMAVectorQuantizer(nn.Module):
141
+ def __init__(self,
142
+ n_embed,
143
+ embedding_dim,
144
+ beta,
145
+ decay=0.99,
146
+ eps=1e-5,
147
+ statistic_code_usage=True,
148
+ kmeans_init=False,
149
+ codebook_init_path=''):
150
+ super().__init__()
151
+ self.codebook_dim = embedding_dim
152
+ self.num_tokens = n_embed
153
+ self.beta = beta
154
+ self.decay = decay
155
+
156
+ # learnable = True if orthogonal_reg_weight > 0 else False
157
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay,
158
+ eps, kmeans_init, codebook_init_path)
159
+
160
+ self.statistic_code_usage = statistic_code_usage
161
+ if statistic_code_usage:
162
+ self.register_buffer('cluster_size', torch.zeros(n_embed))
163
+ if distributed.is_available() and distributed.is_initialized():
164
+ print(
165
+ "ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!"
166
+ )
167
+ self.all_reduce_fn = distributed.all_reduce
168
+ else:
169
+ self.all_reduce_fn = nn.Identity()
170
+
171
+ def reset_cluster_size(self, device):
172
+ if self.statistic_code_usage:
173
+ self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
174
+ self.cluster_size = self.cluster_size.to(device)
175
+
176
+ def forward(self, z):
177
+ # reshape z -> (batch, height, width, channel) and flatten
178
+ # z, 'b c h w -> b h w c'
179
+ # z = rearrange(z, 'b c h w -> b h w c')
180
+ # z = z.transpose(1, 2)
181
+ z = l2norm(z)
182
+ z_flattened = z.reshape(-1, self.codebook_dim)
183
+
184
+ self.embedding.init_embed_(z_flattened)
185
+
186
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
187
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
188
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
189
+
190
+ encoding_indices = torch.argmin(d, dim=1)
191
+
192
+ z_q = self.embedding(encoding_indices).view(z.shape)
193
+
194
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
195
+
196
+ if not self.training:
197
+ with torch.no_grad():
198
+ cluster_size = encodings.sum(0)
199
+ self.all_reduce_fn(cluster_size)
200
+ ema_inplace(self.cluster_size, cluster_size, self.decay)
201
+
202
+ if self.training and self.embedding.update:
203
+ # EMA cluster size
204
+
205
+ bins = encodings.sum(0)
206
+ self.all_reduce_fn(bins)
207
+
208
+ # self.embedding.cluster_size_ema_update(bins)
209
+ ema_inplace(self.cluster_size, bins, self.decay)
210
+
211
+ zero_mask = (bins == 0)
212
+ bins = bins.masked_fill(zero_mask, 1.)
213
+
214
+ embed_sum = z_flattened.t() @ encodings
215
+ self.all_reduce_fn(embed_sum)
216
+
217
+ embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
218
+ embed_normalized = l2norm(embed_normalized)
219
+
220
+ embed_normalized = torch.where(
221
+ zero_mask[..., None], self.embedding.weight, embed_normalized)
222
+ norm_ema_inplace(self.embedding.weight, embed_normalized,
223
+ self.decay)
224
+
225
+ # compute loss for embedding
226
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
227
+
228
+ # preserve gradients
229
+ z_q = z + (z_q - z).detach()
230
+
231
+ # reshape back to match original input shape
232
+ # z_q, 'b h w c -> b c h w'
233
+ # z_q = rearrange(z_q, 'b h w c -> b c h w')
234
+ # z_q = z_q.transpose(1, 2)
235
+ return z_q, loss, encoding_indices
AR/exps/get_beats_librilight.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use AudioTag tool BEATs to filter out audios who's top1 tag is not 'speech'
2
+ # non_speech.npy, 存储一个 python dict 表示非 speech 类型的音频的 tag, 更小,加载和搜索速度更快
3
+ # audio_tag 目录存储 {utt_id}.txt, 第一行是小写的 top1 tag
4
+ import argparse
5
+ import os
6
+ import time
7
+ import traceback
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from pathlib import Path
10
+
11
+ import librosa
12
+ import numpy as np
13
+ import torch
14
+ import tqdm
15
+ from AR.exps.beats.BEATs import BEATs
16
+ from AR.exps.beats.BEATs import BEATsConfig
17
+ from AR.exps.beats.config import id_name_dict
18
+ from soundstorm.s2.exps.hubert.feature_utils import get_shard_range
19
+ from soundstorm.utils import check_txt_file
20
+
21
+
22
+ def get_BEATs_top1(wav,
23
+ BEATs_model,
24
+ BEATs_label_dict,
25
+ device: str='cpu',
26
+ topk: int=1):
27
+ wav = torch.tensor(wav).unsqueeze(0).to(device)
28
+ padding_mask = torch.zeros(wav.shape).bool().to(device)
29
+ probs = BEATs_model.extract_features(wav, padding_mask=padding_mask)[0]
30
+ # 单条推理
31
+ probs = probs[0]
32
+ topk_label_prob, topk_label_idx = probs.topk(k=topk)
33
+ topk_label = [
34
+ BEATs_label_dict[label_idx.item()] for label_idx in topk_label_idx
35
+ ]
36
+ topk_label_name = [id_name_dict[label] for label in topk_label]
37
+ top1_label = topk_label_name[0]
38
+ return top1_label
39
+
40
+
41
+ def process_sentence(args,
42
+ fp: Path,
43
+ train_dump_dir: Path,
44
+ dev_dump_dir: Path,
45
+ test_dump_dir: Path,
46
+ VAD_dict,
47
+ BEATs_model,
48
+ BEATs_label_dict,
49
+ device: str='cpu'):
50
+ utt_id = fp.stem
51
+ sr = args.sr
52
+ record = []
53
+ train_audio_tag_dir = train_dump_dir / "audio_tag"
54
+ train_audio_tag_dir.mkdir(parents=True, exist_ok=True)
55
+
56
+ dev_audio_tag_dir = dev_dump_dir / "audio_tag"
57
+ dev_audio_tag_dir.mkdir(parents=True, exist_ok=True)
58
+
59
+ test_audio_tag_dir = test_dump_dir / "audio_tag"
60
+ test_audio_tag_dir.mkdir(parents=True, exist_ok=True)
61
+
62
+ try:
63
+ # get info for path
64
+ wav_path_list = str(fp).strip().split('/')
65
+ sub_dataset, spk_id, book_name = wav_path_list[-4], wav_path_list[
66
+ -3], wav_path_list[-2]
67
+ wav_name = wav_path_list[-1][:-5]
68
+ assert wav_name == utt_id
69
+ # key_name for big wav
70
+ key_name = f'{wav_name}#{sub_dataset}#{spk_id}#{book_name}'
71
+ # 判断 VAD 字典中不存在该条音频信息的情况
72
+ if key_name not in VAD_dict.keys():
73
+ print(key_name, 'not in VAD_dict !')
74
+ return record
75
+ wav = None
76
+ sorted_split_VAD_dict = sorted(VAD_dict[key_name].items())
77
+ len_dict = len(sorted_split_VAD_dict)
78
+ for index, item in enumerate(sorted_split_VAD_dict):
79
+ split_name, value = item
80
+ start, end = value
81
+ # train | dev | test
82
+ if index == len_dict - 1:
83
+ subset = 'test'
84
+ audio_tag_path = test_audio_tag_dir / (split_name + ".txt")
85
+ elif index == len_dict - 2:
86
+ subset = 'dev'
87
+ audio_tag_path = dev_audio_tag_dir / (split_name + ".txt")
88
+ else:
89
+ subset = 'train'
90
+ audio_tag_path = train_audio_tag_dir / (split_name + ".txt")
91
+
92
+ if os.path.exists(audio_tag_path) and check_txt_file(
93
+ audio_tag_path):
94
+ # print(audio_tag_path, 'exits!')
95
+ pass
96
+ else:
97
+ # 这里加判断保证在 sub wav 的循环中只 load 一次
98
+ if wav is None:
99
+ # load big wav
100
+ # 在最外层 load 如果 sub wav 的特征都存在了就会白白消耗 load 的时间
101
+ wav, _ = librosa.load(str(fp), sr=sr)
102
+ sub_wav = wav[int(start * sr):int(end * sr)]
103
+ audio_tag_top1 = get_BEATs_top1(
104
+ wav=sub_wav,
105
+ BEATs_model=BEATs_model,
106
+ BEATs_label_dict=BEATs_label_dict,
107
+ device=device)
108
+
109
+ with open(audio_tag_path, 'w') as f:
110
+ f.write(audio_tag_top1)
111
+
112
+ sub_record = {
113
+ "utt_id": split_name,
114
+ "audio_tag_path": audio_tag_path,
115
+ "subset": subset
116
+ }
117
+ # recodrd 变成 List of Dict
118
+ record.append(sub_record)
119
+ except Exception:
120
+ print("occur Exception")
121
+ traceback.print_exc()
122
+ # record 有可能是一个不完整的 List
123
+ return record
124
+ return record
125
+
126
+
127
+ def process_sentences(args,
128
+ fps: Path,
129
+ train_dump_dir: Path,
130
+ dev_dump_dir: Path,
131
+ test_dump_dir: Path,
132
+ VAD_dict,
133
+ BEATs_model,
134
+ BEATs_label_dict,
135
+ device: str='cpu',
136
+ nprocs: int=1):
137
+ print("nprocs:", nprocs)
138
+ if nprocs == 1:
139
+ results = []
140
+ for fp in tqdm.tqdm(fps, total=len(fps)):
141
+ record = process_sentence(
142
+ args=args,
143
+ fp=fp,
144
+ train_dump_dir=train_dump_dir,
145
+ dev_dump_dir=dev_dump_dir,
146
+ test_dump_dir=test_dump_dir,
147
+ VAD_dict=VAD_dict,
148
+ BEATs_model=BEATs_model,
149
+ BEATs_label_dict=BEATs_label_dict,
150
+ device=device)
151
+ if record:
152
+ results.append(record)
153
+ else:
154
+ with ThreadPoolExecutor(nprocs) as pool:
155
+ futures = []
156
+ with tqdm.tqdm(total=len(fps)) as progress:
157
+ for fp in fps:
158
+ future = pool.submit(process_sentence, args, fp,
159
+ train_dump_dir, dev_dump_dir,
160
+ test_dump_dir, VAD_dict, BEATs_model,
161
+ BEATs_label_dict, device)
162
+ future.add_done_callback(lambda p: progress.update())
163
+ futures.append(future)
164
+
165
+ results = []
166
+ for ft in futures:
167
+ record = ft.result()
168
+ if record:
169
+ results.append(record)
170
+
171
+ # torch.save() to a large `.pth` file
172
+ non_speech_dict = dict()
173
+ non_speech_dict['train'] = {}
174
+ non_speech_dict['dev'] = {}
175
+ non_speech_dict['test'] = {}
176
+ # record 是 List of Dict, 一条大 wav 一个 record,一条小 wav 一个 sub_recored
177
+ print(f"start to save {args.rank}_{args.nshard}.npy ...")
178
+ save_start_time = time.time()
179
+ for record in tqdm.tqdm(results, total=len(results), colour='green'):
180
+ for sub_record in record:
181
+ # 这里加 try, 因为 txt 文件可能损坏
182
+ try:
183
+ utt_id = sub_record["utt_id"]
184
+ subset = sub_record["subset"]
185
+ audio_tag_top1 = check_txt_file(sub_record["audio_tag_path"])
186
+ if audio_tag_top1 is not False:
187
+ if 'speech' not in audio_tag_top1.lower():
188
+ non_speech_dict[subset][utt_id] = audio_tag_top1
189
+ else:
190
+ # print(f'audio tag result of {utt_id} is speech')
191
+ pass
192
+ else:
193
+ print(f'audio tag result of {utt_id} is False')
194
+ except Exception:
195
+ print(f"{utt_id} occur Exception")
196
+ traceback.print_exc()
197
+ continue
198
+
199
+ train_filename = train_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
200
+ dev_filename = dev_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
201
+ test_filename = test_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
202
+ np.save(train_filename, non_speech_dict['train'])
203
+ print(f"npy file '{train_filename}' write down")
204
+
205
+ np.save(dev_filename, non_speech_dict['dev'])
206
+ print(f"npy file '{dev_filename}' write down")
207
+
208
+ np.save(test_filename, non_speech_dict['test'])
209
+ print(f"npy file '{test_filename}' write down")
210
+ print('time of save stage:', time.time() - save_start_time)
211
+
212
+
213
+ def main():
214
+ # parse config and args
215
+ parser = argparse.ArgumentParser(
216
+ description="Use AudioTag tool BEATs to filter out audios who's top1 tag is not 'speech'."
217
+ )
218
+
219
+ parser.add_argument(
220
+ "--data_dir", default=None, type=str, help="directory to dataset.")
221
+
222
+ parser.add_argument(
223
+ "--dump_dir",
224
+ type=str,
225
+ required=True,
226
+ help="directory to dump feature files.")
227
+
228
+ parser.add_argument(
229
+ "--num-cpu", type=int, default=1, help="number of process.")
230
+
231
+ parser.add_argument(
232
+ '--sr', type=int, default=16000, help='sample rate of model')
233
+
234
+ # For LibriLight dataset
235
+ parser.add_argument(
236
+ "--sub_dataset",
237
+ default="small",
238
+ type=str,
239
+ help="name of sub dataset of LibriLight",
240
+ choices=['small', 'medium', 'large', 'duplicate'], )
241
+ parser.add_argument(
242
+ "--VAD_path", type=str, default='./VAD/librilight_segment_dict.npy')
243
+ parser.add_argument("--nshard", type=int, default=3)
244
+ parser.add_argument("--rank", type=int, default=0)
245
+
246
+ # for BEATs
247
+ parser.add_argument(
248
+ "--BEATs_ckpt_path",
249
+ type=str,
250
+ default='./pretrained_model/BEATs_iter1_finetuned_on_AS2M_cpt1.pt')
251
+
252
+ args = parser.parse_args()
253
+
254
+ data_dir = Path(args.data_dir).expanduser()
255
+ dump_dir = Path(args.dump_dir).expanduser()
256
+ # use absolute path
257
+ dump_dir = dump_dir.resolve()
258
+ dump_dir.mkdir(parents=True, exist_ok=True)
259
+
260
+ assert data_dir.is_dir()
261
+
262
+ # sub_dataset here
263
+ sub_dataset_dir = data_dir / args.sub_dataset
264
+ # olny spk_id in list, sort by lexicographical order
265
+ speaker_list = sorted(os.listdir(sub_dataset_dir))
266
+ start, end = get_shard_range(len(speaker_list), args.nshard, args.rank)
267
+ # speaker_list for this rank
268
+ speaker_list = speaker_list[start:end]
269
+
270
+ all_wav_files = []
271
+
272
+ for speaker in speaker_list:
273
+ wav_files = sorted(list((sub_dataset_dir / speaker).rglob("*/*.flac")))
274
+ # filter out ._*.flac
275
+ wav_files = [
276
+ file for file in wav_files if not file.name.startswith('._')
277
+ ]
278
+ all_wav_files += wav_files
279
+
280
+ print(f"num of wav files in rank {args.rank}:", len(all_wav_files))
281
+ # get VAD info
282
+ VAD_dict = np.load(args.VAD_path, allow_pickle=True).item()
283
+
284
+ sub_dataset_dump_dir = dump_dir / args.sub_dataset
285
+ sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
286
+ train_dump_dir = sub_dataset_dump_dir / "train"
287
+ train_dump_dir.mkdir(parents=True, exist_ok=True)
288
+ dev_dump_dir = sub_dataset_dump_dir / "dev"
289
+ dev_dump_dir.mkdir(parents=True, exist_ok=True)
290
+ test_dump_dir = sub_dataset_dump_dir / "test"
291
+ test_dump_dir.mkdir(parents=True, exist_ok=True)
292
+
293
+ BEATs_ckpt = torch.load(args.BEATs_ckpt_path)
294
+
295
+ BEATs_cfg = BEATsConfig(BEATs_ckpt['cfg'])
296
+ BEATs_model = BEATs(BEATs_cfg)
297
+ BEATs_model.load_state_dict(BEATs_ckpt['model'])
298
+ BEATs_model.eval()
299
+ # cpu or cuda
300
+ device = 'cpu'
301
+ BEATs_model.to(device)
302
+
303
+ BEATs_label_dict = BEATs_ckpt['label_dict']
304
+
305
+ # 每条大 wav 分出一个 dev 一个 test,比例大概是 96:2:2
306
+ if all_wav_files:
307
+ process_sentences(
308
+ args=args,
309
+ fps=all_wav_files,
310
+ train_dump_dir=train_dump_dir,
311
+ dev_dump_dir=dev_dump_dir,
312
+ test_dump_dir=test_dump_dir,
313
+ VAD_dict=VAD_dict,
314
+ BEATs_model=BEATs_model,
315
+ BEATs_label_dict=BEATs_label_dict,
316
+ device=device,
317
+ nprocs=args.num_cpu)
318
+
319
+
320
+ if __name__ == "__main__":
321
+ main()
AR/exps/get_phones.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 1. read text of dataset
3
+ 2. text -> IPA by GruutPhonemizer
4
+ 3. save out a *.npy dict for all text
5
+ my_dict = {"utt_id1": text1, "utt_id2": text2}
6
+ np.save(output_filename, my_dict)
7
+ my_dict = np.load(output_filename, allow_pickle=True).item()
8
+ """
9
+ import argparse
10
+ import os
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from operator import itemgetter
13
+ from pathlib import Path
14
+ from typing import List
15
+
16
+ import numpy as np
17
+ import tqdm
18
+ from AR.text_processing.phonemizer import GruutPhonemizer
19
+
20
+
21
+ def read_txt(txt_file):
22
+ utt_name = txt_file.stem
23
+ utt_id = utt_name.split('.')[0]
24
+ try:
25
+ with open(txt_file, 'r') as file:
26
+ txt = file.readline()
27
+ record = {"utt_id": utt_id, "txt": txt}
28
+ except Exception:
29
+ print("occur Exception")
30
+ traceback.print_exc()
31
+ return None
32
+ return record
33
+
34
+
35
+ def read_txts(txt_files: List[Path], nprocs: int=1):
36
+ if nprocs == 1:
37
+ results = []
38
+ for txt_file in tqdm.tqdm(txt_files, total=len(txt_files)):
39
+ record = read_txt(txt_file=txt_file)
40
+ if record:
41
+ results.append(record)
42
+ else:
43
+ with ThreadPoolExecutor(nprocs) as pool:
44
+ futures = []
45
+ with tqdm.tqdm(total=len(txt_files)) as progress:
46
+ for txt_file in txt_files:
47
+ future = pool.submit(read_txt, txt_file)
48
+ future.add_done_callback(lambda p: progress.update())
49
+ futures.append(future)
50
+
51
+ results = []
52
+ for ft in futures:
53
+ record = ft.result()
54
+ if record:
55
+ results.append(record)
56
+
57
+ results.sort(key=itemgetter("utt_id"))
58
+ return_list = []
59
+ for item in results:
60
+ return_list.append((item["utt_id"], item["txt"]))
61
+ return return_list
62
+
63
+
64
+ def process_sentence(item, phonemizer):
65
+ utt_id, text = item
66
+ try:
67
+ phonemes = phonemizer.phonemize(text, espeak=False)
68
+ record = {"utt_id": utt_id, "phonemes": phonemes}
69
+ except Exception:
70
+ print("occur Exception")
71
+ traceback.print_exc()
72
+ return None
73
+ return record
74
+
75
+
76
+ def process_sentences(items, phonemizer, output_dir, nprocs: int=1):
77
+ if nprocs == 1:
78
+ results = []
79
+ for item in tqdm.tqdm(items, total=len(items)):
80
+ record = process_sentence(item=item, phonemizer=phonemizer)
81
+ if record:
82
+ results.append(record)
83
+ else:
84
+ with ThreadPoolExecutor(nprocs) as pool:
85
+ futures = []
86
+ with tqdm.tqdm(total=len(items)) as progress:
87
+ for item in items:
88
+ future = pool.submit(process_sentence, item, phonemizer)
89
+ future.add_done_callback(lambda p: progress.update())
90
+ futures.append(future)
91
+
92
+ results = []
93
+ for ft in futures:
94
+ record = ft.result()
95
+ if record:
96
+ results.append(record)
97
+ results.sort(key=itemgetter("utt_id"))
98
+ npy_dict = {}
99
+ for item in results:
100
+ utt_id = item["utt_id"]
101
+ phonemes = item["phonemes"]
102
+ npy_dict[utt_id] = phonemes
103
+ filename = output_dir / 'phonemes.npy'
104
+ np.save(filename, npy_dict)
105
+ print(f"npy file '{filename}' write down")
106
+
107
+
108
+ def main():
109
+ # parse config and args
110
+ parser = argparse.ArgumentParser(description="Get phones for datasets")
111
+
112
+ parser.add_argument(
113
+ "--dataset",
114
+ default="ljspeech",
115
+ type=str,
116
+ help="name of dataset, should in {ljspeech, libritts} now")
117
+
118
+ parser.add_argument(
119
+ "--data_dir", default=None, type=str, help="directory to dataset.")
120
+
121
+ parser.add_argument(
122
+ "--dump_dir",
123
+ type=str,
124
+ required=True,
125
+ help="directory to dump feature files.")
126
+ parser.add_argument(
127
+ "--num-cpu", type=int, default=1, help="number of process.")
128
+
129
+ args = parser.parse_args()
130
+
131
+ data_dir = Path(args.data_dir).expanduser()
132
+ dump_dir = Path(args.dump_dir).expanduser()
133
+ # use absolute path
134
+ dump_dir = dump_dir.resolve()
135
+ dump_dir.mkdir(parents=True, exist_ok=True)
136
+
137
+ assert data_dir.is_dir()
138
+
139
+ if args.dataset == "ljspeech":
140
+ data_dict = {}
141
+ text_path = data_dir / 'metadata.csv'
142
+ with open(text_path, 'r') as rf:
143
+ for line in rf:
144
+ line_list = line.strip().split('|')
145
+ utt_id = line_list[0]
146
+ raw_text = line_list[-1]
147
+ data_dict[utt_id] = raw_text
148
+
149
+ sorted_dict = sorted(data_dict.items())
150
+
151
+ num_train = 12900
152
+ num_dev = 100
153
+ # (utt_id, txt)
154
+ train_txts = sorted_dict[:num_train]
155
+ dev_txts = sorted_dict[num_train:num_train + num_dev]
156
+ test_txts = sorted_dict[num_train + num_dev:]
157
+
158
+ elif args.dataset == "libritts":
159
+ '''
160
+ we use train-clean-100、train-clean-360、train-other-500 here
161
+ and split dev and test from them, don't use test-* and dev-* cause the speakers are disjoint
162
+ the file structure is LibriTTS_R/train-clean-100/spkid/*/*.wav
163
+ there are about 2311 in these subsets, we split 1 dev and 1 test wav out from each speaker
164
+ '''
165
+ txt_files = []
166
+ train_txt_files = []
167
+ dev_txt_files = []
168
+ test_txt_files = []
169
+ sub_num_dev = 1
170
+ for sub_dataset_name in {
171
+ "train-clean-100", "train-clean-360", "train-other-500"
172
+ }:
173
+ sub_dataset_dir = data_dir / sub_dataset_name
174
+ # filter out hidden files
175
+ speaker_list = [
176
+ file for file in os.listdir(sub_dataset_dir)
177
+ if not file.startswith('.')
178
+ ]
179
+ for speaker in speaker_list:
180
+ txt_files = sorted(
181
+ list((sub_dataset_dir / speaker).rglob(
182
+ "*/*.normalized.txt")))
183
+ # filter out ._*.wav
184
+ txt_files = [
185
+ file for file in txt_files if not file.name.startswith('._')
186
+ ]
187
+ train_txt_files += txt_files[:-sub_num_dev * 2]
188
+ dev_txt_files += txt_files[-sub_num_dev * 2:-sub_num_dev]
189
+ test_txt_files += txt_files[-sub_num_dev:]
190
+ print("len(train_txt_files):", len(train_txt_files))
191
+ print("len(dev_txt_files):", len(dev_txt_files))
192
+ print("len(test_txt_files):", len(test_txt_files))
193
+
194
+ train_txts = read_txts(train_txt_files)
195
+ dev_txts = read_txts(dev_txt_files)
196
+ test_txts = read_txts(test_txt_files)
197
+
198
+ else:
199
+ print("dataset should in {ljspeech, libritts} now!")
200
+
201
+ train_dump_dir = dump_dir / "train"
202
+ train_dump_dir.mkdir(parents=True, exist_ok=True)
203
+ dev_dump_dir = dump_dir / "dev"
204
+ dev_dump_dir.mkdir(parents=True, exist_ok=True)
205
+ test_dump_dir = dump_dir / "test"
206
+ test_dump_dir.mkdir(parents=True, exist_ok=True)
207
+
208
+ phonemizer = GruutPhonemizer(language='en-us')
209
+
210
+ # process for the 3 sections
211
+ if train_txts:
212
+ process_sentences(
213
+ items=train_txts,
214
+ output_dir=train_dump_dir,
215
+ phonemizer=phonemizer,
216
+ nprocs=args.num_cpu)
217
+ if dev_txts:
218
+ process_sentences(
219
+ items=dev_txts,
220
+ output_dir=dev_dump_dir,
221
+ phonemizer=phonemizer,
222
+ nprocs=args.num_cpu)
223
+ if test_txts:
224
+ process_sentences(
225
+ items=test_txts,
226
+ output_dir=test_dump_dir,
227
+ phonemizer=phonemizer,
228
+ nprocs=args.num_cpu)
229
+
230
+
231
+ if __name__ == "__main__":
232
+ main()
AR/exps/get_phones_librilight.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 1. read text of dataset, for LibriLight read txt_*.npy -> 需要整理成 list(utt_id, txt) 的形式
3
+ 2. text -> IPA by GruutPhonemizer
4
+ 3. save out a *.npy dict for all text
5
+ 4. LibriLight 每个 split 分开处理
6
+ my_dict = {"utt_id1": text1, "utt_id2": text2}
7
+ np.save(output_filename, my_dict)
8
+ my_dict = np.load(output_filename, allow_pickle=True).item()
9
+ """
10
+ import argparse
11
+ import os
12
+ import time
13
+ import traceback
14
+ from concurrent.futures import ThreadPoolExecutor
15
+ from operator import itemgetter
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import tqdm
20
+ from AR.text_processing.phonemizer import GruutPhonemizer
21
+ from soundstorm.utils import check_txt_file
22
+
23
+
24
+ def read_txts(txt_file: Path, nprocs: int=1):
25
+ '''
26
+ txt_file: path of npy dict, {"utt_id1": text1, "utt_id2": text2}
27
+ '''
28
+ txt_dict = np.load(txt_file, allow_pickle=True).item()
29
+ #[(utt_id, txt), ...]
30
+ return_list = list(txt_dict.items())
31
+ return return_list
32
+
33
+
34
+ def process_sentence(item, phonemizer, output_dir):
35
+ utt_id, text = item
36
+ phonemes_dir = output_dir / "phonemes"
37
+ phonemes_dir.mkdir(parents=True, exist_ok=True)
38
+ phonemes_path = phonemes_dir / (utt_id + ".txt")
39
+ try:
40
+ if os.path.exists(phonemes_path) and check_txt_file(phonemes_path):
41
+ # print(phonemes_path, 'exits!')
42
+ pass
43
+ else:
44
+ phonemes = phonemizer.phonemize(text, espeak=False)
45
+ with open(phonemes_path, 'w') as f:
46
+ f.write(phonemes)
47
+ record = {"utt_id": utt_id, "phonemes_path": phonemes_path}
48
+ except Exception:
49
+ print("occur Exception")
50
+ traceback.print_exc()
51
+ return None
52
+ return record
53
+
54
+
55
+ def process_sentences(args, items, phonemizer, output_dir, nprocs: int=1):
56
+ print("nprocs:", nprocs)
57
+ if nprocs == 1:
58
+ results = []
59
+ for item in tqdm.tqdm(items, total=len(items)):
60
+ record = process_sentence(
61
+ item=item, phonemizer=phonemizer, output_dir=output_dir)
62
+ if record:
63
+ results.append(record)
64
+ else:
65
+ with ThreadPoolExecutor(nprocs) as pool:
66
+ futures = []
67
+ with tqdm.tqdm(total=len(items)) as progress:
68
+ for item in items:
69
+ future = pool.submit(process_sentence, item, phonemizer,
70
+ output_dir)
71
+ future.add_done_callback(lambda p: progress.update())
72
+ futures.append(future)
73
+
74
+ results = []
75
+ for ft in futures:
76
+ record = ft.result()
77
+ if record:
78
+ results.append(record)
79
+
80
+ results.sort(key=itemgetter("utt_id"))
81
+
82
+ npy_dict = {}
83
+ print(f"start to save {args.rank}_{args.nshard}.npy ...")
84
+ save_start_time = time.time()
85
+ for item in tqdm.tqdm(results, total=len(results), colour='green'):
86
+ # 这里加 try, 因为 txt 文件可能损坏
87
+ try:
88
+ utt_id = item["utt_id"]
89
+ phonemes = check_txt_file(item["phonemes_path"])
90
+ if phonemes is not False:
91
+ npy_dict[utt_id] = phonemes
92
+ else:
93
+ print(f'phonemes of {utt_id} is False')
94
+ except Exception:
95
+ print(f"{utt_id} occur Exception")
96
+ traceback.print_exc()
97
+ continue
98
+
99
+ filename = output_dir / f'phonemes_{args.rank}_{args.nshard}.npy'
100
+ np.save(filename, npy_dict)
101
+ print(f"npy file '{filename}' write down")
102
+ print('time of save stage:', time.time() - save_start_time)
103
+
104
+
105
+ def main():
106
+ # parse config and args
107
+ parser = argparse.ArgumentParser(
108
+ description="Get phones for LibriLight dataset from txt_*.npy")
109
+
110
+ parser.add_argument(
111
+ "--dump_dir",
112
+ type=str,
113
+ required=True,
114
+ help="directory to dump feature files.")
115
+ parser.add_argument(
116
+ "--num-cpu", type=int, default=1, help="number of process.")
117
+
118
+ parser.add_argument(
119
+ '--train_txt_dir',
120
+ type=str,
121
+ default='dump/small/train/',
122
+ help='dir of train txt files')
123
+ parser.add_argument(
124
+ '--dev_txt_dir',
125
+ type=str,
126
+ default='dump/small/dev/',
127
+ help='dir of dev txt files')
128
+ parser.add_argument(
129
+ '--test_txt_dir',
130
+ type=str,
131
+ default='dump/small/test/',
132
+ help='dir of test txt files')
133
+
134
+ parser.add_argument(
135
+ "--sub_dataset",
136
+ default="small",
137
+ type=str,
138
+ help="name of sub dataset of LibriLight",
139
+ choices=['small', 'medium', 'large', 'duplicate'], )
140
+ parser.add_argument("--nshard", type=int, default=3)
141
+ parser.add_argument("--rank", type=int, default=0)
142
+
143
+ args = parser.parse_args()
144
+ print(f"nshard: {args.nshard}, rank: {args.rank}")
145
+
146
+ train_txt_dir = Path(args.train_txt_dir)
147
+ dev_txt_dir = Path(args.dev_txt_dir)
148
+ test_txt_dir = Path(args.test_txt_dir)
149
+
150
+ dump_dir = Path(args.dump_dir).expanduser()
151
+ # use absolute path
152
+ dump_dir = dump_dir.resolve()
153
+ dump_dir.mkdir(parents=True, exist_ok=True)
154
+
155
+ train_txt_file = train_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
156
+ dev_txt_file = dev_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
157
+ test_txt_file = test_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
158
+
159
+ train_txts = read_txts(train_txt_file)
160
+ dev_txts = read_txts(dev_txt_file)
161
+ test_txts = read_txts(test_txt_file)
162
+
163
+ sub_dataset_dump_dir = dump_dir / args.sub_dataset
164
+ sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
165
+ train_dump_dir = sub_dataset_dump_dir / "train"
166
+ train_dump_dir.mkdir(parents=True, exist_ok=True)
167
+ dev_dump_dir = sub_dataset_dump_dir / "dev"
168
+ dev_dump_dir.mkdir(parents=True, exist_ok=True)
169
+ test_dump_dir = sub_dataset_dump_dir / "test"
170
+ test_dump_dir.mkdir(parents=True, exist_ok=True)
171
+ phonemizer = GruutPhonemizer(language='en-us')
172
+
173
+ # process for the 3 sections
174
+ if train_txts:
175
+ process_sentences(
176
+ args=args,
177
+ items=train_txts,
178
+ output_dir=train_dump_dir,
179
+ phonemizer=phonemizer,
180
+ nprocs=args.num_cpu)
181
+ if dev_txts:
182
+ process_sentences(
183
+ args=args,
184
+ items=dev_txts,
185
+ output_dir=dev_dump_dir,
186
+ phonemizer=phonemizer,
187
+ nprocs=args.num_cpu)
188
+ if test_txts:
189
+ process_sentences(
190
+ args=args,
191
+ items=test_txts,
192
+ output_dir=test_dump_dir,
193
+ phonemizer=phonemizer,
194
+ nprocs=args.num_cpu)
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()
AR/exps/get_txt_librilight.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ import traceback
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from pathlib import Path
7
+
8
+ import librosa
9
+ import numpy as np
10
+ import tqdm
11
+ import whisper
12
+ from soundstorm.s2.exps.hubert.feature_utils import get_shard_range
13
+ from soundstorm.utils import check_txt_file
14
+
15
+
16
+ def process_sentence(args,
17
+ fp: Path,
18
+ train_dump_dir: Path,
19
+ dev_dump_dir: Path,
20
+ test_dump_dir: Path,
21
+ VAD_dict):
22
+ asr_model = whisper.load_model("tiny.en")
23
+ utt_id = fp.stem
24
+ sr = args.sr
25
+ record = []
26
+ train_txt_dir = train_dump_dir / "txt"
27
+ train_txt_dir.mkdir(parents=True, exist_ok=True)
28
+
29
+ dev_txt_dir = dev_dump_dir / "txt"
30
+ dev_txt_dir.mkdir(parents=True, exist_ok=True)
31
+
32
+ test_txt_dir = test_dump_dir / "txt"
33
+ test_txt_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ try:
36
+ # get info for path
37
+ wav_path_list = str(fp).strip().split('/')
38
+ sub_dataset, spk_id, book_name = wav_path_list[-4], wav_path_list[
39
+ -3], wav_path_list[-2]
40
+ wav_name = wav_path_list[-1][:-5]
41
+ assert wav_name == utt_id
42
+ # key_name for big wav
43
+ key_name = f'{wav_name}#{sub_dataset}#{spk_id}#{book_name}'
44
+ # 判断 VAD 字典中不存在该条音频信息的情况
45
+ if key_name not in VAD_dict.keys():
46
+ print(key_name, 'not in VAD_dict !')
47
+ return record
48
+ wav = None
49
+ sorted_split_VAD_dict = sorted(VAD_dict[key_name].items())
50
+ len_dict = len(sorted_split_VAD_dict)
51
+ for index, item in enumerate(sorted_split_VAD_dict):
52
+ split_name, value = item
53
+ start, end = value
54
+ # train | dev | test
55
+ if index == len_dict - 1:
56
+ subset = 'test'
57
+ txt_path = test_txt_dir / (split_name + ".txt")
58
+ elif index == len_dict - 2:
59
+ subset = 'dev'
60
+ txt_path = dev_txt_dir / (split_name + ".txt")
61
+ else:
62
+ subset = 'train'
63
+ txt_path = train_txt_dir / (split_name + ".txt")
64
+
65
+ if os.path.exists(txt_path) and check_txt_file(txt_path):
66
+ # print(txt_path, 'exits!')
67
+ pass
68
+ else:
69
+ # 这里加判断保证在 sub wav 的循环中只 load 一次
70
+ if wav is None:
71
+ # load big wav
72
+ # 在最外层 load 如果 sub wav 的特征都存在了就会白白消耗 load 的时间
73
+ wav, _ = librosa.load(str(fp), sr=sr)
74
+ sub_wav = wav[int(start * sr):int(end * sr)]
75
+ asr_result = asr_model.transcribe(sub_wav)["text"]
76
+ with open(txt_path, 'w') as f:
77
+ f.write(asr_result)
78
+
79
+ sub_record = {
80
+ "utt_id": split_name,
81
+ "txt_path": txt_path,
82
+ "subset": subset
83
+ }
84
+ # recodrd 变成 List of Dict
85
+ record.append(sub_record)
86
+ except Exception:
87
+ print("occur Exception")
88
+ traceback.print_exc()
89
+ # record 有可能是一个不完整的 List
90
+ return record
91
+ return record
92
+
93
+
94
+ def process_sentences(args,
95
+ fps: Path,
96
+ train_dump_dir: Path,
97
+ dev_dump_dir: Path,
98
+ test_dump_dir: Path,
99
+ VAD_dict,
100
+ nprocs: int=1):
101
+ print("nprocs:", nprocs)
102
+ if nprocs == 1:
103
+ results = []
104
+ for fp in tqdm.tqdm(fps, total=len(fps)):
105
+ record = process_sentence(
106
+ args=args,
107
+ fp=fp,
108
+ train_dump_dir=train_dump_dir,
109
+ dev_dump_dir=dev_dump_dir,
110
+ test_dump_dir=test_dump_dir,
111
+ VAD_dict=VAD_dict)
112
+ if record:
113
+ results.append(record)
114
+ else:
115
+ with ThreadPoolExecutor(nprocs) as pool:
116
+ futures = []
117
+ with tqdm.tqdm(total=len(fps)) as progress:
118
+ for fp in fps:
119
+ future = pool.submit(process_sentence, args, fp,
120
+ train_dump_dir, dev_dump_dir,
121
+ test_dump_dir, VAD_dict)
122
+ future.add_done_callback(lambda p: progress.update())
123
+ futures.append(future)
124
+
125
+ results = []
126
+ for ft in futures:
127
+ record = ft.result()
128
+ if record:
129
+ results.append(record)
130
+
131
+ # torch.save() to a large `.pth` file
132
+ txt_dict = dict()
133
+ txt_dict['train'] = {}
134
+ txt_dict['dev'] = {}
135
+ txt_dict['test'] = {}
136
+ # record 是 List of Dict, 一条大 wav 一个 record,一条小 wav 一个 sub_recored
137
+ print(f"start to save {args.rank}_{args.nshard}.npy ...")
138
+ save_start_time = time.time()
139
+ for record in tqdm.tqdm(results, total=len(results), colour='green'):
140
+ for sub_record in record:
141
+ # 这里加 try, 因为 txt 文件可能损坏
142
+ try:
143
+ utt_id = sub_record["utt_id"]
144
+ subset = sub_record["subset"]
145
+ asr_result = check_txt_file(sub_record["txt_path"])
146
+ if asr_result is not False:
147
+ txt_dict[subset][utt_id] = asr_result
148
+ else:
149
+ print(f'asr result of {utt_id} is False')
150
+ except Exception:
151
+ print(f"{utt_id} occur Exception")
152
+ traceback.print_exc()
153
+ continue
154
+
155
+ train_filename = train_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
156
+ dev_filename = dev_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
157
+ test_filename = test_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
158
+ np.save(train_filename, txt_dict['train'])
159
+ print(f"npy file '{train_filename}' write down")
160
+
161
+ np.save(dev_filename, txt_dict['dev'])
162
+ print(f"npy file '{dev_filename}' write down")
163
+
164
+ np.save(test_filename, txt_dict['test'])
165
+ print(f"npy file '{test_filename}' write down")
166
+ print('time of save stage:', time.time() - save_start_time)
167
+
168
+
169
+ def main():
170
+ # parse config and args
171
+ parser = argparse.ArgumentParser(
172
+ description="Preprocess audio and then extract features for LibriLight.")
173
+
174
+ parser.add_argument(
175
+ "--data_dir", default=None, type=str, help="directory to dataset.")
176
+
177
+ parser.add_argument(
178
+ "--dump_dir",
179
+ type=str,
180
+ required=True,
181
+ help="directory to dump feature files.")
182
+
183
+ parser.add_argument(
184
+ "--num-cpu", type=int, default=1, help="number of process.")
185
+
186
+ parser.add_argument(
187
+ '--sr', type=int, default=16000, help='sample rate of model')
188
+
189
+ # For LibriLight dataset
190
+ parser.add_argument(
191
+ "--sub_dataset",
192
+ default="small",
193
+ type=str,
194
+ help="name of sub dataset of LibriLight",
195
+ choices=['small', 'medium', 'large', 'duplicate'], )
196
+ parser.add_argument(
197
+ "--VAD_path", type=str, default='./VAD/librilight_segment_dict.npy')
198
+ parser.add_argument("--nshard", type=int, default=3)
199
+ parser.add_argument("--rank", type=int, default=0)
200
+
201
+ args = parser.parse_args()
202
+
203
+ data_dir = Path(args.data_dir).expanduser()
204
+ dump_dir = Path(args.dump_dir).expanduser()
205
+ # use absolute path
206
+ dump_dir = dump_dir.resolve()
207
+ dump_dir.mkdir(parents=True, exist_ok=True)
208
+
209
+ assert data_dir.is_dir()
210
+
211
+ # sub_dataset here
212
+ sub_dataset_dir = data_dir / args.sub_dataset
213
+ # olny spk_id in list, sort by lexicographical order
214
+ speaker_list = sorted(os.listdir(sub_dataset_dir))
215
+ start, end = get_shard_range(len(speaker_list), args.nshard, args.rank)
216
+ # speaker_list for this rank
217
+ speaker_list = speaker_list[start:end]
218
+
219
+ all_wav_files = []
220
+
221
+ for speaker in speaker_list:
222
+ wav_files = sorted(list((sub_dataset_dir / speaker).rglob("*/*.flac")))
223
+ # filter out ._*.flac
224
+ wav_files = [
225
+ file for file in wav_files if not file.name.startswith('._')
226
+ ]
227
+ all_wav_files += wav_files
228
+
229
+ print(f"num of wav files in rank {args.rank}:", len(all_wav_files))
230
+ # get VAD info
231
+ VAD_dict = np.load(args.VAD_path, allow_pickle=True).item()
232
+
233
+ sub_dataset_dump_dir = dump_dir / args.sub_dataset
234
+ sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
235
+ train_dump_dir = sub_dataset_dump_dir / "train"
236
+ train_dump_dir.mkdir(parents=True, exist_ok=True)
237
+ dev_dump_dir = sub_dataset_dump_dir / "dev"
238
+ dev_dump_dir.mkdir(parents=True, exist_ok=True)
239
+ test_dump_dir = sub_dataset_dump_dir / "test"
240
+ test_dump_dir.mkdir(parents=True, exist_ok=True)
241
+
242
+ # 每条大 wav 分出一个 dev 一个 test,比例大概是 96:2:2
243
+ if all_wav_files:
244
+ process_sentences(
245
+ args=args,
246
+ fps=all_wav_files,
247
+ train_dump_dir=train_dump_dir,
248
+ dev_dump_dir=dev_dump_dir,
249
+ test_dump_dir=test_dump_dir,
250
+ VAD_dict=VAD_dict,
251
+ nprocs=args.num_cpu)
252
+
253
+
254
+ if __name__ == "__main__":
255
+ main()
AR/exps/split_train_val.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import pandas
3
+
4
+ semantic_path = 'dump/semantic.tsv'
5
+ phoneme_path = 'dump/phoneme.npy'
6
+ train_semantic_path = 'dump/semantic_train.tsv'
7
+ train_phoneme_path = 'dump/phoneme_train.npy'
8
+ dev_semantic_path = 'dump/semantic_dev.tsv'
9
+ dev_phoneme_path = 'dump/phoneme_dev.npy'
10
+
11
+ # 读取dump/semantic.tsv
12
+ semantic_df = pandas.read_csv(semantic_path, sep='\t')
13
+ # pd.DataFrame(columns=["item_name", "semantic_audio"])
14
+ # # 读取dump/phoneme.npy
15
+ phoneme_dict = numpy.load(phoneme_path, allow_pickle=True).item()
16
+
17
+ dev_num = 20
18
+ # 随机从semantic_df中选取dev_num个
19
+ dev_df = semantic_df.sample(n=dev_num)
20
+ # 剩下的是train
21
+ train_df = semantic_df.drop(dev_df.index)
22
+ # 保存
23
+ dev_df.to_csv(dev_semantic_path, sep='\t', index=False)
24
+ train_df.to_csv(train_semantic_path, sep='\t', index=False)
25
+
26
+ # 将dev_df中的item_name取出来 作为dev_phoneme_dict的key
27
+ dev_item_names = dev_df['item_name'].tolist()
28
+ dev_phoneme_dict = {k: phoneme_dict[k] for k in dev_item_names if k in phoneme_dict}
29
+ train_phoneme_dict = {k: phoneme_dict[k] for k in phoneme_dict.keys() if k not in dev_item_names}
30
+
31
+ numpy.save(dev_phoneme_path, dev_phoneme_dict)
32
+ numpy.save(train_phoneme_path, train_phoneme_dict)
33
+
34
+
35
+
AR/exps/t2s.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # text to semantic
2
+ import argparse
3
+ import os
4
+ import re
5
+ import time
6
+ from pathlib import Path
7
+
8
+ import librosa
9
+ import numpy as np
10
+ import torch
11
+ import whisper
12
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
13
+ from AR.text_processing.phonemizer import GruutPhonemizer
14
+ from AR.utils.io import load_yaml_config
15
+
16
+
17
+ def get_batch(text, phonemizer):
18
+ # phoneme_ids 和 phoneme_ids_len 是需要的
19
+ phoneme = phonemizer.phonemize(text, espeak=False)
20
+ phoneme_ids = phonemizer.transform(phoneme)
21
+ phoneme_ids_len = len(phoneme_ids)
22
+ phoneme_ids = np.array(phoneme_ids)
23
+ # add batch axis here
24
+ phoneme_ids = torch.tensor(phoneme_ids).unsqueeze(0)
25
+ phoneme_ids_len = torch.tensor([phoneme_ids_len])
26
+ print("phoneme:", phoneme)
27
+ batch = {
28
+ # torch.Tensor (B, max_phoneme_length)
29
+ "phoneme_ids": phoneme_ids,
30
+ # torch.Tensor (B)
31
+ "phoneme_ids_len": phoneme_ids_len
32
+ }
33
+ return batch
34
+
35
+
36
+ def get_prompt(prompt_wav_path, asr_model, phonemizer, semantic_tokenizer):
37
+ sample_rate = 16000
38
+ # to get prompt
39
+ prompt_name = os.path.basename(prompt_wav_path).split('.')[0]
40
+ wav, _ = librosa.load(prompt_wav_path, sr=sample_rate)
41
+ # 取末尾 3s, 但是不包含最后 0.1s 防止 AR S1 infer 提前停止
42
+ wav = wav[-sample_rate * 3:-int(sample_rate * 0.1)]
43
+ # wav 需要挪出末尾的静音否则也可能提前停住
44
+ prompt_text = asr_model.transcribe(wav)["text"]
45
+ # 移除最后的句点, 防止 AR S1 infer 提前停止, 加了句点可能会有停顿
46
+ prompt_text = prompt_text.replace(".", "")
47
+ prompt_phoneme = phonemizer.phonemize(prompt_text, espeak=False)
48
+ prompt_phoneme_ids = phonemizer.transform(prompt_phoneme)
49
+ prompt_phoneme_ids_len = len(prompt_phoneme_ids)
50
+ # get prompt_semantic
51
+ # (T) -> (1, T)
52
+ wav = torch.tensor(wav).unsqueeze(0)
53
+ wav = wav.cuda()
54
+ # (1, T)
55
+ prompt_semantic_tokens = semantic_tokenizer.tokenize(wav).to(torch.int32)
56
+ prompt_phoneme_ids = torch.tensor(prompt_phoneme_ids).unsqueeze(0)
57
+ prompt_phoneme_ids_len = torch.tensor([prompt_phoneme_ids_len])
58
+
59
+ result = {
60
+ 'prompt_name': prompt_name,
61
+ 'prompt_phoneme_ids': prompt_phoneme_ids,
62
+ 'prompt_semantic_tokens': prompt_semantic_tokens,
63
+ 'prompt_phoneme_ids_len': prompt_phoneme_ids_len
64
+ }
65
+
66
+ return result
67
+
68
+
69
+ def parse_args():
70
+ # parse args and config
71
+ parser = argparse.ArgumentParser(
72
+ description="Run SoundStorm AR S1 model for input text file")
73
+
74
+ parser.add_argument(
75
+ '--config_file',
76
+ type=str,
77
+ default='conf/default.yaml',
78
+ help='path of config file')
79
+
80
+ parser.add_argument(
81
+ "--text_file",
82
+ type=str,
83
+ help="text file to be convert to semantic tokens, a 'utt_id sentence' pair per line."
84
+ )
85
+
86
+ parser.add_argument(
87
+ '--ckpt_path',
88
+ type=str,
89
+ default='exp/default/ckpt/epoch=99-step=49000.ckpt',
90
+ help='Checkpoint file of SoundStorm AR S1 model.')
91
+
92
+ parser.add_argument(
93
+ '--prompt_wav_path',
94
+ type=str,
95
+ default=None,
96
+ help='extract prompt semantic and prompt phonemes from prompt wav')
97
+
98
+ # to get semantic tokens from prompt_wav
99
+ parser.add_argument("--hubert_path", type=str, default=None)
100
+ parser.add_argument("--quantizer_path", type=str, default=None)
101
+
102
+ parser.add_argument("--output_dir", type=str, help="output dir.")
103
+
104
+ args = parser.parse_args()
105
+ return args
106
+
107
+
108
+ def main():
109
+ args = parse_args()
110
+ config = load_yaml_config(args.config_file)
111
+
112
+ output_dir = Path(args.output_dir)
113
+ output_dir.mkdir(parents=True, exist_ok=True)
114
+
115
+ hz = 50
116
+ max_sec = config['data']['max_sec']
117
+
118
+ # get models
119
+ t2s_model = Text2SemanticLightningModule.load_from_checkpoint(
120
+ checkpoint_path=args.ckpt_path, config=config)
121
+ t2s_model.cuda()
122
+ t2s_model.eval()
123
+
124
+ phonemizer: GruutPhonemizer = GruutPhonemizer(language='en-us')
125
+
126
+ # models for prompt
127
+ asr_model = whisper.load_model("tiny.en")
128
+
129
+ semantic_tokenizer = SemanticTokenizer(
130
+ hubert_path=args.hubert_path,
131
+ quantizer_path=args.quantizer_path,
132
+ duplicate=True)
133
+
134
+ prompt_result = get_prompt(
135
+ prompt_wav_path=args.prompt_wav_path,
136
+ asr_model=asr_model,
137
+ phonemizer=phonemizer,
138
+ semantic_tokenizer=semantic_tokenizer)
139
+
140
+ # zero prompt => 输出的 semantic 包含的内容是对的但是音色是乱的
141
+ # (B, 1)
142
+ # prompt = torch.ones(
143
+ # batch['phoneme_ids'].size(0), 1, dtype=torch.int32) * 0
144
+
145
+ prompt = prompt_result['prompt_semantic_tokens']
146
+ prompt_phoneme_ids_len = prompt_result['prompt_phoneme_ids_len']
147
+ prompt_phoneme_ids = prompt_result['prompt_phoneme_ids']
148
+
149
+ sentences = []
150
+ with open(args.text_file, 'rt', encoding='utf-8') as f:
151
+ for line in f:
152
+ if line.strip() != "":
153
+ items = re.split(r"\s+", line.strip(), 1)
154
+ utt_id = items[0]
155
+ sentence = " ".join(items[1:])
156
+ sentences.append((utt_id, sentence))
157
+ semantic_data = [['item_name', 'semantic_audio']]
158
+ for utt_id, sentence in sentences[1:]:
159
+ # 需要自己构造伪 batch 输入给模型
160
+ batch = get_batch(sentence, phonemizer)
161
+ # prompt 和真正的输入拼接
162
+ all_phoneme_ids = torch.cat(
163
+ [prompt_phoneme_ids, batch['phoneme_ids']], dim=1)
164
+ # 或者可以直接求 all_phoneme_ids 的 shape[-1]
165
+ all_phoneme_len = prompt_phoneme_ids_len + batch['phoneme_ids_len']
166
+ st = time.time()
167
+ with torch.no_grad():
168
+ pred_semantic = t2s_model.model.infer(
169
+ all_phoneme_ids.cuda(),
170
+ all_phoneme_len.cuda(),
171
+ prompt.cuda(),
172
+ top_k=config['inference']['top_k'],
173
+ early_stop_num=hz * max_sec)
174
+ print(f'{time.time() - st} sec used in T2S')
175
+
176
+ # 删除 prompt 对应的部分
177
+ prompt_len = prompt.shape[-1]
178
+ pred_semantic = pred_semantic[:, prompt_len:]
179
+
180
+ # bs = 1
181
+ pred_semantic = pred_semantic[0]
182
+ semantic_token = pred_semantic.detach().cpu().numpy().tolist()
183
+ semantic_token_str = ' '.join(str(x) for x in semantic_token)
184
+ semantic_data.append([utt_id, semantic_token_str])
185
+
186
+ delimiter = '\t'
187
+ filename = output_dir / f'{utt_id}_p_{prompt_result["prompt_name"]}_semantic_token.tsv'
188
+ with open(filename, 'w', encoding='utf-8') as writer:
189
+ for row in semantic_data:
190
+ line = delimiter.join(row)
191
+ writer.write(line + '\n')
192
+ # clean semantic token for next setence
193
+ semantic_data = [['item_name', 'semantic_audio']]
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
AR/exps/test.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test from dump file
2
+ import argparse
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ from AR.data.dataset import Text2SemanticDataset
9
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
10
+ from AR.utils.io import load_yaml_config
11
+ from torch.utils.data import DataLoader
12
+
13
+
14
+ def parse_args():
15
+ # parse args and config
16
+ parser = argparse.ArgumentParser(
17
+ description="Run SoundStorm AR S1 model for test set.")
18
+
19
+ parser.add_argument(
20
+ '--config_file',
21
+ type=str,
22
+ default='conf/default.yaml',
23
+ help='path of config file')
24
+
25
+ # args for dataset
26
+ parser.add_argument(
27
+ '--test_semantic_path',
28
+ type=str,
29
+ default='dump/test/semantic_token.tsv')
30
+ parser.add_argument(
31
+ '--test_phoneme_path', type=str, default='dump/test/phonemes.npy')
32
+
33
+ parser.add_argument(
34
+ '--ckpt_path',
35
+ type=str,
36
+ default='exp/default/ckpt/epoch=99-step=49000.ckpt',
37
+ help='Checkpoint file of SoundStorm AR S1 model.')
38
+
39
+ parser.add_argument("--output_dir", type=str, help="output dir.")
40
+
41
+ args = parser.parse_args()
42
+ return args
43
+
44
+
45
+ def main():
46
+ args = parse_args()
47
+
48
+ config = load_yaml_config(args.config_file)
49
+
50
+ output_dir = Path(args.output_dir)
51
+ output_dir.mkdir(parents=True, exist_ok=True)
52
+
53
+ batch_size = 1
54
+ hz = 50
55
+ max_sec = config['data']['max_sec']
56
+
57
+ # get dataset
58
+ test_dataset = Text2SemanticDataset(
59
+ phoneme_path=args.test_phoneme_path,
60
+ semantic_path=args.test_semantic_path,
61
+ # max_sec 需要与训练时保持一致,不然可能会效果不好,重复漏字等
62
+ # 但是这里设置太短又会直接过滤掉太长的样本,为了防止被过滤掉,可以在 infer 的时候截断
63
+ max_sec=100,
64
+ max_sample=8,
65
+ pad_val=config['data']['pad_val'])
66
+ # get model
67
+ t2s_model = Text2SemanticLightningModule.load_from_checkpoint(
68
+ checkpoint_path=args.ckpt_path, config=config)
69
+ t2s_model.cuda()
70
+ t2s_model.eval()
71
+
72
+ # 获取 batch_size 条
73
+ # 创建 DataLoader,并指定 collate_fn 函数
74
+ dataloader = DataLoader(
75
+ test_dataset,
76
+ batch_size=batch_size,
77
+ shuffle=False,
78
+ collate_fn=test_dataset.collate)
79
+
80
+ item_names = test_dataset.__get_item_names__()
81
+
82
+ # 逐批次读取数据, bs=1、shuffle=False 时可以用 __get_item_names__ 对应
83
+ semantic_data = [['item_name', 'semantic_audio']]
84
+ for i, batch in enumerate(dataloader):
85
+ # 要保证 bs = 1
86
+ utt_id = item_names[i]
87
+ if i == 0:
88
+ print("utt_id:", utt_id)
89
+ # bs > 1 时会补零
90
+ # 与 validation_step() 保持一致
91
+ semantic_len = batch['semantic_ids'].size(1)
92
+ # 以 batch['semantic_ids'] 的前 150 个为 prompt
93
+ # 多次合成,前 prompt_len 个是一样的,而且和 prompt 一样
94
+ prompt_len = min(int(semantic_len * 0.5), 150)
95
+ # 输入纯文本时 prompt 该输入什么?=> see t2s.py
96
+ prompt = batch['semantic_ids'][:, :prompt_len]
97
+ # # zero prompt => 也可以输出文本内容正确的 semantic token, 但是音色是乱的
98
+ # 证明 semantic token 中还是包含了音色信息
99
+ # prompt = torch.ones(
100
+ # batch['semantic_ids'].size(0), 1, dtype=torch.int32) * 0
101
+ # print("prompt:", prompt)
102
+ # print("prompt.shape:", prompt.shape)
103
+ np.save(output_dir / 'prompt.npy', prompt.detach().cpu().numpy())
104
+
105
+ st = time.time()
106
+ with torch.no_grad():
107
+ # calculate acc for test
108
+ loss, acc = t2s_model.model.forward(
109
+ batch['phoneme_ids'].cuda(),
110
+ batch['phoneme_ids_len'].cuda(),
111
+ batch['semantic_ids'].cuda(),
112
+ batch['semantic_ids_len'].cuda())
113
+ print("top_3_acc of this batch:", acc)
114
+ pred_semantic = t2s_model.model.infer(
115
+ batch['phoneme_ids'].cuda(),
116
+ batch['phoneme_ids_len'].cuda(),
117
+ prompt.cuda(),
118
+ top_k=config['inference']['top_k'],
119
+ # hz * max_sec in train dataloader
120
+ # 生成的长度是 1002 应该是有一些 pad
121
+ early_stop_num=hz * max_sec)
122
+ # bs = 1
123
+ pred_semantic = pred_semantic[0]
124
+ print(f'{time.time() - st} sec used in T2S')
125
+ semantic_token = pred_semantic.detach().cpu().numpy().tolist()
126
+ semantic_token_str = ' '.join(str(x) for x in semantic_token)
127
+ semantic_data.append([utt_id, semantic_token_str])
128
+ else:
129
+ break
130
+ delimiter = '\t'
131
+ filename = output_dir / "semantic_token.tsv"
132
+ with open(filename, 'w', encoding='utf-8') as writer:
133
+ for row in semantic_data:
134
+ line = delimiter.join(row)
135
+ writer.write(line + '\n')
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()
AR/exps/text.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ 001 Life was like a box of chocolates, you never know what you're gonna get.
2
+ 002 With great power there must come great responsibility.
3
+ 003 To be or not to be, that’s a question.
4
+ 004 A man can be destroyed but not defeated
5
+ 005 Do not, for one repulse, give up the purpose that you resolved to effort.
6
+ 006 Death is just a part of life, something we're all destined to do.
7
+ 007 I think it's hard winning a war with words.
8
+ 008 Don’t argue with the people of strong determination, because they may change the fact!
9
+ 009 Love you three thousand times.
10
+ 010 tidy tiger tied a tie tighter to tidy her tiny tall.
AR/exps/train.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
2
+ import argparse
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from pytorch_lightning import seed_everything
9
+ from pytorch_lightning import Trainer
10
+ from pytorch_lightning.callbacks import ModelCheckpoint
11
+ from pytorch_lightning.loggers import WandbLogger
12
+ from pytorch_lightning.strategies import DDPStrategy
13
+ from AR.data.data_module import Text2SemanticDataModule
14
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
15
+ from soundstorm.utils.io import load_yaml_config
16
+ logging.getLogger('numba').setLevel(logging.WARNING)
17
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
18
+ torch.set_float32_matmul_precision('high')
19
+ from soundstorm.utils import get_newest_ckpt
20
+
21
+
22
+ def main(args):
23
+ output_dir = Path(args.output_dir)
24
+ output_dir.mkdir(parents=True, exist_ok=True)
25
+
26
+ ckpt_dir = output_dir / 'ckpt'
27
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
28
+
29
+ config = load_yaml_config(args.config_file)
30
+
31
+ seed_everything(config["train"]["seed"], workers=True)
32
+ ckpt_callback: ModelCheckpoint = ModelCheckpoint(
33
+ save_top_k=-1,
34
+ save_on_train_epoch_end=False,
35
+ every_n_epochs=config["train"]["save_every_n_epoch"],
36
+ dirpath=ckpt_dir)
37
+ logger = WandbLogger(
38
+ project="AR_S1",
39
+ name=output_dir.stem,
40
+ save_dir=output_dir,
41
+ # resume the loss curve
42
+ resume=True,
43
+ # id='k19kvsq8'
44
+ )
45
+ trainer: Trainer = Trainer(
46
+ max_epochs=config["train"]["epochs"],
47
+ accelerator='gpu',
48
+ devices=-1,
49
+ benchmark=False,
50
+ fast_dev_run=False,
51
+ strategy=DDPStrategy(find_unused_parameters=True),
52
+ precision=config["train"]["precision"],
53
+ logger=logger,
54
+ callbacks=[ckpt_callback])
55
+
56
+ model: Text2SemanticLightningModule = Text2SemanticLightningModule(
57
+ config, output_dir)
58
+
59
+ data_module: Text2SemanticDataModule = Text2SemanticDataModule(
60
+ config,
61
+ train_semantic_path=args.train_semantic_path,
62
+ train_phoneme_path=args.train_phoneme_path,
63
+ dev_semantic_path=args.dev_semantic_path,
64
+ dev_phoneme_path=args.dev_phoneme_path)
65
+
66
+ try:
67
+ # 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
68
+ newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
69
+ ckpt_path = ckpt_dir / newest_ckpt_name
70
+ except Exception:
71
+ ckpt_path = None
72
+ print("ckpt_path:", ckpt_path)
73
+ trainer.fit(model, data_module, ckpt_path=ckpt_path)
74
+
75
+
76
+ # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
77
+ if __name__ == '__main__':
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument(
80
+ '--config_file',
81
+ type=str,
82
+ default='conf/default.yaml',
83
+ help='path of config file')
84
+ # args for dataset
85
+ parser.add_argument(
86
+ '--train_semantic_path',
87
+ type=str,
88
+ default='dump/train/semantic_token.tsv')
89
+ parser.add_argument(
90
+ '--train_phoneme_path', type=str, default='dump/train/phonemes.npy')
91
+ parser.add_argument(
92
+ '--dev_semantic_path', type=str, default='dump/dev/semantic_token.tsv')
93
+ parser.add_argument(
94
+ '--dev_phoneme_path', type=str, default='dump/dev/phonemes.npy')
95
+ parser.add_argument(
96
+ '--output_dir',
97
+ type=str,
98
+ default='exp/default',
99
+ help='directory to save the results')
100
+
101
+ args = parser.parse_args()
102
+ logging.info(str(args))
103
+ main(args)
AR/exps/train_librilight_6k.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
2
+ import argparse
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from pytorch_lightning import seed_everything
9
+ from pytorch_lightning import Trainer
10
+ from pytorch_lightning.callbacks import ModelCheckpoint
11
+ from pytorch_lightning.loggers import WandbLogger
12
+ from pytorch_lightning.strategies import DDPStrategy
13
+ from AR.data.data_module_librilight_6k import Text2SemanticDataModule
14
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
15
+ from soundstorm.utils import get_newest_ckpt
16
+ from soundstorm.utils.io import load_yaml_config
17
+
18
+ logging.getLogger('numba').setLevel(logging.WARNING)
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ torch.set_float32_matmul_precision('high')
21
+
22
+
23
+ def main(args):
24
+ output_dir = Path(args.output_dir)
25
+ output_dir.mkdir(parents=True, exist_ok=True)
26
+
27
+ ckpt_dir = output_dir / 'ckpt'
28
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
29
+
30
+ config = load_yaml_config(args.config_file)
31
+
32
+ seed_everything(config["train"]["seed"], workers=True)
33
+
34
+ ckpt_callback: ModelCheckpoint = ModelCheckpoint(
35
+ save_top_k=-1,
36
+ save_on_train_epoch_end=False,
37
+ every_n_train_steps=config["train"]["every_n_train_steps"],
38
+ dirpath=ckpt_dir)
39
+ logger = WandbLogger(
40
+ project="AR_S1_LibriLight",
41
+ name=output_dir.stem,
42
+ save_dir=output_dir,
43
+ # resume the loss curve
44
+ resume=True,
45
+ # id='k19kvsq8'
46
+ )
47
+ trainer: Trainer = Trainer(
48
+ max_epochs=config["train"]["epochs"],
49
+ accelerator='gpu',
50
+ devices=-1,
51
+ benchmark=False,
52
+ fast_dev_run=False,
53
+ strategy=DDPStrategy(find_unused_parameters=True),
54
+ precision=config["train"]["precision"],
55
+ logger=logger,
56
+ callbacks=[ckpt_callback])
57
+
58
+ model: Text2SemanticLightningModule = Text2SemanticLightningModule(
59
+ config, output_dir)
60
+
61
+ data_module: Text2SemanticDataModule = Text2SemanticDataModule(
62
+ config,
63
+ train_semantic_dirs=args.train_semantic_dirs,
64
+ train_phoneme_dirs=args.train_phoneme_dirs,
65
+ dev_semantic_dirs=args.dev_semantic_dirs,
66
+ dev_phoneme_dirs=args.dev_phoneme_dirs,
67
+ train_non_speech_dirs=args.train_non_speech_dirs,
68
+ dev_non_speech_dirs=args.dev_non_speech_dirs)
69
+ try:
70
+ newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
71
+ ckpt_path = ckpt_dir / newest_ckpt_name
72
+ except Exception:
73
+ ckpt_path = None
74
+
75
+ print("ckpt_path:", ckpt_path)
76
+ trainer.fit(model, data_module, ckpt_path=ckpt_path)
77
+
78
+
79
+ # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
80
+ if __name__ == '__main__':
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument(
83
+ '--config_file',
84
+ type=str,
85
+ default='conf/default.yaml',
86
+ help='path of config file')
87
+ # args for dataset
88
+ parser.add_argument(
89
+ '--train_semantic_dirs',
90
+ type=list,
91
+ nargs='+',
92
+ default=["dump/small/train/"],
93
+ help='dirs of train semantic')
94
+ parser.add_argument(
95
+ '--train_phoneme_dirs',
96
+ type=list,
97
+ nargs='+',
98
+ default=["dump/small/train/"],
99
+ help='dirs of train phoneme')
100
+ parser.add_argument(
101
+ '--dev_semantic_dirs',
102
+ type=list,
103
+ nargs='+',
104
+ default=["dump/small/dev/"],
105
+ help='dirs of dev semantic')
106
+ parser.add_argument(
107
+ '--dev_phoneme_dirs',
108
+ type=list,
109
+ nargs='+',
110
+ default=["dump/small/dev/"],
111
+ help='dirs of dev phoneme')
112
+ parser.add_argument(
113
+ '--output_dir',
114
+ type=str,
115
+ default='exp/default',
116
+ help='directory to save the results')
117
+
118
+ parser.add_argument(
119
+ '--train_non_speech_dirs',
120
+ type=list,
121
+ nargs='+',
122
+ default=None,
123
+ help='dirs of train non_speech data')
124
+
125
+ parser.add_argument(
126
+ '--dev_non_speech_dirs',
127
+ type=list,
128
+ nargs='+',
129
+ default=None,
130
+ help='dirs of dev non_speech data')
131
+
132
+ args = parser.parse_args()
133
+
134
+ new_train_semantic_dirs = []
135
+ new_train_phoneme_dirs = []
136
+ new_dev_semantic_dirs = []
137
+ new_dev_phoneme_dirs = []
138
+
139
+ new_train_non_speech_dirs = []
140
+ new_dev_non_speech_dirs = []
141
+
142
+ # format dataset dirs
143
+ for item in args.train_semantic_dirs:
144
+ new_train_semantic_dirs.append(''.join(item))
145
+ args.train_semantic_dirs = new_train_semantic_dirs
146
+
147
+ for item in args.train_phoneme_dirs:
148
+ new_train_phoneme_dirs.append(''.join(item))
149
+ args.train_phoneme_dirs = new_train_phoneme_dirs
150
+
151
+ for item in args.dev_semantic_dirs:
152
+ new_dev_semantic_dirs.append(''.join(item))
153
+ args.dev_semantic_dirs = new_dev_semantic_dirs
154
+
155
+ for item in args.dev_phoneme_dirs:
156
+ new_dev_phoneme_dirs.append(''.join(item))
157
+ args.dev_phoneme_dirs = new_dev_phoneme_dirs
158
+
159
+ if args.train_non_speech_dirs is not None:
160
+ for item in args.train_non_speech_dirs:
161
+ new_train_non_speech_dirs.append(''.join(item))
162
+ args.train_non_speech_dirs = new_train_non_speech_dirs
163
+
164
+ if args.dev_non_speech_dirs is not None:
165
+ for item in args.dev_non_speech_dirs:
166
+ new_dev_non_speech_dirs.append(''.join(item))
167
+ args.dev_non_speech_dirs = new_dev_non_speech_dirs
168
+
169
+ logging.info(str(args))
170
+ main(args)
AR/models/__init__.py ADDED
File without changes
AR/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (142 Bytes). View file
 
AR/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (145 Bytes). View file
 
AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc ADDED
Binary file (3.14 kB). View file
 
AR/models/__pycache__/t2s_lightning_module.cpython-39.pyc ADDED
Binary file (3.15 kB). View file
 
AR/models/__pycache__/t2s_model.cpython-310.pyc ADDED
Binary file (6.84 kB). View file
 
AR/models/__pycache__/t2s_model.cpython-39.pyc ADDED
Binary file (6.84 kB). View file
 
AR/models/__pycache__/utils.cpython-310.pyc ADDED
Binary file (4.5 kB). View file
 
AR/models/__pycache__/utils.cpython-39.pyc ADDED
Binary file (4.48 kB). View file
 
AR/models/t2s_lightning_module.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
2
+ import os,sys
3
+ now_dir = os.getcwd()
4
+ sys.path.append(now_dir)
5
+ from typing import Dict
6
+
7
+ import torch
8
+ from pytorch_lightning import LightningModule
9
+ from AR.models.t2s_model import Text2SemanticDecoder
10
+ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
11
+ from AR.modules.optim import ScaledAdam
12
+
13
+
14
+ class Text2SemanticLightningModule(LightningModule):
15
+ def __init__(self, config, output_dir,is_train=True):
16
+ super().__init__()
17
+ self.config = config
18
+ self.top_k = 3
19
+ self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
20
+ pretrained_s1=config.get("pretrained_s1")
21
+ if(pretrained_s1 and is_train):
22
+ # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
23
+ print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"]))
24
+ if is_train:
25
+ self.automatic_optimization = False
26
+ self.save_hyperparameters()
27
+ self.eval_dir = output_dir / 'eval'
28
+ self.eval_dir.mkdir(parents=True, exist_ok=True)
29
+
30
+ def training_step(self, batch: Dict, batch_idx: int):
31
+
32
+ opt = self.optimizers()
33
+ scheduler = self.lr_schedulers()
34
+ loss, acc = self.model.forward(
35
+ batch['phoneme_ids'], batch['phoneme_ids_len'],
36
+ batch['semantic_ids'], batch['semantic_ids_len'],
37
+ batch['bert_feature'])
38
+ self.manual_backward(loss)
39
+ if batch_idx > 0 and batch_idx % 4 == 0:
40
+ opt.step()
41
+ opt.zero_grad()
42
+ scheduler.step()
43
+
44
+ self.log(
45
+ "total_loss",
46
+ loss,
47
+ on_step=True,
48
+ on_epoch=True,
49
+ prog_bar=True,
50
+ sync_dist=True)
51
+ self.log(
52
+ "lr",
53
+ scheduler.get_last_lr()[0],
54
+ on_epoch=True,
55
+ prog_bar=True,
56
+ sync_dist=True)
57
+ self.log(
58
+ f"top_{self.top_k}_acc",
59
+ acc,
60
+ on_step=True,
61
+ on_epoch=True,
62
+ prog_bar=True,
63
+ sync_dist=True)
64
+
65
+ def validation_step(self, batch: Dict, batch_idx: int):return
66
+ # # get loss
67
+ # loss, acc = self.model.forward(
68
+ # batch['phoneme_ids'], batch['phoneme_ids_len'],
69
+ # batch['semantic_ids'], batch['semantic_ids_len'],
70
+ # batch['bert_feature']
71
+ # )
72
+ #
73
+ # self.log(
74
+ # "val_total_loss",
75
+ # loss,
76
+ # on_step=True,
77
+ # on_epoch=True,
78
+ # prog_bar=True,
79
+ # sync_dist=True)
80
+ # self.log(
81
+ # f"val_top_{self.top_k}_acc",
82
+ # acc,
83
+ # on_step=True,
84
+ # on_epoch=True,
85
+ # prog_bar=True,
86
+ # sync_dist=True)
87
+ #
88
+ # # get infer output
89
+ # semantic_len = batch['semantic_ids'].size(1)
90
+ # prompt_len = min(int(semantic_len * 0.5), 150)
91
+ # prompt = batch['semantic_ids'][:, :prompt_len]
92
+ # pred_semantic = self.model.infer(batch['phoneme_ids'],
93
+ # batch['phoneme_ids_len'], prompt,
94
+ # batch['bert_feature']
95
+ # )
96
+ # save_name = f'semantic_toks_{batch_idx}.pt'
97
+ # save_path = os.path.join(self.eval_dir, save_name)
98
+ # torch.save(pred_semantic.detach().cpu(), save_path)
99
+
100
+ def configure_optimizers(self):
101
+ model_parameters = self.model.parameters()
102
+ parameters_names = []
103
+ parameters_names.append([
104
+ name_param_pair[0]
105
+ for name_param_pair in self.model.named_parameters()
106
+ ])
107
+ lm_opt = ScaledAdam(
108
+ model_parameters,
109
+ lr=0.01,
110
+ betas=(0.9, 0.95),
111
+ clipping_scale=2.0,
112
+ parameters_names=parameters_names,
113
+ show_dominant_parameters=False,
114
+ clipping_update_period=1000, )
115
+
116
+ return {
117
+ "optimizer": lm_opt,
118
+ "lr_scheduler": {
119
+ "scheduler":
120
+ WarmupCosineLRSchedule(
121
+ lm_opt,
122
+ init_lr=self.config['optimizer']['lr_init'],
123
+ peak_lr=self.config['optimizer']['lr'],
124
+ end_lr=self.config['optimizer']['lr_end'],
125
+ warmup_steps=self.config['optimizer']['warmup_steps'],
126
+ total_steps=self.config['optimizer']['decay_steps'])
127
+ }
128
+ }
AR/models/t2s_model.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ from AR.models.utils import make_pad_mask
6
+ from AR.models.utils import topk_sampling,sample,logits_to_probs,multinomial_sample_one_no_sync
7
+ from AR.modules.embedding import SinePositionalEmbedding
8
+ from AR.modules.embedding import TokenEmbedding
9
+ from AR.modules.transformer import LayerNorm
10
+ from AR.modules.transformer import TransformerEncoder
11
+ from AR.modules.transformer import TransformerEncoderLayer
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+ from torchmetrics.classification import MulticlassAccuracy
15
+
16
+ default_config = {
17
+ "embedding_dim": 512,
18
+ "hidden_dim": 512,
19
+ "num_head": 8,
20
+ "num_layers": 12,
21
+ "num_codebook": 8,
22
+ "p_dropout": 0.0,
23
+ "vocab_size": 1024 + 1,
24
+ "phoneme_vocab_size": 512,
25
+ "EOS": 1024
26
+ }
27
+
28
+
29
+ class Text2SemanticDecoder(nn.Module):
30
+ def __init__(self, config, norm_first=False, top_k=3):
31
+ super(Text2SemanticDecoder, self).__init__()
32
+ self.model_dim = config['model']["hidden_dim"]
33
+ self.embedding_dim = config['model']["embedding_dim"]
34
+ self.num_head = config['model']["head"]
35
+ self.num_layers = config['model']["n_layer"]
36
+ self.norm_first = norm_first
37
+ self.vocab_size = config['model']["vocab_size"]
38
+ self.phoneme_vocab_size = config['model']["phoneme_vocab_size"]
39
+ self.p_dropout = config['model']["dropout"]
40
+ self.EOS = config['model']["EOS"]
41
+ self.norm_first = norm_first
42
+ assert self.EOS == self.vocab_size - 1
43
+ # should be same as num of kmeans bin
44
+ # assert self.EOS == 1024
45
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
46
+ self.ar_text_embedding = TokenEmbedding(
47
+ self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
48
+ self.ar_text_position = SinePositionalEmbedding(
49
+ self.embedding_dim, dropout=0.1, scale=False, alpha=True)
50
+ self.ar_audio_embedding = TokenEmbedding(
51
+ self.embedding_dim, self.vocab_size, self.p_dropout)
52
+ self.ar_audio_position = SinePositionalEmbedding(
53
+ self.embedding_dim, dropout=0.1, scale=False, alpha=True)
54
+
55
+ self.h = TransformerEncoder(
56
+ TransformerEncoderLayer(
57
+ d_model=self.model_dim,
58
+ nhead=self.num_head,
59
+ dim_feedforward=self.model_dim * 4,
60
+ dropout=0.1,
61
+ batch_first=True,
62
+ norm_first=norm_first, ),
63
+ num_layers=self.num_layers,
64
+ norm=LayerNorm(self.model_dim) if norm_first else None, )
65
+
66
+ self.ar_predict_layer = nn.Linear(
67
+ self.model_dim, self.vocab_size, bias=False)
68
+ self.loss_fct = nn.CrossEntropyLoss(reduction='sum')
69
+
70
+ self.ar_accuracy_metric = MulticlassAccuracy(
71
+ self.vocab_size,
72
+ top_k=top_k,
73
+ average="micro",
74
+ multidim_average="global",
75
+ ignore_index=self.EOS, )
76
+
77
+ def forward(self, x, x_lens, y, y_lens, bert_feature):
78
+ '''
79
+ x: phoneme_ids
80
+ y: semantic_ids
81
+ '''
82
+ x = self.ar_text_embedding(x)
83
+ x = x + self.bert_proj(bert_feature.transpose(1,2))
84
+ x = self.ar_text_position(x)
85
+ x_mask = make_pad_mask(x_lens)
86
+
87
+ y_mask = make_pad_mask(y_lens)
88
+ y_mask_int = y_mask.type(torch.int64)
89
+ codes = y.type(torch.int64) * (1 - y_mask_int)
90
+
91
+ # Training
92
+ # AR Decoder
93
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
94
+ x_len = x_lens.max()
95
+ y_len = y_lens.max()
96
+ y_emb = self.ar_audio_embedding(y)
97
+ y_pos = self.ar_audio_position(y_emb)
98
+
99
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
100
+ ar_xy_padding_mask = xy_padding_mask
101
+
102
+ x_attn_mask = F.pad(
103
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
104
+ (0, y_len),
105
+ value=True, )
106
+ y_attn_mask = F.pad(
107
+ torch.triu(
108
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
109
+ diagonal=1, ),
110
+ (x_len, 0),
111
+ value=False, )
112
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
113
+ bsz, src_len = x.shape[0], x_len + y_len
114
+ _xy_padding_mask = (ar_xy_padding_mask.view(bsz, 1, 1, src_len)
115
+ .expand(-1, self.num_head, -1, -1)
116
+ .reshape(bsz * self.num_head, 1, src_len))
117
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
118
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
119
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
120
+ xy_attn_mask = new_attn_mask
121
+ # x 和完整的 y 一次性输入模型
122
+ xy_pos = torch.concat([x, y_pos], dim=1)
123
+ xy_dec, _ = self.h(
124
+ (xy_pos, None),
125
+ mask=xy_attn_mask, )
126
+ logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
127
+ # loss
128
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
129
+ loss = F.cross_entropy(logits, targets, reduction='sum')
130
+ acc = self.ar_accuracy_metric(logits.detach(), targets).item()
131
+ return loss, acc
132
+
133
+ # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
134
+ def infer(self,
135
+ x,
136
+ x_lens,
137
+ prompts,
138
+ bert_feature,
139
+ top_k: int=-100,
140
+ early_stop_num: int=-1,
141
+ temperature: float=1.0):
142
+
143
+ x = self.ar_text_embedding(x)
144
+ x = x + self.bert_proj(bert_feature.transpose(1,2))
145
+ x = self.ar_text_position(x)
146
+
147
+ # AR Decoder
148
+ y = prompts
149
+ prefix_len = y.shape[1]
150
+ x_len = x.shape[1]
151
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
152
+ stop = False
153
+ for _ in tqdm(range(1500)):
154
+ y_emb = self.ar_audio_embedding(y)
155
+ y_pos = self.ar_audio_position(y_emb)
156
+ # x 和逐渐增长的 y 一起输入给模型
157
+ xy_pos = torch.concat([x, y_pos], dim=1)
158
+ y_len = y.shape[1]
159
+ x_attn_mask_pad = F.pad(
160
+ x_attn_mask,
161
+ (0, y_len),
162
+ value=True, )
163
+ y_attn_mask = F.pad(
164
+ torch.triu(
165
+ torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
166
+ (x_len, 0),
167
+ value=False, )
168
+ xy_attn_mask = torch.concat(
169
+ [x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
170
+
171
+ xy_dec, _ = self.h(
172
+ (xy_pos, None),
173
+ mask=xy_attn_mask, )
174
+ logits = self.ar_predict_layer(xy_dec[:, -1])
175
+ samples = topk_sampling(
176
+ logits, top_k=top_k, top_p=1.0, temperature=temperature)
177
+
178
+ if early_stop_num != -1 and (y.shape[1] - prefix_len
179
+ ) > early_stop_num:
180
+ print("use early stop num:", early_stop_num)
181
+ stop = True
182
+
183
+ if torch.argmax(
184
+ logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
185
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
186
+ stop = True
187
+ if stop:
188
+ if prompts.shape[1] == y.shape[1]:
189
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
190
+ print('bad zero prediction')
191
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
192
+ break
193
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
194
+ # print(samples.shape)#[1,1]#第一个1是bs
195
+ # import os
196
+ # os._exit(2333)
197
+ y = torch.concat([y, samples], dim=1)
198
+ return y
199
+
200
+ def pad_y_eos(self, y, y_mask_int, eos_id):
201
+ targets = F.pad(
202
+ y, (0, 1), value=0) + eos_id * F.pad(
203
+ y_mask_int, (0, 1), value=1)
204
+ # 错位
205
+ return targets[:, :-1], targets[:, 1:]
206
+
207
+ def infer_panel(self,
208
+ x,#####全部文本token
209
+ x_lens,
210
+ prompts,####参考音频token
211
+ bert_feature,
212
+ top_k: int=-100,
213
+ early_stop_num: int=-1,
214
+ temperature: float=1.0):
215
+
216
+ x = self.ar_text_embedding(x)
217
+ x = x + self.bert_proj(bert_feature.transpose(1,2))
218
+ x = self.ar_text_position(x)
219
+
220
+ # AR Decoder
221
+ y = prompts
222
+ prefix_len = y.shape[1]
223
+ x_len = x.shape[1]
224
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
225
+ stop = False
226
+ # print(1111111,self.num_layers)
227
+ cache={
228
+ "all_stage":self.num_layers,
229
+ "k":[None]*self.num_layers,###根据配置自己手写
230
+ "v":[None]*self.num_layers,
231
+ # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
232
+ "y_emb":None,##只需要对最新的samples求emb,再拼历史的就行
233
+ # "logits":None,###原版就已经只对结尾求再拼接了,不用管
234
+ # "xy_dec":None,###不需要,本来只需要最后一个做logits
235
+ "first_infer":1,
236
+ "stage":0
237
+ }
238
+ for idx in tqdm(range(1500)):
239
+ if(cache["first_infer"]==1):
240
+ y_emb = self.ar_audio_embedding(y)
241
+ else:
242
+ y_emb = torch.cat([cache["y_emb"],self.ar_audio_embedding(y[:,-1:])],1)
243
+ cache["y_emb"]=y_emb
244
+ y_pos = self.ar_audio_position(y_emb)
245
+ # x 和逐渐增长的 y 一起输入给模型
246
+ if(cache["first_infer"]==1):
247
+ xy_pos = torch.concat([x, y_pos], dim=1)
248
+ else:
249
+ xy_pos=y_pos[:,-1:]
250
+ y_len = y_pos.shape[1]
251
+ ###以下3个不做缓存
252
+ if (cache["first_infer"] == 1):
253
+ x_attn_mask_pad = F.pad(
254
+ x_attn_mask,
255
+ (0, y_len),###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
256
+ value=True, )
257
+ y_attn_mask = F.pad(###yy的右上1扩展到左边xy的0,(y,x+y)
258
+ torch.triu(
259
+ torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
260
+ (x_len, 0),
261
+ value=False, )
262
+ xy_attn_mask = torch.concat(
263
+ [x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
264
+ else:
265
+ ###最右边一列(是错的)
266
+ # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
267
+ # xy_attn_mask[:,-1]=False
268
+ ###最下面一行(是对的)
269
+ xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool, device=xy_pos.device)
270
+ # pdb.set_trace()
271
+ ###缓存重头戏
272
+ # print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
273
+ xy_dec, _ = self.h(
274
+ (xy_pos, None),
275
+ mask=xy_attn_mask,cache=cache )
276
+ logits = self.ar_predict_layer(xy_dec[:, -1])##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
277
+ # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
278
+ samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
279
+ if early_stop_num != -1 and (y.shape[1] - prefix_len
280
+ ) > early_stop_num:
281
+ print("use early stop num:", early_stop_num)
282
+ stop = True
283
+
284
+ if torch.argmax(
285
+ logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
286
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
287
+ stop = True
288
+ if stop:
289
+ if prompts.shape[1] == y.shape[1]:
290
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
291
+ print('bad zero prediction')
292
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
293
+ break
294
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
295
+ # print(samples.shape)#[1,1]#第一个1是bs
296
+ y = torch.concat([y, samples], dim=1)
297
+ cache["first_infer"]=0
298
+ return y,idx
AR/models/utils.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+
6
+
7
+ def sequence_mask(length, max_length=None):
8
+ if max_length is None:
9
+ max_length = length.max()
10
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
11
+ return x.unsqueeze(0) < length.unsqueeze(1)
12
+
13
+
14
+ def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor:
15
+ """
16
+ Args:
17
+ lengths:
18
+ A 1-D tensor containing sentence lengths.
19
+ max_len:
20
+ The length of masks.
21
+ Returns:
22
+ Return a 2-D bool tensor, where masked positions
23
+ are filled with `True` and non-masked positions are
24
+ filled with `False`.
25
+
26
+ #>>> lengths = torch.tensor([1, 3, 2, 5])
27
+ #>>> make_pad_mask(lengths)
28
+ tensor([[False, True, True, True, True],
29
+ [False, False, False, True, True],
30
+ [False, False, True, True, True],
31
+ [False, False, False, False, False]])
32
+ """
33
+ assert lengths.ndim == 1, lengths.ndim
34
+ max_len = max(max_len, lengths.max())
35
+ n = lengths.size(0)
36
+ seq_range = torch.arange(0, max_len, device=lengths.device)
37
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
38
+
39
+ return expaned_lengths >= lengths.unsqueeze(-1)
40
+
41
+
42
+ # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
43
+ def top_k_top_p_filtering(logits,
44
+ top_k=0,
45
+ top_p=1.0,
46
+ filter_value=-float("Inf"),
47
+ min_tokens_to_keep=1):
48
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
49
+ Args:
50
+ logits: logits distribution shape (batch size, vocabulary size)
51
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
52
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
53
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
54
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
55
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
56
+ """
57
+ if top_k > 0:
58
+ top_k = min(max(top_k, min_tokens_to_keep),
59
+ logits.size(-1)) # Safety check
60
+ # Remove all tokens with a probability less than the last token of the top-k
61
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
62
+ logits[indices_to_remove] = filter_value
63
+
64
+ if top_p < 1.0:
65
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
66
+ cumulative_probs = torch.cumsum(
67
+ F.softmax(sorted_logits, dim=-1), dim=-1)
68
+
69
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
70
+ sorted_indices_to_remove = cumulative_probs > top_p
71
+ if min_tokens_to_keep > 1:
72
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
73
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
74
+ # Shift the indices to the right to keep also the first token above the threshold
75
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
76
+ ..., :-1].clone()
77
+ sorted_indices_to_remove[..., 0] = 0
78
+
79
+ # scatter sorted tensors to original indexing
80
+ indices_to_remove = sorted_indices_to_remove.scatter(
81
+ 1, sorted_indices, sorted_indices_to_remove)
82
+ logits[indices_to_remove] = filter_value
83
+ return logits
84
+
85
+
86
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
87
+ # temperature: (`optional`) float
88
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
89
+ # top_k: (`optional`) int
90
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
91
+ # top_p: (`optional`) float
92
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
93
+
94
+ # Temperature (higher temperature => more likely to sample low probability tokens)
95
+ if temperature != 1.0:
96
+ logits = logits / temperature
97
+ # Top-p/top-k filtering
98
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
99
+ # Sample
100
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
101
+ return token
102
+
103
+
104
+ from typing import Optional, Tuple
105
+ def multinomial_sample_one_no_sync(
106
+ probs_sort,
107
+ ): # Does multinomial sampling without a cuda synchronization
108
+ q = torch.empty_like(probs_sort).exponential_(1)
109
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
110
+
111
+
112
+ def logits_to_probs(
113
+ logits,
114
+ previous_tokens: Optional[torch.Tensor] = None,
115
+ temperature: float = 1.0,
116
+ top_k: Optional[int] = None,
117
+ top_p: Optional[int] = None,
118
+ repetition_penalty: float = 1.0,
119
+ ):
120
+ previous_tokens=previous_tokens.squeeze()
121
+ # print(logits.shape,previous_tokens.shape)
122
+ # pdb.set_trace()
123
+ if previous_tokens is not None and repetition_penalty != 1.0:
124
+ previous_tokens = previous_tokens.long()
125
+ score = torch.gather(logits, dim=0, index=previous_tokens)
126
+ score = torch.where(
127
+ score < 0, score * repetition_penalty, score / repetition_penalty
128
+ )
129
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
130
+
131
+ if top_p is not None and top_p < 1.0:
132
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
133
+ cum_probs = torch.cumsum(
134
+ torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
135
+ )
136
+ sorted_indices_to_remove = cum_probs > top_p
137
+ sorted_indices_to_remove[0] = False # keep at least one option
138
+ indices_to_remove = sorted_indices_to_remove.scatter(
139
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
140
+ )
141
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
142
+
143
+ logits = logits / max(temperature, 1e-5)
144
+
145
+ if top_k is not None:
146
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
147
+ pivot = v.select(-1, -1).unsqueeze(-1)
148
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
149
+
150
+ probs = torch.nn.functional.softmax(logits, dim=-1)
151
+ return probs
152
+
153
+
154
+ def sample(
155
+ logits,
156
+ previous_tokens: Optional[torch.Tensor] = None,
157
+ **sampling_kwargs,
158
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
159
+ probs = logits_to_probs(
160
+ logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
161
+ )
162
+ idx_next = multinomial_sample_one_no_sync(probs)
163
+ return idx_next, probs
164
+
AR/modules/__init__.py ADDED
File without changes
AR/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file
 
AR/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (146 Bytes). View file