Spaces:
Runtime error
Runtime error
Ricecake123
commited on
Commit
•
e79b770
1
Parent(s):
298e47d
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- AR/__init__.py +0 -0
- AR/__pycache__/__init__.cpython-310.pyc +0 -0
- AR/__pycache__/__init__.cpython-39.pyc +0 -0
- AR/data/__init__.py +0 -0
- AR/data/__pycache__/__init__.cpython-310.pyc +0 -0
- AR/data/__pycache__/__init__.cpython-39.pyc +0 -0
- AR/data/__pycache__/bucket_sampler.cpython-310.pyc +0 -0
- AR/data/__pycache__/bucket_sampler.cpython-39.pyc +0 -0
- AR/data/__pycache__/data_module.cpython-310.pyc +0 -0
- AR/data/__pycache__/data_module.cpython-39.pyc +0 -0
- AR/data/__pycache__/dataset.cpython-310.pyc +0 -0
- AR/data/__pycache__/dataset.cpython-39.pyc +0 -0
- AR/data/bucket_sampler.py +157 -0
- AR/data/data_module.py +66 -0
- AR/data/dataset.py +302 -0
- AR/exps/__init__.py +0 -0
- AR/exps/beats/BEATs.py +179 -0
- AR/exps/beats/README.md +127 -0
- AR/exps/beats/Tokenizers.py +172 -0
- AR/exps/beats/__init__.py +2 -0
- AR/exps/beats/backbone.py +791 -0
- AR/exps/beats/config.py +19 -0
- AR/exps/beats/modules.py +220 -0
- AR/exps/beats/ontology.json +0 -0
- AR/exps/beats/quantizer.py +235 -0
- AR/exps/get_beats_librilight.py +321 -0
- AR/exps/get_phones.py +232 -0
- AR/exps/get_phones_librilight.py +198 -0
- AR/exps/get_txt_librilight.py +255 -0
- AR/exps/split_train_val.py +35 -0
- AR/exps/t2s.py +197 -0
- AR/exps/test.py +139 -0
- AR/exps/text.txt +10 -0
- AR/exps/train.py +103 -0
- AR/exps/train_librilight_6k.py +170 -0
- AR/models/__init__.py +0 -0
- AR/models/__pycache__/__init__.cpython-310.pyc +0 -0
- AR/models/__pycache__/__init__.cpython-39.pyc +0 -0
- AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc +0 -0
- AR/models/__pycache__/t2s_lightning_module.cpython-39.pyc +0 -0
- AR/models/__pycache__/t2s_model.cpython-310.pyc +0 -0
- AR/models/__pycache__/t2s_model.cpython-39.pyc +0 -0
- AR/models/__pycache__/utils.cpython-310.pyc +0 -0
- AR/models/__pycache__/utils.cpython-39.pyc +0 -0
- AR/models/t2s_lightning_module.py +128 -0
- AR/models/t2s_model.py +298 -0
- AR/models/utils.py +164 -0
- AR/modules/__init__.py +0 -0
- AR/modules/__pycache__/__init__.cpython-310.pyc +0 -0
- AR/modules/__pycache__/__init__.cpython-39.pyc +0 -0
AR/__init__.py
ADDED
File without changes
|
AR/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (135 Bytes). View file
|
|
AR/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (138 Bytes). View file
|
|
AR/data/__init__.py
ADDED
File without changes
|
AR/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (140 Bytes). View file
|
|
AR/data/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (143 Bytes). View file
|
|
AR/data/__pycache__/bucket_sampler.cpython-310.pyc
ADDED
Binary file (4.42 kB). View file
|
|
AR/data/__pycache__/bucket_sampler.cpython-39.pyc
ADDED
Binary file (4.39 kB). View file
|
|
AR/data/__pycache__/data_module.cpython-310.pyc
ADDED
Binary file (2.27 kB). View file
|
|
AR/data/__pycache__/data_module.cpython-39.pyc
ADDED
Binary file (2.29 kB). View file
|
|
AR/data/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (6.58 kB). View file
|
|
AR/data/__pycache__/dataset.cpython-39.pyc
ADDED
Binary file (6.57 kB). View file
|
|
AR/data/bucket_sampler.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/bucketsampler.py
|
2 |
+
import itertools
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
from random import shuffle
|
6 |
+
from typing import Iterator
|
7 |
+
from typing import Optional
|
8 |
+
from typing import TypeVar
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.distributed as dist
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
from torch.utils.data import Sampler
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
"DistributedBucketSampler",
|
17 |
+
]
|
18 |
+
|
19 |
+
T_co = TypeVar('T_co', covariant=True)
|
20 |
+
|
21 |
+
|
22 |
+
class DistributedBucketSampler(Sampler[T_co]):
|
23 |
+
r"""
|
24 |
+
sort the dataset wrt. input length
|
25 |
+
divide samples into buckets
|
26 |
+
sort within buckets
|
27 |
+
divide buckets into batches
|
28 |
+
sort batches
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
dataset: Dataset,
|
33 |
+
num_replicas: Optional[int]=None,
|
34 |
+
rank: Optional[int]=None,
|
35 |
+
shuffle: bool=True,
|
36 |
+
seed: int=0,
|
37 |
+
drop_last: bool=False,
|
38 |
+
batch_size: int=32) -> None:
|
39 |
+
if num_replicas is None:
|
40 |
+
if not dist.is_available():
|
41 |
+
raise RuntimeError(
|
42 |
+
"Requires distributed package to be available")
|
43 |
+
num_replicas = dist.get_world_size()
|
44 |
+
if rank is None:
|
45 |
+
if not dist.is_available():
|
46 |
+
raise RuntimeError(
|
47 |
+
"Requires distributed package to be available")
|
48 |
+
rank = dist.get_rank()
|
49 |
+
torch.cuda.set_device(rank)
|
50 |
+
if rank >= num_replicas or rank < 0:
|
51 |
+
raise ValueError("Invalid rank {}, rank should be in the interval"
|
52 |
+
" [0, {}]".format(rank, num_replicas - 1))
|
53 |
+
self.dataset = dataset
|
54 |
+
self.num_replicas = num_replicas
|
55 |
+
self.rank = rank
|
56 |
+
self.epoch = 0
|
57 |
+
self.drop_last = drop_last
|
58 |
+
# If the dataset length is evenly divisible by # of replicas, then there
|
59 |
+
# is no need to drop any data, since the dataset will be split equally.
|
60 |
+
if self.drop_last and len(
|
61 |
+
self.
|
62 |
+
dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
63 |
+
# Split to nearest available length that is evenly divisible.
|
64 |
+
# This is to ensure each rank receives the same amount of data when
|
65 |
+
# using this Sampler.
|
66 |
+
self.num_samples = math.ceil(
|
67 |
+
(len(self.dataset) - self.num_replicas) /
|
68 |
+
self.num_replicas # type: ignore[arg-type]
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
self.num_samples = math.ceil(
|
72 |
+
len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
|
73 |
+
self.total_size = self.num_samples * self.num_replicas
|
74 |
+
self.shuffle = shuffle
|
75 |
+
self.seed = seed
|
76 |
+
self.batch_size = batch_size
|
77 |
+
self.id_with_length = self._get_sample_lengths()
|
78 |
+
self.id_buckets = self.make_buckets(bucket_width=2.0)
|
79 |
+
|
80 |
+
def _get_sample_lengths(self):
|
81 |
+
id_with_lengths = []
|
82 |
+
for i in range(len(self.dataset)):
|
83 |
+
id_with_lengths.append((i, self.dataset.get_sample_length(i)))
|
84 |
+
id_with_lengths.sort(key=lambda x: x[1])
|
85 |
+
return id_with_lengths
|
86 |
+
|
87 |
+
def make_buckets(self, bucket_width: float=2.0):
|
88 |
+
buckets = []
|
89 |
+
cur = []
|
90 |
+
max_sec = bucket_width
|
91 |
+
for id, sec in self.id_with_length:
|
92 |
+
if sec < max_sec:
|
93 |
+
cur.append(id)
|
94 |
+
else:
|
95 |
+
buckets.append(cur)
|
96 |
+
cur = [id]
|
97 |
+
max_sec += bucket_width
|
98 |
+
if len(cur) > 0:
|
99 |
+
buckets.append(cur)
|
100 |
+
return buckets
|
101 |
+
|
102 |
+
def __iter__(self) -> Iterator[T_co]:
|
103 |
+
if self.shuffle:
|
104 |
+
# deterministically shuffle based on epoch and seed
|
105 |
+
g = torch.Generator()
|
106 |
+
g.manual_seed(self.seed + self.epoch)
|
107 |
+
random.seed(self.epoch + self.seed)
|
108 |
+
shuffled_bucket = []
|
109 |
+
for buc in self.id_buckets:
|
110 |
+
buc_copy = buc.copy()
|
111 |
+
shuffle(buc_copy)
|
112 |
+
shuffled_bucket.append(buc_copy)
|
113 |
+
grouped_batch_size = self.batch_size * self.num_replicas
|
114 |
+
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
|
115 |
+
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
|
116 |
+
batches = [
|
117 |
+
shuffled_bucket[b * grouped_batch_size:(b + 1) *
|
118 |
+
grouped_batch_size] for b in range(n_batch)
|
119 |
+
]
|
120 |
+
shuffle(batches)
|
121 |
+
indices = list(itertools.chain(*batches))
|
122 |
+
else:
|
123 |
+
# type: ignore[arg-type]
|
124 |
+
indices = list(range(len(self.dataset)))
|
125 |
+
|
126 |
+
if not self.drop_last:
|
127 |
+
# add extra samples to make it evenly divisible
|
128 |
+
padding_size = self.total_size - len(indices)
|
129 |
+
if padding_size <= len(indices):
|
130 |
+
indices += indices[:padding_size]
|
131 |
+
else:
|
132 |
+
indices += (indices * math.ceil(padding_size /
|
133 |
+
len(indices)))[:padding_size]
|
134 |
+
else:
|
135 |
+
# remove tail of data to make it evenly divisible.
|
136 |
+
indices = indices[:self.total_size]
|
137 |
+
assert len(indices) == self.total_size
|
138 |
+
|
139 |
+
# subsample
|
140 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
141 |
+
assert len(indices) == self.num_samples
|
142 |
+
|
143 |
+
return iter(indices)
|
144 |
+
|
145 |
+
def __len__(self) -> int:
|
146 |
+
return self.num_samples
|
147 |
+
|
148 |
+
def set_epoch(self, epoch: int) -> None:
|
149 |
+
r"""
|
150 |
+
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
151 |
+
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
152 |
+
sampler will yield the same ordering.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
epoch (int): Epoch number.
|
156 |
+
"""
|
157 |
+
self.epoch = epoch
|
AR/data/data_module.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
|
2 |
+
from pytorch_lightning import LightningDataModule
|
3 |
+
from AR.data.bucket_sampler import DistributedBucketSampler
|
4 |
+
from AR.data.dataset import Text2SemanticDataset
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
|
8 |
+
class Text2SemanticDataModule(LightningDataModule):
|
9 |
+
def __init__(self, config, train_semantic_path, train_phoneme_path,dev_semantic_path=None, dev_phoneme_path=None):
|
10 |
+
super().__init__()
|
11 |
+
self.config = config
|
12 |
+
self.train_semantic_path = train_semantic_path
|
13 |
+
self.train_phoneme_path = train_phoneme_path
|
14 |
+
self.dev_semantic_path = dev_semantic_path
|
15 |
+
self.dev_phoneme_path = dev_phoneme_path
|
16 |
+
self.num_workers = self.config['data']['num_workers']
|
17 |
+
|
18 |
+
def prepare_data(self):
|
19 |
+
pass
|
20 |
+
|
21 |
+
def setup(self, stage=None, output_logs=False):
|
22 |
+
self._train_dataset = Text2SemanticDataset(
|
23 |
+
phoneme_path=self.train_phoneme_path,
|
24 |
+
semantic_path=self.train_semantic_path,
|
25 |
+
max_sec=self.config['data']['max_sec'],
|
26 |
+
pad_val=self.config['data']['pad_val'])
|
27 |
+
self._dev_dataset = self._train_dataset
|
28 |
+
# self._dev_dataset = Text2SemanticDataset(
|
29 |
+
# phoneme_path=self.dev_phoneme_path,
|
30 |
+
# semantic_path=self.dev_semantic_path,
|
31 |
+
# max_sample=self.config['data']['max_eval_sample'],
|
32 |
+
# max_sec=self.config['data']['max_sec'],
|
33 |
+
# pad_val=self.config['data']['pad_val'])
|
34 |
+
|
35 |
+
def train_dataloader(self):
|
36 |
+
batch_size = self.config['train']['batch_size']
|
37 |
+
sampler = DistributedBucketSampler(
|
38 |
+
self._train_dataset, batch_size=batch_size)
|
39 |
+
return DataLoader(
|
40 |
+
self._train_dataset,
|
41 |
+
batch_size=batch_size,
|
42 |
+
sampler=sampler,
|
43 |
+
collate_fn=self._train_dataset.collate,
|
44 |
+
num_workers=self.num_workers,
|
45 |
+
persistent_workers=True,
|
46 |
+
prefetch_factor=16
|
47 |
+
)
|
48 |
+
|
49 |
+
def val_dataloader(self):
|
50 |
+
return DataLoader(
|
51 |
+
self._dev_dataset,
|
52 |
+
batch_size=1,
|
53 |
+
shuffle=False,
|
54 |
+
collate_fn=self._train_dataset.collate,
|
55 |
+
num_workers=max(self.num_workers,12),
|
56 |
+
persistent_workers=True,
|
57 |
+
prefetch_factor=16
|
58 |
+
)
|
59 |
+
|
60 |
+
# 这个会使用到嘛?
|
61 |
+
def test_dataloader(self):
|
62 |
+
return DataLoader(
|
63 |
+
self._dev_dataset,
|
64 |
+
batch_size=1,
|
65 |
+
shuffle=False,
|
66 |
+
collate_fn=self._train_dataset.collate)
|
AR/data/dataset.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
|
2 |
+
import pdb
|
3 |
+
import sys
|
4 |
+
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
|
5 |
+
import traceback,os
|
6 |
+
from typing import Dict
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import torch,json
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
from transformers import AutoTokenizer
|
15 |
+
|
16 |
+
from text import cleaned_text_to_sequence
|
17 |
+
# from config import exp_dir
|
18 |
+
|
19 |
+
def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
|
20 |
+
seq = sequences[0]
|
21 |
+
ndim = seq.ndim
|
22 |
+
if axis < 0:
|
23 |
+
axis += ndim
|
24 |
+
dtype = seq.dtype
|
25 |
+
pad_value = dtype.type(pad_value)
|
26 |
+
seq_lengths = [seq.shape[axis] for seq in sequences]
|
27 |
+
max_length = np.max(seq_lengths)
|
28 |
+
|
29 |
+
padded_sequences = []
|
30 |
+
for seq, length in zip(sequences, seq_lengths):
|
31 |
+
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (
|
32 |
+
ndim - axis - 1)
|
33 |
+
padded_seq = np.pad(
|
34 |
+
seq, padding, mode='constant', constant_values=pad_value)
|
35 |
+
padded_sequences.append(padded_seq)
|
36 |
+
batch = np.stack(padded_sequences)
|
37 |
+
return batch
|
38 |
+
|
39 |
+
class Text2SemanticDataset(Dataset):
|
40 |
+
"""dataset class for text tokens to semantic model training."""
|
41 |
+
|
42 |
+
def __init__(self,
|
43 |
+
phoneme_path: str,
|
44 |
+
semantic_path: str,
|
45 |
+
max_sample: int = None,
|
46 |
+
max_sec: int = 100,
|
47 |
+
pad_val: int = 1024,
|
48 |
+
# min value of phoneme/sec
|
49 |
+
min_ps_ratio: int = 3,
|
50 |
+
# max value of phoneme/sec
|
51 |
+
max_ps_ratio: int = 25) -> None:
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.semantic_data = pd.read_csv(semantic_path, delimiter='\t', encoding="utf-8")
|
55 |
+
# get dict
|
56 |
+
self.path2=phoneme_path#"%s/2-name2text.txt"%exp_dir#phoneme_path
|
57 |
+
self.path3="%s/3-bert"%(os.path.basename(phoneme_path))#"%s/3-bert"%exp_dir#bert_dir
|
58 |
+
self.path6=semantic_path#"%s/6-name2semantic.tsv"%exp_dir#semantic_path
|
59 |
+
assert os.path.exists(self.path2)
|
60 |
+
assert os.path.exists(self.path6)
|
61 |
+
self.phoneme_data={}
|
62 |
+
with open(self.path2,"r",encoding="utf8")as f:
|
63 |
+
lines=f.read().strip("\n").split("\n")
|
64 |
+
|
65 |
+
for line in lines:
|
66 |
+
tmp=line.split("\t")
|
67 |
+
if(len(tmp)!=4):continue
|
68 |
+
self.phoneme_data[tmp[0]]=[tmp[1],tmp[2],tmp[3]]
|
69 |
+
|
70 |
+
# self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
|
71 |
+
# pad for semantic tokens
|
72 |
+
self.PAD: int = pad_val
|
73 |
+
# self.hz = 25
|
74 |
+
# with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
|
75 |
+
# data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
|
76 |
+
# self.hz=int(data[:-2])#
|
77 |
+
self.hz=int(os.environ.get("hz","25hz")[:-2])
|
78 |
+
|
79 |
+
# max seconds of semantic token
|
80 |
+
self.max_sec = max_sec
|
81 |
+
self.min_ps_ratio = min_ps_ratio
|
82 |
+
self.max_ps_ratio = max_ps_ratio
|
83 |
+
|
84 |
+
if max_sample is not None:
|
85 |
+
self.semantic_data = self.semantic_data[:max_sample]
|
86 |
+
|
87 |
+
# {idx: (semantic, phoneme)}
|
88 |
+
# semantic list, phoneme list
|
89 |
+
self.semantic_phoneme = []
|
90 |
+
self.item_names = []
|
91 |
+
|
92 |
+
self.inited = False
|
93 |
+
|
94 |
+
if not self.inited:
|
95 |
+
# 调用初始化函数
|
96 |
+
self.init_batch()
|
97 |
+
self.inited = True
|
98 |
+
del self.semantic_data
|
99 |
+
del self.phoneme_data
|
100 |
+
# self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
|
101 |
+
# self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
|
102 |
+
|
103 |
+
|
104 |
+
def init_batch(self):
|
105 |
+
semantic_data_len = len(self.semantic_data)
|
106 |
+
phoneme_data_len = len(self.phoneme_data.keys())
|
107 |
+
print("semantic_data_len:", semantic_data_len)
|
108 |
+
print("phoneme_data_len:", phoneme_data_len)
|
109 |
+
idx = 0
|
110 |
+
num_not_in = 0
|
111 |
+
num_deleted_bigger = 0
|
112 |
+
num_deleted_ps = 0
|
113 |
+
for i in range(semantic_data_len):
|
114 |
+
# 先依次遍历
|
115 |
+
# get str
|
116 |
+
item_name = self.semantic_data['item_name'][i]
|
117 |
+
# print(self.phoneme_data)
|
118 |
+
try:
|
119 |
+
phoneme, word2ph, text = self.phoneme_data[item_name]
|
120 |
+
except Exception:
|
121 |
+
traceback.print_exc()
|
122 |
+
# print(f"{item_name} not in self.phoneme_data !")
|
123 |
+
num_not_in += 1
|
124 |
+
continue
|
125 |
+
|
126 |
+
semantic_str = self.semantic_data['semantic_audio'][i]
|
127 |
+
# get token list
|
128 |
+
semantic_ids = [int(idx) for idx in semantic_str.split(' ')]
|
129 |
+
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
|
130 |
+
# 过滤掉太长的样本
|
131 |
+
if len(semantic_ids) > self.max_sec * self.hz:#########1###根据token���数推测总时长过滤时长60s(config里)#40*25=1k
|
132 |
+
num_deleted_bigger += 1
|
133 |
+
continue
|
134 |
+
# (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
|
135 |
+
phoneme = phoneme.split(' ')
|
136 |
+
|
137 |
+
try:
|
138 |
+
phoneme_ids = cleaned_text_to_sequence(phoneme)
|
139 |
+
except:
|
140 |
+
traceback.print_exc()
|
141 |
+
# print(f"{item_name} not in self.phoneme_data !")
|
142 |
+
num_not_in += 1
|
143 |
+
continue
|
144 |
+
# if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
|
145 |
+
if len(phoneme_ids) >self.max_sec * self.hz/2.5:###########2:改为恒定限制为semantic/2.5就行
|
146 |
+
num_deleted_ps += 1
|
147 |
+
continue
|
148 |
+
# if len(semantic_ids) > 1000:###########3
|
149 |
+
# num_deleted_bigger += 1
|
150 |
+
# continue
|
151 |
+
|
152 |
+
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
|
153 |
+
|
154 |
+
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio:##########4#3~25#每秒多少个phone
|
155 |
+
num_deleted_ps += 1
|
156 |
+
# print(item_name)
|
157 |
+
continue
|
158 |
+
|
159 |
+
self.semantic_phoneme.append((semantic_ids, phoneme_ids))
|
160 |
+
idx += 1
|
161 |
+
self.item_names.append(item_name)
|
162 |
+
|
163 |
+
min_num=100#20直接不补#30补了也不存ckpt
|
164 |
+
leng =len(self.semantic_phoneme)
|
165 |
+
if(leng<min_num):
|
166 |
+
tmp1=self.semantic_phoneme
|
167 |
+
tmp2=self.item_names
|
168 |
+
self.semantic_phoneme=[]
|
169 |
+
self.item_names=[]
|
170 |
+
for _ in range(max(2,int(min_num/leng))):
|
171 |
+
self.semantic_phoneme+=tmp1
|
172 |
+
self.item_names+=tmp2
|
173 |
+
if num_not_in > 0:
|
174 |
+
print(f"there are {num_not_in} semantic datas not in phoneme datas")
|
175 |
+
if num_deleted_bigger > 0:
|
176 |
+
print(
|
177 |
+
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
|
178 |
+
)
|
179 |
+
if num_deleted_ps > 0:
|
180 |
+
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
|
181 |
+
print(
|
182 |
+
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
|
183 |
+
)
|
184 |
+
'''
|
185 |
+
there are 31 semantic datas not in phoneme datas
|
186 |
+
deleted 34 audios who's duration are bigger than 54 seconds
|
187 |
+
deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
|
188 |
+
dataset.__len__(): 366463
|
189 |
+
|
190 |
+
'''
|
191 |
+
# 345410 for LibriTTS
|
192 |
+
print("dataset.__len__():", self.__len__())
|
193 |
+
|
194 |
+
def __get_item_names__(self) -> List[str]:
|
195 |
+
return self.item_names
|
196 |
+
|
197 |
+
def __len__(self) -> int:
|
198 |
+
return len(self.semantic_phoneme)
|
199 |
+
|
200 |
+
def __getitem__(self, idx: int) -> Dict:
|
201 |
+
semantic_ids, phoneme_ids = self.semantic_phoneme[idx]
|
202 |
+
item_name = self.item_names[idx]
|
203 |
+
phoneme_ids_len = len(phoneme_ids)
|
204 |
+
# semantic tokens target
|
205 |
+
semantic_ids_len = len(semantic_ids)
|
206 |
+
|
207 |
+
flag=0
|
208 |
+
path_bert = "%s/%s.pt" % (self.path3, item_name)
|
209 |
+
if(os.path.exists(path_bert)==True):bert_feature = torch.load(path_bert,map_location="cpu")
|
210 |
+
else:flag=1
|
211 |
+
if(flag==1):
|
212 |
+
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
|
213 |
+
bert_feature=None
|
214 |
+
else:
|
215 |
+
assert bert_feature.shape[-1] == len(phoneme_ids)
|
216 |
+
return {
|
217 |
+
'idx': idx,
|
218 |
+
'phoneme_ids': phoneme_ids,
|
219 |
+
'phoneme_ids_len': phoneme_ids_len,
|
220 |
+
'semantic_ids': semantic_ids,
|
221 |
+
'semantic_ids_len': semantic_ids_len,
|
222 |
+
'bert_feature': bert_feature,
|
223 |
+
}
|
224 |
+
|
225 |
+
def get_sample_length(self, idx: int):
|
226 |
+
semantic_ids = self.semantic_phoneme[idx][0]
|
227 |
+
sec = 1.0 * len(semantic_ids) / self.hz
|
228 |
+
return sec
|
229 |
+
|
230 |
+
def collate(self, examples: List[Dict]) -> Dict:
|
231 |
+
sample_index: List[int] = []
|
232 |
+
phoneme_ids: List[torch.Tensor] = []
|
233 |
+
phoneme_ids_lens: List[int] = []
|
234 |
+
semantic_ids: List[torch.Tensor] = []
|
235 |
+
semantic_ids_lens: List[int] = []
|
236 |
+
# return
|
237 |
+
|
238 |
+
|
239 |
+
for item in examples:
|
240 |
+
sample_index.append(item["idx"])
|
241 |
+
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
|
242 |
+
semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
|
243 |
+
phoneme_ids_lens.append(item["phoneme_ids_len"])
|
244 |
+
semantic_ids_lens.append(item["semantic_ids_len"])
|
245 |
+
|
246 |
+
# pad 0
|
247 |
+
phoneme_ids = batch_sequences(phoneme_ids)
|
248 |
+
semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
|
249 |
+
|
250 |
+
# # convert each batch to torch.tensor
|
251 |
+
phoneme_ids = torch.tensor(phoneme_ids)
|
252 |
+
semantic_ids = torch.tensor(semantic_ids)
|
253 |
+
phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
|
254 |
+
semantic_ids_lens = torch.tensor(semantic_ids_lens)
|
255 |
+
bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
|
256 |
+
bert_padded.zero_()
|
257 |
+
|
258 |
+
for idx, item in enumerate(examples):
|
259 |
+
bert = item['bert_feature']
|
260 |
+
if(bert!=None):
|
261 |
+
bert_padded[idx, :, :bert.shape[-1]] = bert
|
262 |
+
|
263 |
+
return {
|
264 |
+
# List[int]
|
265 |
+
"ids": sample_index,
|
266 |
+
# torch.Tensor (B, max_phoneme_length)
|
267 |
+
"phoneme_ids": phoneme_ids,
|
268 |
+
# torch.Tensor (B)
|
269 |
+
"phoneme_ids_len": phoneme_ids_lens,
|
270 |
+
# torch.Tensor (B, max_semantic_ids_length)
|
271 |
+
"semantic_ids": semantic_ids,
|
272 |
+
# torch.Tensor (B)
|
273 |
+
"semantic_ids_len": semantic_ids_lens,
|
274 |
+
# torch.Tensor (B, 1024, max_phoneme_length)
|
275 |
+
"bert_feature": bert_padded,
|
276 |
+
}
|
277 |
+
|
278 |
+
|
279 |
+
if __name__ == '__main__':
|
280 |
+
root_dir = '/data/docker/liujing04/gpt-vits/prepare/dump_mix/'
|
281 |
+
dataset = Text2SemanticDataset(
|
282 |
+
phoneme_path=root_dir + 'phoneme_train.npy',
|
283 |
+
semantic_path=root_dir + 'semantic_train.tsv')
|
284 |
+
|
285 |
+
batch_size = 12
|
286 |
+
dataloader = DataLoader(
|
287 |
+
dataset,
|
288 |
+
batch_size=batch_size,
|
289 |
+
collate_fn=dataset.collate,
|
290 |
+
shuffle=False)
|
291 |
+
for i, batch in enumerate(dataloader):
|
292 |
+
if(i%1000==0):print(i)
|
293 |
+
# if i == 0:
|
294 |
+
# print('batch["ids"]:', batch["ids"])
|
295 |
+
# print('batch["phoneme_ids"]:', batch["phoneme_ids"],
|
296 |
+
# batch["phoneme_ids"].shape)
|
297 |
+
# print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
|
298 |
+
# batch["phoneme_ids_len"].shape)
|
299 |
+
# print('batch["semantic_ids"]:', batch["semantic_ids"],
|
300 |
+
# batch["semantic_ids"].shape)
|
301 |
+
# print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
|
302 |
+
# batch["semantic_ids_len"].shape)
|
AR/exps/__init__.py
ADDED
File without changes
|
AR/exps/beats/BEATs.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
import logging
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torchaudio.compliance.kaldi as ta_kaldi
|
15 |
+
from torch.nn import LayerNorm
|
16 |
+
|
17 |
+
from .backbone import TransformerEncoder
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class BEATsConfig:
|
23 |
+
def __init__(self, cfg=None):
|
24 |
+
self.input_patch_size: int = -1 # path size of patch embedding
|
25 |
+
self.embed_dim: int = 512 # patch embedding dimension
|
26 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
27 |
+
|
28 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
29 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
30 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
31 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
32 |
+
self.activation_fn: str = "gelu" # activation function to use
|
33 |
+
|
34 |
+
self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
|
35 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
36 |
+
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
37 |
+
|
38 |
+
# dropouts
|
39 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
40 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
41 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
42 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
43 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
44 |
+
|
45 |
+
# positional embeddings
|
46 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
47 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
48 |
+
|
49 |
+
# relative position embedding
|
50 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
51 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
52 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
53 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
54 |
+
|
55 |
+
# label predictor
|
56 |
+
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
|
57 |
+
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
|
58 |
+
self.predictor_class: int = 527 # target class number for the predictor
|
59 |
+
|
60 |
+
if cfg is not None:
|
61 |
+
self.update(cfg)
|
62 |
+
|
63 |
+
def update(self, cfg: dict):
|
64 |
+
self.__dict__.update(cfg)
|
65 |
+
|
66 |
+
|
67 |
+
class BEATs(nn.Module):
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
cfg: BEATsConfig, ) -> None:
|
71 |
+
super().__init__()
|
72 |
+
logger.info(f"BEATs Config: {cfg.__dict__}")
|
73 |
+
|
74 |
+
self.cfg = cfg
|
75 |
+
|
76 |
+
self.embed = cfg.embed_dim
|
77 |
+
self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)
|
78 |
+
if self.embed != cfg.encoder_embed_dim else
|
79 |
+
None)
|
80 |
+
|
81 |
+
self.input_patch_size = cfg.input_patch_size
|
82 |
+
self.patch_embedding = nn.Conv2d(
|
83 |
+
1,
|
84 |
+
self.embed,
|
85 |
+
kernel_size=self.input_patch_size,
|
86 |
+
stride=self.input_patch_size,
|
87 |
+
bias=cfg.conv_bias)
|
88 |
+
|
89 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
90 |
+
|
91 |
+
assert not cfg.deep_norm or not cfg.layer_norm_first
|
92 |
+
self.encoder = TransformerEncoder(cfg)
|
93 |
+
self.layer_norm = LayerNorm(self.embed)
|
94 |
+
|
95 |
+
if cfg.finetuned_model:
|
96 |
+
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
|
97 |
+
self.predictor = nn.Linear(cfg.encoder_embed_dim,
|
98 |
+
cfg.predictor_class)
|
99 |
+
else:
|
100 |
+
self.predictor = None
|
101 |
+
|
102 |
+
def forward_padding_mask(
|
103 |
+
self,
|
104 |
+
features: torch.Tensor,
|
105 |
+
padding_mask: torch.Tensor, ) -> torch.Tensor:
|
106 |
+
extra = padding_mask.size(1) % features.size(1)
|
107 |
+
if extra > 0:
|
108 |
+
padding_mask = padding_mask[:, :-extra]
|
109 |
+
padding_mask = padding_mask.view(
|
110 |
+
padding_mask.size(0), features.size(1), -1)
|
111 |
+
padding_mask = padding_mask.all(-1)
|
112 |
+
return padding_mask
|
113 |
+
|
114 |
+
def preprocess(
|
115 |
+
self,
|
116 |
+
source: torch.Tensor,
|
117 |
+
fbank_mean: float=15.41663,
|
118 |
+
fbank_std: float=6.55582, ) -> torch.Tensor:
|
119 |
+
fbanks = []
|
120 |
+
for waveform in source:
|
121 |
+
waveform = waveform.unsqueeze(0) * 2**15
|
122 |
+
fbank = ta_kaldi.fbank(
|
123 |
+
waveform,
|
124 |
+
num_mel_bins=128,
|
125 |
+
sample_frequency=16000,
|
126 |
+
frame_length=25,
|
127 |
+
frame_shift=10)
|
128 |
+
fbanks.append(fbank)
|
129 |
+
fbank = torch.stack(fbanks, dim=0)
|
130 |
+
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
131 |
+
return fbank
|
132 |
+
|
133 |
+
def extract_features(
|
134 |
+
self,
|
135 |
+
source: torch.Tensor,
|
136 |
+
padding_mask: Optional[torch.Tensor]=None,
|
137 |
+
fbank_mean: float=15.41663,
|
138 |
+
fbank_std: float=6.55582, ):
|
139 |
+
fbank = self.preprocess(
|
140 |
+
source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
141 |
+
|
142 |
+
if padding_mask is not None:
|
143 |
+
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
144 |
+
|
145 |
+
fbank = fbank.unsqueeze(1)
|
146 |
+
features = self.patch_embedding(fbank)
|
147 |
+
features = features.reshape(features.shape[0], features.shape[1], -1)
|
148 |
+
features = features.transpose(1, 2)
|
149 |
+
features = self.layer_norm(features)
|
150 |
+
|
151 |
+
if padding_mask is not None:
|
152 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
153 |
+
|
154 |
+
if self.post_extract_proj is not None:
|
155 |
+
features = self.post_extract_proj(features)
|
156 |
+
|
157 |
+
x = self.dropout_input(features)
|
158 |
+
|
159 |
+
x, layer_results = self.encoder(
|
160 |
+
x,
|
161 |
+
padding_mask=padding_mask, )
|
162 |
+
|
163 |
+
if self.predictor is not None:
|
164 |
+
x = self.predictor_dropout(x)
|
165 |
+
logits = self.predictor(x)
|
166 |
+
|
167 |
+
if padding_mask is not None and padding_mask.any():
|
168 |
+
logits[padding_mask] = 0
|
169 |
+
logits = logits.sum(dim=1)
|
170 |
+
logits = logits / (~padding_mask).sum(
|
171 |
+
dim=1).unsqueeze(-1).expand_as(logits)
|
172 |
+
else:
|
173 |
+
logits = logits.mean(dim=1)
|
174 |
+
|
175 |
+
lprobs = torch.sigmoid(logits)
|
176 |
+
|
177 |
+
return lprobs, padding_mask
|
178 |
+
else:
|
179 |
+
return x, padding_mask
|
AR/exps/beats/README.md
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# BEATs
|
3 |
+
|
4 |
+
[**BEATs**](https://arxiv.org/abs/2212.09058): **Audio Pre-Training with Acoustic Tokenizers**
|
5 |
+
|
6 |
+
Official PyTorch implementation and pretrained models of BEATs
|
7 |
+
|
8 |
+
## Pre-Trained and Fine-Tuned Tokenizers and Models
|
9 |
+
Iterations | Tokenizer | Pre-Trained Model | AudioSet Fine-Tuned Model 1 | AudioSet Fine-Tuned Model 2
|
10 |
+
|---|---|---|---|---
|
11 |
+
Iter1 | Random Projection | [BEATs_iter1](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter1 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter1 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
|
12 |
+
Iter2 | [Tokenizer_iter2](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter2](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter2 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter2 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
|
13 |
+
Iter3 | [Tokenizer_iter3](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
|
14 |
+
Iter3+ | [Tokenizer_iter3+ (AS20K)](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3_plus_AS20K.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3+ (AS20K)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS20K) (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS20K) (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
|
15 |
+
Iter3+ | [Tokenizer_iter3+ (AS2M)](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3+ (AS2M)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS2M) (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS2M) (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
|
16 |
+
|
17 |
+
|
18 |
+
### Load Tokenizers
|
19 |
+
|
20 |
+
```python
|
21 |
+
import torch
|
22 |
+
from Tokenizers import TokenizersConfig, Tokenizers
|
23 |
+
|
24 |
+
# load the pre-trained checkpoints
|
25 |
+
checkpoint = torch.load('/path/to/tokenizer.pt')
|
26 |
+
|
27 |
+
cfg = TokenizersConfig(checkpoint['cfg'])
|
28 |
+
BEATs_tokenizer = Tokenizers(cfg)
|
29 |
+
BEATs_tokenizer.load_state_dict(checkpoint['model'])
|
30 |
+
BEATs_tokenizer.eval()
|
31 |
+
|
32 |
+
# tokenize the audio and generate the labels
|
33 |
+
audio_input_16khz = torch.randn(1, 10000)
|
34 |
+
padding_mask = torch.zeros(1, 10000).bool()
|
35 |
+
|
36 |
+
labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)
|
37 |
+
```
|
38 |
+
|
39 |
+
|
40 |
+
### Load Pre-Trained Models
|
41 |
+
|
42 |
+
```python
|
43 |
+
import torch
|
44 |
+
from BEATs import BEATs, BEATsConfig
|
45 |
+
|
46 |
+
# load the pre-trained checkpoints
|
47 |
+
checkpoint = torch.load('/path/to/model.pt')
|
48 |
+
|
49 |
+
cfg = BEATsConfig(checkpoint['cfg'])
|
50 |
+
BEATs_model = BEATs(cfg)
|
51 |
+
BEATs_model.load_state_dict(checkpoint['model'])
|
52 |
+
BEATs_model.eval()
|
53 |
+
|
54 |
+
# extract the the audio representation
|
55 |
+
audio_input_16khz = torch.randn(1, 10000)
|
56 |
+
padding_mask = torch.zeros(1, 10000).bool()
|
57 |
+
|
58 |
+
representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
|
59 |
+
```
|
60 |
+
|
61 |
+
|
62 |
+
### Load Fine-tuned Models
|
63 |
+
|
64 |
+
```python
|
65 |
+
import torch
|
66 |
+
from BEATs import BEATs, BEATsConfig
|
67 |
+
|
68 |
+
# load the fine-tuned checkpoints
|
69 |
+
checkpoint = torch.load('/path/to/model.pt')
|
70 |
+
|
71 |
+
cfg = BEATsConfig(checkpoint['cfg'])
|
72 |
+
BEATs_model = BEATs(cfg)
|
73 |
+
BEATs_model.load_state_dict(checkpoint['model'])
|
74 |
+
BEATs_model.eval()
|
75 |
+
|
76 |
+
# predict the classification probability of each class
|
77 |
+
audio_input_16khz = torch.randn(3, 10000)
|
78 |
+
padding_mask = torch.zeros(3, 10000).bool()
|
79 |
+
|
80 |
+
probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
|
81 |
+
|
82 |
+
for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
|
83 |
+
top5_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
|
84 |
+
print(f'Top 5 predicted labels of the {i}th audio are {top5_label} with probability of {top5_label_prob}')
|
85 |
+
```
|
86 |
+
|
87 |
+
## Evaluation Results
|
88 |
+
|
89 |
+
### Comparing with the SOTA Single Models
|
90 |
+
![alt text](Evaluation_Results/Comparing_with_the_SOTA_Single_Models.png)
|
91 |
+
|
92 |
+
|
93 |
+
### Comparing with the SOTA Ensemble Models
|
94 |
+
![alt text](Evaluation_Results/Comparing_with_the_SOTA_Ensemble_Models.png)
|
95 |
+
|
96 |
+
|
97 |
+
### Comparing Different BEATS Tokenizers
|
98 |
+
![alt text](Evaluation_Results/Comparing_Different_BEATS_Tokenizers.png)
|
99 |
+
|
100 |
+
|
101 |
+
### Comparing Different Pre-Training Targets
|
102 |
+
![alt text](Evaluation_Results/Comparing_Different_Pre-Training_Targets.png)
|
103 |
+
|
104 |
+
|
105 |
+
## License
|
106 |
+
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
|
107 |
+
Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) and [VQGAN](https://github.com/CompVis/taming-transformers) project.
|
108 |
+
|
109 |
+
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
|
110 |
+
|
111 |
+
|
112 |
+
### Reference
|
113 |
+
If you find our work is useful in your research, please cite the following paper:
|
114 |
+
``` latex
|
115 |
+
@article{Chen2022beats,
|
116 |
+
title = {BEATs: Audio Pre-Training with Acoustic Tokenizers},
|
117 |
+
author = {Sanyuan Chen and Yu Wu and Chengyi Wang and Shujie Liu and Daniel Tompkins and Zhuo Chen and Furu Wei},
|
118 |
+
eprint={2212.09058},
|
119 |
+
archivePrefix={arXiv},
|
120 |
+
year={2022}
|
121 |
+
}
|
122 |
+
```
|
123 |
+
### Contact Information
|
124 |
+
|
125 |
+
For help or issues using BEATs models, please submit a GitHub issue.
|
126 |
+
|
127 |
+
For other communications related to BEATs, please contact Yu Wu (`[email protected]`).
|
AR/exps/beats/Tokenizers.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
import logging
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torchaudio.compliance.kaldi as ta_kaldi
|
15 |
+
from backbone import (
|
16 |
+
TransformerEncoder, )
|
17 |
+
from quantizer import (
|
18 |
+
NormEMAVectorQuantizer, )
|
19 |
+
from torch.nn import LayerNorm
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
class TokenizersConfig:
|
25 |
+
def __init__(self, cfg=None):
|
26 |
+
self.input_patch_size: int = -1 # path size of patch embedding
|
27 |
+
self.embed_dim: int = 512 # patch embedding dimension
|
28 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
29 |
+
|
30 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
31 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
32 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
33 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
34 |
+
self.activation_fn: str = "gelu" # activation function to use
|
35 |
+
|
36 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
37 |
+
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
38 |
+
|
39 |
+
# dropouts
|
40 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
41 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
42 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
43 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
44 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
45 |
+
|
46 |
+
# positional embeddings
|
47 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
48 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
49 |
+
|
50 |
+
# relative position embedding
|
51 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
52 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
53 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
54 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
55 |
+
|
56 |
+
# quantizer
|
57 |
+
self.quant_n: int = 1024 # codebook number in quantizer
|
58 |
+
self.quant_dim: int = 256 # codebook dimension in quantizer
|
59 |
+
|
60 |
+
if cfg is not None:
|
61 |
+
self.update(cfg)
|
62 |
+
|
63 |
+
def update(self, cfg: dict):
|
64 |
+
self.__dict__.update(cfg)
|
65 |
+
|
66 |
+
|
67 |
+
class Tokenizers(nn.Module):
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
cfg: TokenizersConfig, ) -> None:
|
71 |
+
super().__init__()
|
72 |
+
logger.info(f"Tokenizers Config: {cfg.__dict__}")
|
73 |
+
|
74 |
+
self.cfg = cfg
|
75 |
+
|
76 |
+
self.embed = cfg.embed_dim
|
77 |
+
self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)
|
78 |
+
if self.embed != cfg.encoder_embed_dim else
|
79 |
+
None)
|
80 |
+
|
81 |
+
self.input_patch_size = cfg.input_patch_size
|
82 |
+
self.patch_embedding = nn.Conv2d(
|
83 |
+
1,
|
84 |
+
self.embed,
|
85 |
+
kernel_size=self.input_patch_size,
|
86 |
+
stride=self.input_patch_size,
|
87 |
+
bias=cfg.conv_bias)
|
88 |
+
|
89 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
90 |
+
|
91 |
+
assert not cfg.deep_norm or not cfg.layer_norm_first
|
92 |
+
self.encoder = TransformerEncoder(cfg)
|
93 |
+
self.layer_norm = LayerNorm(self.embed)
|
94 |
+
|
95 |
+
self.quantize = NormEMAVectorQuantizer(
|
96 |
+
n_embed=cfg.quant_n,
|
97 |
+
embedding_dim=cfg.quant_dim,
|
98 |
+
beta=1.0,
|
99 |
+
kmeans_init=True,
|
100 |
+
decay=0.99, )
|
101 |
+
self.quant_n = cfg.quant_n
|
102 |
+
self.quantize_layer = nn.Sequential(
|
103 |
+
nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
|
104 |
+
nn.Tanh(),
|
105 |
+
nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
|
106 |
+
)
|
107 |
+
|
108 |
+
def forward_padding_mask(
|
109 |
+
self,
|
110 |
+
features: torch.Tensor,
|
111 |
+
padding_mask: torch.Tensor, ) -> torch.Tensor:
|
112 |
+
extra = padding_mask.size(1) % features.size(1)
|
113 |
+
if extra > 0:
|
114 |
+
padding_mask = padding_mask[:, :-extra]
|
115 |
+
padding_mask = padding_mask.view(
|
116 |
+
padding_mask.size(0), features.size(1), -1)
|
117 |
+
padding_mask = padding_mask.all(-1)
|
118 |
+
return padding_mask
|
119 |
+
|
120 |
+
def preprocess(
|
121 |
+
self,
|
122 |
+
source: torch.Tensor,
|
123 |
+
fbank_mean: float=15.41663,
|
124 |
+
fbank_std: float=6.55582, ) -> torch.Tensor:
|
125 |
+
fbanks = []
|
126 |
+
for waveform in source:
|
127 |
+
waveform = waveform.unsqueeze(0) * 2**15
|
128 |
+
fbank = ta_kaldi.fbank(
|
129 |
+
waveform,
|
130 |
+
num_mel_bins=128,
|
131 |
+
sample_frequency=16000,
|
132 |
+
frame_length=25,
|
133 |
+
frame_shift=10)
|
134 |
+
fbanks.append(fbank)
|
135 |
+
fbank = torch.stack(fbanks, dim=0)
|
136 |
+
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
137 |
+
return fbank
|
138 |
+
|
139 |
+
def extract_labels(
|
140 |
+
self,
|
141 |
+
source: torch.Tensor,
|
142 |
+
padding_mask: Optional[torch.Tensor]=None,
|
143 |
+
fbank_mean: float=15.41663,
|
144 |
+
fbank_std: float=6.55582, ):
|
145 |
+
fbank = self.preprocess(
|
146 |
+
source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
147 |
+
|
148 |
+
if padding_mask is not None:
|
149 |
+
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
150 |
+
|
151 |
+
fbank = fbank.unsqueeze(1)
|
152 |
+
features = self.patch_embedding(fbank)
|
153 |
+
features = features.reshape(features.shape[0], features.shape[1], -1)
|
154 |
+
features = features.transpose(1, 2)
|
155 |
+
features = self.layer_norm(features)
|
156 |
+
|
157 |
+
if padding_mask is not None:
|
158 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
159 |
+
|
160 |
+
if self.post_extract_proj is not None:
|
161 |
+
features = self.post_extract_proj(features)
|
162 |
+
|
163 |
+
x = self.dropout_input(features)
|
164 |
+
|
165 |
+
x, layer_results = self.encoder(
|
166 |
+
x,
|
167 |
+
padding_mask=padding_mask, )
|
168 |
+
|
169 |
+
quantize_input = self.quantize_layer(x)
|
170 |
+
quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
|
171 |
+
|
172 |
+
return embed_ind
|
AR/exps/beats/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# this folder is modified from https://github.com/microsoft/unilm/tree/master/beats
|
2 |
+
# ontology.json is from https://github.com/audioset/ontology/
|
AR/exps/beats/backbone.py
ADDED
@@ -0,0 +1,791 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
import math
|
10 |
+
from typing import Dict
|
11 |
+
from typing import Optional
|
12 |
+
from typing import Tuple
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torch import nn
|
18 |
+
from torch import Tensor
|
19 |
+
from torch.nn import LayerNorm
|
20 |
+
from torch.nn import Parameter
|
21 |
+
|
22 |
+
from .modules import get_activation_fn
|
23 |
+
from .modules import GLU_Linear
|
24 |
+
from .modules import GradMultiply
|
25 |
+
from .modules import quant_noise
|
26 |
+
from .modules import SamePad
|
27 |
+
|
28 |
+
|
29 |
+
class TransformerEncoder(nn.Module):
|
30 |
+
def __init__(self, args):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.dropout = args.dropout
|
34 |
+
self.embedding_dim = args.encoder_embed_dim
|
35 |
+
|
36 |
+
self.pos_conv = nn.Conv1d(
|
37 |
+
self.embedding_dim,
|
38 |
+
self.embedding_dim,
|
39 |
+
kernel_size=args.conv_pos,
|
40 |
+
padding=args.conv_pos // 2,
|
41 |
+
groups=args.conv_pos_groups, )
|
42 |
+
dropout = 0
|
43 |
+
std = math.sqrt(
|
44 |
+
(4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
45 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
46 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
47 |
+
|
48 |
+
self.pos_conv = nn.utils.weight_norm(
|
49 |
+
self.pos_conv, name="weight", dim=2)
|
50 |
+
self.pos_conv = nn.Sequential(self.pos_conv,
|
51 |
+
SamePad(args.conv_pos), nn.GELU())
|
52 |
+
|
53 |
+
if hasattr(args, "relative_position_embedding"):
|
54 |
+
self.relative_position_embedding = args.relative_position_embedding
|
55 |
+
self.num_buckets = args.num_buckets
|
56 |
+
self.max_distance = args.max_distance
|
57 |
+
else:
|
58 |
+
self.relative_position_embedding = False
|
59 |
+
self.num_buckets = 0
|
60 |
+
self.max_distance = 0
|
61 |
+
|
62 |
+
self.layers = nn.ModuleList([
|
63 |
+
TransformerSentenceEncoderLayer(
|
64 |
+
embedding_dim=self.embedding_dim,
|
65 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
66 |
+
num_attention_heads=args.encoder_attention_heads,
|
67 |
+
dropout=self.dropout,
|
68 |
+
attention_dropout=args.attention_dropout,
|
69 |
+
activation_dropout=args.activation_dropout,
|
70 |
+
activation_fn=args.activation_fn,
|
71 |
+
layer_norm_first=args.layer_norm_first,
|
72 |
+
deep_norm=args.deep_norm,
|
73 |
+
has_relative_attention_bias=self.relative_position_embedding,
|
74 |
+
num_buckets=self.num_buckets,
|
75 |
+
max_distance=self.max_distance,
|
76 |
+
gru_rel_pos=args.gru_rel_pos,
|
77 |
+
encoder_layers=args.encoder_layers, )
|
78 |
+
for i in range(args.encoder_layers)
|
79 |
+
])
|
80 |
+
if self.relative_position_embedding:
|
81 |
+
for i in range(1, args.encoder_layers):
|
82 |
+
del self.layers[i].self_attn.relative_attention_bias
|
83 |
+
self.layers[i].self_attn.relative_attention_bias = self.layers[
|
84 |
+
0].self_attn.relative_attention_bias
|
85 |
+
|
86 |
+
self.layer_norm_first = args.layer_norm_first
|
87 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
88 |
+
self.layerdrop = args.encoder_layerdrop
|
89 |
+
|
90 |
+
self.apply(init_bert_params)
|
91 |
+
|
92 |
+
if args.deep_norm:
|
93 |
+
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
|
94 |
+
for i in range(args.encoder_layers):
|
95 |
+
nn.init.xavier_normal_(
|
96 |
+
self.layers[i].self_attn.k_proj.weight, gain=1)
|
97 |
+
nn.init.xavier_normal_(
|
98 |
+
self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
|
99 |
+
nn.init.xavier_normal_(
|
100 |
+
self.layers[i].self_attn.q_proj.weight, gain=1)
|
101 |
+
nn.init.xavier_normal_(
|
102 |
+
self.layers[i].self_attn.out_proj.weight,
|
103 |
+
gain=deep_norm_beta)
|
104 |
+
nn.init.xavier_normal_(
|
105 |
+
self.layers[i].fc1.weight, gain=deep_norm_beta)
|
106 |
+
nn.init.xavier_normal_(
|
107 |
+
self.layers[i].fc2.weight, gain=deep_norm_beta)
|
108 |
+
|
109 |
+
self.layer_wise_gradient_decay_ratio = getattr(
|
110 |
+
args, "layer_wise_gradient_decay_ratio", 1)
|
111 |
+
|
112 |
+
def forward(self, x, padding_mask=None, layer=None):
|
113 |
+
x, layer_results = self.extract_features(x, padding_mask, layer)
|
114 |
+
|
115 |
+
if self.layer_norm_first and layer is None:
|
116 |
+
x = self.layer_norm(x)
|
117 |
+
|
118 |
+
return x, layer_results
|
119 |
+
|
120 |
+
def extract_features(self, x, padding_mask=None, tgt_layer=None):
|
121 |
+
|
122 |
+
if padding_mask is not None:
|
123 |
+
x[padding_mask] = 0
|
124 |
+
|
125 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
126 |
+
x_conv = x_conv.transpose(1, 2)
|
127 |
+
x = x + x_conv
|
128 |
+
|
129 |
+
if not self.layer_norm_first:
|
130 |
+
x = self.layer_norm(x)
|
131 |
+
|
132 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
133 |
+
|
134 |
+
# B x T x C -> T x B x C
|
135 |
+
x = x.transpose(0, 1)
|
136 |
+
|
137 |
+
layer_results = []
|
138 |
+
z = None
|
139 |
+
if tgt_layer is not None:
|
140 |
+
layer_results.append((x, z))
|
141 |
+
r = None
|
142 |
+
pos_bias = None
|
143 |
+
for i, layer in enumerate(self.layers):
|
144 |
+
if self.layer_wise_gradient_decay_ratio != 1.0:
|
145 |
+
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
|
146 |
+
dropout_probability = np.random.random()
|
147 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
148 |
+
x, z, pos_bias = layer(
|
149 |
+
x,
|
150 |
+
self_attn_padding_mask=padding_mask,
|
151 |
+
need_weights=False,
|
152 |
+
pos_bias=pos_bias)
|
153 |
+
if tgt_layer is not None:
|
154 |
+
layer_results.append((x, z))
|
155 |
+
if i == tgt_layer:
|
156 |
+
r = x
|
157 |
+
break
|
158 |
+
|
159 |
+
if r is not None:
|
160 |
+
x = r
|
161 |
+
|
162 |
+
# T x B x C -> B x T x C
|
163 |
+
x = x.transpose(0, 1)
|
164 |
+
|
165 |
+
return x, layer_results
|
166 |
+
|
167 |
+
|
168 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
169 |
+
def __init__(
|
170 |
+
self,
|
171 |
+
embedding_dim: float=768,
|
172 |
+
ffn_embedding_dim: float=3072,
|
173 |
+
num_attention_heads: float=8,
|
174 |
+
dropout: float=0.1,
|
175 |
+
attention_dropout: float=0.1,
|
176 |
+
activation_dropout: float=0.1,
|
177 |
+
activation_fn: str="relu",
|
178 |
+
layer_norm_first: bool=False,
|
179 |
+
deep_norm: bool=False,
|
180 |
+
has_relative_attention_bias: bool=False,
|
181 |
+
num_buckets: int=0,
|
182 |
+
max_distance: int=0,
|
183 |
+
rescale_init: bool=False,
|
184 |
+
gru_rel_pos: bool=False,
|
185 |
+
encoder_layers: int=0, ) -> None:
|
186 |
+
|
187 |
+
super().__init__()
|
188 |
+
self.embedding_dim = embedding_dim
|
189 |
+
self.dropout = dropout
|
190 |
+
self.activation_dropout = activation_dropout
|
191 |
+
|
192 |
+
self.activation_name = activation_fn
|
193 |
+
self.activation_fn = get_activation_fn(activation_fn)
|
194 |
+
self.self_attn = MultiheadAttention(
|
195 |
+
self.embedding_dim,
|
196 |
+
num_attention_heads,
|
197 |
+
dropout=attention_dropout,
|
198 |
+
self_attention=True,
|
199 |
+
has_relative_attention_bias=has_relative_attention_bias,
|
200 |
+
num_buckets=num_buckets,
|
201 |
+
max_distance=max_distance,
|
202 |
+
rescale_init=rescale_init,
|
203 |
+
gru_rel_pos=gru_rel_pos, )
|
204 |
+
|
205 |
+
self.dropout1 = nn.Dropout(dropout)
|
206 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
207 |
+
self.dropout3 = nn.Dropout(dropout)
|
208 |
+
|
209 |
+
self.layer_norm_first = layer_norm_first
|
210 |
+
|
211 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
212 |
+
|
213 |
+
if self.activation_name == "glu":
|
214 |
+
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim,
|
215 |
+
"swish")
|
216 |
+
else:
|
217 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
218 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
219 |
+
|
220 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
221 |
+
|
222 |
+
self.deep_norm = deep_norm
|
223 |
+
if self.deep_norm:
|
224 |
+
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
|
225 |
+
else:
|
226 |
+
self.deep_norm_alpha = 1
|
227 |
+
|
228 |
+
def forward(self,
|
229 |
+
x: torch.Tensor,
|
230 |
+
self_attn_mask: torch.Tensor=None,
|
231 |
+
self_attn_padding_mask: torch.Tensor=None,
|
232 |
+
need_weights: bool=False,
|
233 |
+
pos_bias=None):
|
234 |
+
residual = x
|
235 |
+
|
236 |
+
if self.layer_norm_first:
|
237 |
+
x = self.self_attn_layer_norm(x)
|
238 |
+
x, attn, pos_bias = self.self_attn(
|
239 |
+
query=x,
|
240 |
+
key=x,
|
241 |
+
value=x,
|
242 |
+
key_padding_mask=self_attn_padding_mask,
|
243 |
+
need_weights=False,
|
244 |
+
attn_mask=self_attn_mask,
|
245 |
+
position_bias=pos_bias)
|
246 |
+
x = self.dropout1(x)
|
247 |
+
x = residual + x
|
248 |
+
|
249 |
+
residual = x
|
250 |
+
x = self.final_layer_norm(x)
|
251 |
+
if self.activation_name == "glu":
|
252 |
+
x = self.fc1(x)
|
253 |
+
else:
|
254 |
+
x = self.activation_fn(self.fc1(x))
|
255 |
+
x = self.dropout2(x)
|
256 |
+
x = self.fc2(x)
|
257 |
+
x = self.dropout3(x)
|
258 |
+
x = residual + x
|
259 |
+
else:
|
260 |
+
x, attn, pos_bias = self.self_attn(
|
261 |
+
query=x,
|
262 |
+
key=x,
|
263 |
+
value=x,
|
264 |
+
key_padding_mask=self_attn_padding_mask,
|
265 |
+
need_weights=need_weights,
|
266 |
+
attn_mask=self_attn_mask,
|
267 |
+
position_bias=pos_bias)
|
268 |
+
|
269 |
+
x = self.dropout1(x)
|
270 |
+
x = residual * self.deep_norm_alpha + x
|
271 |
+
|
272 |
+
x = self.self_attn_layer_norm(x)
|
273 |
+
|
274 |
+
residual = x
|
275 |
+
if self.activation_name == "glu":
|
276 |
+
x = self.fc1(x)
|
277 |
+
else:
|
278 |
+
x = self.activation_fn(self.fc1(x))
|
279 |
+
x = self.dropout2(x)
|
280 |
+
x = self.fc2(x)
|
281 |
+
x = self.dropout3(x)
|
282 |
+
x = residual * self.deep_norm_alpha + x
|
283 |
+
x = self.final_layer_norm(x)
|
284 |
+
|
285 |
+
return x, attn, pos_bias
|
286 |
+
|
287 |
+
|
288 |
+
class MultiheadAttention(nn.Module):
|
289 |
+
"""Multi-headed attention.
|
290 |
+
|
291 |
+
See "Attention Is All You Need" for more details.
|
292 |
+
"""
|
293 |
+
|
294 |
+
def __init__(
|
295 |
+
self,
|
296 |
+
embed_dim,
|
297 |
+
num_heads,
|
298 |
+
kdim=None,
|
299 |
+
vdim=None,
|
300 |
+
dropout=0.0,
|
301 |
+
bias=True,
|
302 |
+
add_bias_kv=False,
|
303 |
+
add_zero_attn=False,
|
304 |
+
self_attention=False,
|
305 |
+
encoder_decoder_attention=False,
|
306 |
+
q_noise=0.0,
|
307 |
+
qn_block_size=8,
|
308 |
+
has_relative_attention_bias=False,
|
309 |
+
num_buckets=32,
|
310 |
+
max_distance=128,
|
311 |
+
gru_rel_pos=False,
|
312 |
+
rescale_init=False, ):
|
313 |
+
super().__init__()
|
314 |
+
self.embed_dim = embed_dim
|
315 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
316 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
317 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
318 |
+
|
319 |
+
self.num_heads = num_heads
|
320 |
+
self.dropout_module = nn.Dropout(dropout)
|
321 |
+
|
322 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
323 |
+
self.num_buckets = num_buckets
|
324 |
+
self.max_distance = max_distance
|
325 |
+
if self.has_relative_attention_bias:
|
326 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
327 |
+
|
328 |
+
self.head_dim = embed_dim // num_heads
|
329 |
+
self.q_head_dim = self.head_dim
|
330 |
+
self.k_head_dim = self.head_dim
|
331 |
+
assert (self.head_dim * num_heads == self.embed_dim
|
332 |
+
), "embed_dim must be divisible by num_heads"
|
333 |
+
self.scaling = self.head_dim**-0.5
|
334 |
+
|
335 |
+
self.self_attention = self_attention
|
336 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
337 |
+
|
338 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
339 |
+
"Self-attention requires query, key and "
|
340 |
+
"value to be of the same size")
|
341 |
+
|
342 |
+
k_bias = True
|
343 |
+
if rescale_init:
|
344 |
+
k_bias = False
|
345 |
+
|
346 |
+
k_embed_dim = embed_dim
|
347 |
+
q_embed_dim = embed_dim
|
348 |
+
|
349 |
+
self.k_proj = quant_noise(
|
350 |
+
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise,
|
351 |
+
qn_block_size)
|
352 |
+
self.v_proj = quant_noise(
|
353 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
354 |
+
self.q_proj = quant_noise(
|
355 |
+
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise,
|
356 |
+
qn_block_size)
|
357 |
+
|
358 |
+
self.out_proj = quant_noise(
|
359 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
360 |
+
|
361 |
+
if add_bias_kv:
|
362 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
363 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
364 |
+
else:
|
365 |
+
self.bias_k = self.bias_v = None
|
366 |
+
|
367 |
+
self.add_zero_attn = add_zero_attn
|
368 |
+
|
369 |
+
self.gru_rel_pos = gru_rel_pos
|
370 |
+
if self.gru_rel_pos:
|
371 |
+
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
372 |
+
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
373 |
+
|
374 |
+
self.reset_parameters()
|
375 |
+
|
376 |
+
def reset_parameters(self):
|
377 |
+
if self.qkv_same_dim:
|
378 |
+
# Empirically observed the convergence to be much better with
|
379 |
+
# the scaled initialization
|
380 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
381 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
382 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
383 |
+
else:
|
384 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
385 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
386 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
387 |
+
|
388 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
389 |
+
if self.out_proj.bias is not None:
|
390 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
391 |
+
if self.bias_k is not None:
|
392 |
+
nn.init.xavier_normal_(self.bias_k)
|
393 |
+
if self.bias_v is not None:
|
394 |
+
nn.init.xavier_normal_(self.bias_v)
|
395 |
+
if self.has_relative_attention_bias:
|
396 |
+
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
397 |
+
|
398 |
+
def _relative_positions_bucket(self, relative_positions,
|
399 |
+
bidirectional=True):
|
400 |
+
num_buckets = self.num_buckets
|
401 |
+
max_distance = self.max_distance
|
402 |
+
relative_buckets = 0
|
403 |
+
|
404 |
+
if bidirectional:
|
405 |
+
num_buckets = num_buckets // 2
|
406 |
+
relative_buckets += (
|
407 |
+
relative_positions > 0).to(torch.long) * num_buckets
|
408 |
+
relative_positions = torch.abs(relative_positions)
|
409 |
+
else:
|
410 |
+
relative_positions = -torch.min(
|
411 |
+
relative_positions, torch.zeros_like(relative_positions))
|
412 |
+
|
413 |
+
max_exact = num_buckets // 2
|
414 |
+
is_small = relative_positions < max_exact
|
415 |
+
|
416 |
+
relative_postion_if_large = max_exact + (
|
417 |
+
torch.log(relative_positions.float() / max_exact) / math.log(
|
418 |
+
max_distance / max_exact) *
|
419 |
+
(num_buckets - max_exact)).to(torch.long)
|
420 |
+
relative_postion_if_large = torch.min(
|
421 |
+
relative_postion_if_large,
|
422 |
+
torch.full_like(relative_postion_if_large, num_buckets - 1))
|
423 |
+
|
424 |
+
relative_buckets += torch.where(is_small, relative_positions,
|
425 |
+
relative_postion_if_large)
|
426 |
+
return relative_buckets
|
427 |
+
|
428 |
+
def compute_bias(self, query_length, key_length):
|
429 |
+
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
430 |
+
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
431 |
+
relative_position = memory_position - context_position
|
432 |
+
relative_position_bucket = self._relative_positions_bucket(
|
433 |
+
relative_position, bidirectional=True)
|
434 |
+
relative_position_bucket = relative_position_bucket.to(
|
435 |
+
self.relative_attention_bias.weight.device)
|
436 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
437 |
+
values = values.permute([2, 0, 1])
|
438 |
+
return values
|
439 |
+
|
440 |
+
def forward(self,
|
441 |
+
query,
|
442 |
+
key: Optional[Tensor],
|
443 |
+
value: Optional[Tensor],
|
444 |
+
key_padding_mask: Optional[Tensor]=None,
|
445 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[
|
446 |
+
Tensor]]]]=None,
|
447 |
+
need_weights: bool=True,
|
448 |
+
static_kv: bool=False,
|
449 |
+
attn_mask: Optional[Tensor]=None,
|
450 |
+
before_softmax: bool=False,
|
451 |
+
need_head_weights: bool=False,
|
452 |
+
position_bias: Optional[Tensor]=None
|
453 |
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
454 |
+
"""Input shape: Time x Batch x Channel
|
455 |
+
|
456 |
+
Args:
|
457 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
458 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
459 |
+
padding elements are indicated by 1s.
|
460 |
+
need_weights (bool, optional): return the attention weights,
|
461 |
+
averaged over heads (default: False).
|
462 |
+
attn_mask (ByteTensor, optional): typically used to
|
463 |
+
implement causal attention, where the mask prevents the
|
464 |
+
attention from looking forward in time (default: None).
|
465 |
+
before_softmax (bool, optional): return the raw attention
|
466 |
+
weights and values before the attention softmax.
|
467 |
+
need_head_weights (bool, optional): return the attention
|
468 |
+
weights for each head. Implies *need_weights*. Default:
|
469 |
+
return the average attention weights over all heads.
|
470 |
+
"""
|
471 |
+
if need_head_weights:
|
472 |
+
need_weights = True
|
473 |
+
|
474 |
+
is_tpu = query.device.type == "xla"
|
475 |
+
|
476 |
+
tgt_len, bsz, embed_dim = query.size()
|
477 |
+
src_len = tgt_len
|
478 |
+
assert embed_dim == self.embed_dim
|
479 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
480 |
+
if key is not None:
|
481 |
+
src_len, key_bsz, _ = key.size()
|
482 |
+
if not torch.jit.is_scripting():
|
483 |
+
assert key_bsz == bsz
|
484 |
+
assert value is not None
|
485 |
+
assert src_len, bsz == value.shape[:2]
|
486 |
+
|
487 |
+
if self.has_relative_attention_bias and position_bias is None:
|
488 |
+
position_bias = self.compute_bias(tgt_len, src_len)
|
489 |
+
position_bias = position_bias.unsqueeze(0).repeat(
|
490 |
+
bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
491 |
+
|
492 |
+
if incremental_state is not None:
|
493 |
+
saved_state = self._get_input_buffer(incremental_state)
|
494 |
+
if saved_state is not None and "prev_key" in saved_state:
|
495 |
+
# previous time steps are cached - no need to recompute
|
496 |
+
# key and value if they are static
|
497 |
+
if static_kv:
|
498 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
499 |
+
key = value = None
|
500 |
+
else:
|
501 |
+
saved_state = None
|
502 |
+
|
503 |
+
if self.self_attention:
|
504 |
+
q = self.q_proj(query)
|
505 |
+
k = self.k_proj(query)
|
506 |
+
v = self.v_proj(query)
|
507 |
+
elif self.encoder_decoder_attention:
|
508 |
+
# encoder-decoder attention
|
509 |
+
q = self.q_proj(query)
|
510 |
+
if key is None:
|
511 |
+
assert value is None
|
512 |
+
k = v = None
|
513 |
+
else:
|
514 |
+
k = self.k_proj(key)
|
515 |
+
v = self.v_proj(key)
|
516 |
+
|
517 |
+
else:
|
518 |
+
assert key is not None and value is not None
|
519 |
+
q = self.q_proj(query)
|
520 |
+
k = self.k_proj(key)
|
521 |
+
v = self.v_proj(value)
|
522 |
+
q *= self.scaling
|
523 |
+
alpha = 32
|
524 |
+
q *= 1 / alpha
|
525 |
+
|
526 |
+
if self.bias_k is not None:
|
527 |
+
assert self.bias_v is not None
|
528 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
529 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
530 |
+
if attn_mask is not None:
|
531 |
+
attn_mask = torch.cat(
|
532 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)],
|
533 |
+
dim=1)
|
534 |
+
if key_padding_mask is not None:
|
535 |
+
key_padding_mask = torch.cat(
|
536 |
+
[
|
537 |
+
key_padding_mask,
|
538 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
539 |
+
],
|
540 |
+
dim=1, )
|
541 |
+
|
542 |
+
q = (q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
543 |
+
.transpose(0, 1))
|
544 |
+
if k is not None:
|
545 |
+
k = (k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim)
|
546 |
+
.transpose(0, 1))
|
547 |
+
if v is not None:
|
548 |
+
v = (v.contiguous().view(-1, bsz * self.num_heads, self.head_dim)
|
549 |
+
.transpose(0, 1))
|
550 |
+
|
551 |
+
if saved_state is not None:
|
552 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
553 |
+
if "prev_key" in saved_state:
|
554 |
+
_prev_key = saved_state["prev_key"]
|
555 |
+
assert _prev_key is not None
|
556 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1,
|
557 |
+
self.head_dim)
|
558 |
+
if static_kv:
|
559 |
+
k = prev_key
|
560 |
+
else:
|
561 |
+
assert k is not None
|
562 |
+
k = torch.cat([prev_key, k], dim=1)
|
563 |
+
src_len = k.size(1)
|
564 |
+
if "prev_value" in saved_state:
|
565 |
+
_prev_value = saved_state["prev_value"]
|
566 |
+
assert _prev_value is not None
|
567 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1,
|
568 |
+
self.head_dim)
|
569 |
+
if static_kv:
|
570 |
+
v = prev_value
|
571 |
+
else:
|
572 |
+
assert v is not None
|
573 |
+
v = torch.cat([prev_value, v], dim=1)
|
574 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
575 |
+
if "prev_key_padding_mask" in saved_state:
|
576 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
577 |
+
assert k is not None and v is not None
|
578 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
579 |
+
key_padding_mask=key_padding_mask,
|
580 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
581 |
+
batch_size=bsz,
|
582 |
+
src_len=k.size(1),
|
583 |
+
static_kv=static_kv, )
|
584 |
+
|
585 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1,
|
586 |
+
self.head_dim)
|
587 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1,
|
588 |
+
self.head_dim)
|
589 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
590 |
+
# In this branch incremental_state is never None
|
591 |
+
assert incremental_state is not None
|
592 |
+
incremental_state = self._set_input_buffer(incremental_state,
|
593 |
+
saved_state)
|
594 |
+
assert k is not None
|
595 |
+
assert k.size(1) == src_len
|
596 |
+
|
597 |
+
# This is part of a workaround to get around fork/join parallelism
|
598 |
+
# not supporting Optional types.
|
599 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
600 |
+
key_padding_mask = None
|
601 |
+
|
602 |
+
if key_padding_mask is not None:
|
603 |
+
assert key_padding_mask.size(0) == bsz
|
604 |
+
assert key_padding_mask.size(1) == src_len
|
605 |
+
|
606 |
+
if self.add_zero_attn:
|
607 |
+
assert v is not None
|
608 |
+
src_len += 1
|
609 |
+
k = torch.cat(
|
610 |
+
[k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
611 |
+
v = torch.cat(
|
612 |
+
[v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
613 |
+
if attn_mask is not None:
|
614 |
+
attn_mask = torch.cat(
|
615 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)],
|
616 |
+
dim=1)
|
617 |
+
if key_padding_mask is not None:
|
618 |
+
key_padding_mask = torch.cat(
|
619 |
+
[
|
620 |
+
key_padding_mask,
|
621 |
+
torch.zeros(key_padding_mask.size(0),
|
622 |
+
1).type_as(key_padding_mask),
|
623 |
+
],
|
624 |
+
dim=1, )
|
625 |
+
|
626 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
627 |
+
attn_weights = (
|
628 |
+
attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
|
629 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
|
630 |
+
bsz)
|
631 |
+
|
632 |
+
assert list(
|
633 |
+
attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
634 |
+
|
635 |
+
if attn_mask is not None:
|
636 |
+
attn_mask = attn_mask.unsqueeze(0)
|
637 |
+
attn_weights += attn_mask
|
638 |
+
|
639 |
+
if key_padding_mask is not None:
|
640 |
+
# don't attend to padding symbols
|
641 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
|
642 |
+
src_len)
|
643 |
+
if not is_tpu:
|
644 |
+
attn_weights = attn_weights.masked_fill(
|
645 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
646 |
+
float("-inf"), )
|
647 |
+
else:
|
648 |
+
attn_weights = attn_weights.transpose(0, 2)
|
649 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask,
|
650 |
+
float("-inf"))
|
651 |
+
attn_weights = attn_weights.transpose(0, 2)
|
652 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
|
653 |
+
src_len)
|
654 |
+
|
655 |
+
if before_softmax:
|
656 |
+
return attn_weights, v, position_bias
|
657 |
+
|
658 |
+
if position_bias is not None:
|
659 |
+
attn_mask_rel_pos = position_bias
|
660 |
+
if self.gru_rel_pos == 1:
|
661 |
+
query_layer = q.view(bsz, self.num_heads, tgt_len,
|
662 |
+
self.q_head_dim) * alpha / self.scaling
|
663 |
+
_B, _H, _L, __ = query_layer.size()
|
664 |
+
gate_a, gate_b = torch.sigmoid(
|
665 |
+
self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(
|
666 |
+
-1, keepdim=False)).chunk(
|
667 |
+
2, dim=-1)
|
668 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
669 |
+
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len,
|
670 |
+
1) * position_bias
|
671 |
+
|
672 |
+
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
|
673 |
+
|
674 |
+
attn_weights = attn_weights + attn_mask_rel_pos
|
675 |
+
|
676 |
+
attn_weights_float = F.softmax(attn_weights, dim=-1)
|
677 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
678 |
+
attn_probs = self.dropout_module(attn_weights)
|
679 |
+
|
680 |
+
assert v is not None
|
681 |
+
attn = torch.bmm(attn_probs, v)
|
682 |
+
assert list(
|
683 |
+
attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
684 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
685 |
+
attn = self.out_proj(attn)
|
686 |
+
attn_weights: Optional[Tensor] = None
|
687 |
+
if need_weights:
|
688 |
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len,
|
689 |
+
src_len).transpose(1, 0)
|
690 |
+
if not need_head_weights:
|
691 |
+
# average attention weights over heads
|
692 |
+
attn_weights = attn_weights.mean(dim=0)
|
693 |
+
|
694 |
+
return attn, attn_weights, position_bias
|
695 |
+
|
696 |
+
@staticmethod
|
697 |
+
def _append_prev_key_padding_mask(
|
698 |
+
key_padding_mask: Optional[Tensor],
|
699 |
+
prev_key_padding_mask: Optional[Tensor],
|
700 |
+
batch_size: int,
|
701 |
+
src_len: int,
|
702 |
+
static_kv: bool, ) -> Optional[Tensor]:
|
703 |
+
# saved key padding masks have shape (bsz, seq_len)
|
704 |
+
if prev_key_padding_mask is not None and static_kv:
|
705 |
+
new_key_padding_mask = prev_key_padding_mask
|
706 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
707 |
+
new_key_padding_mask = torch.cat(
|
708 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()],
|
709 |
+
dim=1)
|
710 |
+
# During incremental decoding, as the padding token enters and
|
711 |
+
# leaves the frame, there will be a time when prev or current
|
712 |
+
# is None
|
713 |
+
elif prev_key_padding_mask is not None:
|
714 |
+
if src_len > prev_key_padding_mask.size(1):
|
715 |
+
filler = torch.zeros(
|
716 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
717 |
+
device=prev_key_padding_mask.device, )
|
718 |
+
new_key_padding_mask = torch.cat(
|
719 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1)
|
720 |
+
else:
|
721 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
722 |
+
elif key_padding_mask is not None:
|
723 |
+
if src_len > key_padding_mask.size(1):
|
724 |
+
filler = torch.zeros(
|
725 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
726 |
+
device=key_padding_mask.device, )
|
727 |
+
new_key_padding_mask = torch.cat(
|
728 |
+
[filler.float(), key_padding_mask.float()], dim=1)
|
729 |
+
else:
|
730 |
+
new_key_padding_mask = key_padding_mask.float()
|
731 |
+
else:
|
732 |
+
new_key_padding_mask = prev_key_padding_mask
|
733 |
+
return new_key_padding_mask
|
734 |
+
|
735 |
+
def _get_input_buffer(
|
736 |
+
self,
|
737 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
738 |
+
) -> Dict[str, Optional[Tensor]]:
|
739 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
740 |
+
if result is not None:
|
741 |
+
return result
|
742 |
+
else:
|
743 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
744 |
+
return empty_result
|
745 |
+
|
746 |
+
def _set_input_buffer(
|
747 |
+
self,
|
748 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
749 |
+
buffer: Dict[str, Optional[Tensor]], ):
|
750 |
+
return self.set_incremental_state(incremental_state, "attn_state",
|
751 |
+
buffer)
|
752 |
+
|
753 |
+
def apply_sparse_mask(self,
|
754 |
+
attn_weights,
|
755 |
+
tgt_len: int,
|
756 |
+
src_len: int,
|
757 |
+
bsz: int):
|
758 |
+
return attn_weights
|
759 |
+
|
760 |
+
|
761 |
+
def init_bert_params(module):
|
762 |
+
"""
|
763 |
+
Initialize the weights specific to the BERT Model.
|
764 |
+
This overrides the default initializations depending on the specified arguments.
|
765 |
+
1. If normal_init_linear_weights is set then weights of linear
|
766 |
+
layer will be initialized using the normal distribution and
|
767 |
+
bais will be set to the specified value.
|
768 |
+
2. If normal_init_embed_weights is set then weights of embedding
|
769 |
+
layer will be initialized using the normal distribution.
|
770 |
+
3. If normal_init_proj_weights is set then weights of
|
771 |
+
in_project_weight for MultiHeadAttention initialized using
|
772 |
+
the normal distribution (to be validated).
|
773 |
+
"""
|
774 |
+
|
775 |
+
def normal_(data):
|
776 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
777 |
+
# so that the RNG is consistent with and without FSDP
|
778 |
+
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
779 |
+
|
780 |
+
if isinstance(module, nn.Linear):
|
781 |
+
normal_(module.weight.data)
|
782 |
+
if module.bias is not None:
|
783 |
+
module.bias.data.zero_()
|
784 |
+
if isinstance(module, nn.Embedding):
|
785 |
+
normal_(module.weight.data)
|
786 |
+
if module.padding_idx is not None:
|
787 |
+
module.weight.data[module.padding_idx].zero_()
|
788 |
+
if isinstance(module, MultiheadAttention):
|
789 |
+
normal_(module.q_proj.weight.data)
|
790 |
+
normal_(module.k_proj.weight.data)
|
791 |
+
normal_(module.v_proj.weight.data)
|
AR/exps/beats/config.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
# 获取当前脚本的所在目录
|
5 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
|
7 |
+
# JSON 文件的文件名
|
8 |
+
json_filename = "ontology.json"
|
9 |
+
|
10 |
+
# 构建 JSON 文件的完整路径
|
11 |
+
json_path = os.path.join(script_dir, json_filename)
|
12 |
+
|
13 |
+
id_name_dict = {}
|
14 |
+
|
15 |
+
with open(json_path, 'r') as f:
|
16 |
+
json_items = json.load(f)
|
17 |
+
# '/m/0dgw9r' -> 'Human sounds' and etc.
|
18 |
+
for item in json_items:
|
19 |
+
id_name_dict[item['id']] = item['name']
|
AR/exps/beats/modules.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
import math
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
|
17 |
+
class GradMultiply(torch.autograd.Function):
|
18 |
+
@staticmethod
|
19 |
+
def forward(ctx, x, scale):
|
20 |
+
ctx.scale = scale
|
21 |
+
res = x.new(x)
|
22 |
+
return res
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def backward(ctx, grad):
|
26 |
+
return grad * ctx.scale, None
|
27 |
+
|
28 |
+
|
29 |
+
class SamePad(nn.Module):
|
30 |
+
def __init__(self, kernel_size, causal=False):
|
31 |
+
super().__init__()
|
32 |
+
if causal:
|
33 |
+
self.remove = kernel_size - 1
|
34 |
+
else:
|
35 |
+
self.remove = 1 if kernel_size % 2 == 0 else 0
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
if self.remove > 0:
|
39 |
+
x = x[:, :, :-self.remove]
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class Swish(nn.Module):
|
44 |
+
def __init__(self):
|
45 |
+
super(Swish, self).__init__()
|
46 |
+
self.act = torch.nn.Sigmoid()
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
return x * self.act(x)
|
50 |
+
|
51 |
+
|
52 |
+
class GLU_Linear(nn.Module):
|
53 |
+
def __init__(self,
|
54 |
+
input_dim,
|
55 |
+
output_dim,
|
56 |
+
glu_type="sigmoid",
|
57 |
+
bias_in_glu=True):
|
58 |
+
super(GLU_Linear, self).__init__()
|
59 |
+
|
60 |
+
self.glu_type = glu_type
|
61 |
+
self.output_dim = output_dim
|
62 |
+
|
63 |
+
if glu_type == "sigmoid":
|
64 |
+
self.glu_act = torch.nn.Sigmoid()
|
65 |
+
elif glu_type == "swish":
|
66 |
+
self.glu_act = Swish()
|
67 |
+
elif glu_type == "relu":
|
68 |
+
self.glu_act = torch.nn.ReLU()
|
69 |
+
elif glu_type == "gelu":
|
70 |
+
self.glu_act = torch.nn.GELU()
|
71 |
+
|
72 |
+
if bias_in_glu:
|
73 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
74 |
+
else:
|
75 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
79 |
+
x = self.linear(x)
|
80 |
+
|
81 |
+
if self.glu_type == "bilinear":
|
82 |
+
x = (x[:, :, 0:self.output_dim] *
|
83 |
+
x[:, :, self.output_dim:self.output_dim * 2])
|
84 |
+
else:
|
85 |
+
x = (x[:, :, 0:self.output_dim] *
|
86 |
+
self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
87 |
+
|
88 |
+
return x
|
89 |
+
|
90 |
+
|
91 |
+
def gelu_accurate(x):
|
92 |
+
if not hasattr(gelu_accurate, "_a"):
|
93 |
+
gelu_accurate._a = math.sqrt(2 / math.pi)
|
94 |
+
return (0.5 * x * (1 + torch.tanh(gelu_accurate._a *
|
95 |
+
(x + 0.044715 * torch.pow(x, 3)))))
|
96 |
+
|
97 |
+
|
98 |
+
def gelu(x: torch.Tensor) -> torch.Tensor:
|
99 |
+
return torch.nn.functional.gelu(x.float()).type_as(x)
|
100 |
+
|
101 |
+
|
102 |
+
def get_activation_fn(activation: str):
|
103 |
+
"""Returns the activation function corresponding to `activation`"""
|
104 |
+
|
105 |
+
if activation == "relu":
|
106 |
+
return F.relu
|
107 |
+
elif activation == "gelu":
|
108 |
+
return gelu
|
109 |
+
elif activation == "gelu_fast":
|
110 |
+
warnings.warn(
|
111 |
+
"--activation-fn=gelu_fast has been renamed to gelu_accurate")
|
112 |
+
return gelu_accurate
|
113 |
+
elif activation == "gelu_accurate":
|
114 |
+
return gelu_accurate
|
115 |
+
elif activation == "tanh":
|
116 |
+
return torch.tanh
|
117 |
+
elif activation == "linear":
|
118 |
+
return lambda x: x
|
119 |
+
elif activation == "glu":
|
120 |
+
return lambda x: x
|
121 |
+
else:
|
122 |
+
raise RuntimeError(
|
123 |
+
"--activation-fn {} not supported".format(activation))
|
124 |
+
|
125 |
+
|
126 |
+
def quant_noise(module, p, block_size):
|
127 |
+
"""
|
128 |
+
Wraps modules and applies quantization noise to the weights for
|
129 |
+
subsequent quantization with Iterative Product Quantization as
|
130 |
+
described in "Training with Quantization Noise for Extreme Model Compression"
|
131 |
+
|
132 |
+
Args:
|
133 |
+
- module: nn.Module
|
134 |
+
- p: amount of Quantization Noise
|
135 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
136 |
+
|
137 |
+
Remarks:
|
138 |
+
- Module weights must have the right sizes wrt the block size
|
139 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
140 |
+
- For more detail on how to quantize by blocks with convolutional weights,
|
141 |
+
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
142 |
+
- We implement the simplest form of noise here as stated in the paper
|
143 |
+
which consists in randomly dropping blocks
|
144 |
+
"""
|
145 |
+
|
146 |
+
# if no quantization noise, don't register hook
|
147 |
+
if p <= 0:
|
148 |
+
return module
|
149 |
+
|
150 |
+
# supported modules
|
151 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
152 |
+
|
153 |
+
# test whether module.weight has the right sizes wrt block_size
|
154 |
+
is_conv = module.weight.ndim == 4
|
155 |
+
|
156 |
+
# 2D matrix
|
157 |
+
if not is_conv:
|
158 |
+
assert (
|
159 |
+
module.weight.size(1) %
|
160 |
+
block_size == 0), "Input features must be a multiple of block sizes"
|
161 |
+
|
162 |
+
# 4D matrix
|
163 |
+
else:
|
164 |
+
# 1x1 convolutions
|
165 |
+
if module.kernel_size == (1, 1):
|
166 |
+
assert (module.in_channels % block_size == 0
|
167 |
+
), "Input channels must be a multiple of block sizes"
|
168 |
+
# regular convolutions
|
169 |
+
else:
|
170 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
171 |
+
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
172 |
+
|
173 |
+
def _forward_pre_hook(mod, input):
|
174 |
+
# no noise for evaluation
|
175 |
+
if mod.training:
|
176 |
+
if not is_conv:
|
177 |
+
# gather weight and sizes
|
178 |
+
weight = mod.weight
|
179 |
+
in_features = weight.size(1)
|
180 |
+
out_features = weight.size(0)
|
181 |
+
|
182 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
183 |
+
mask = torch.zeros(
|
184 |
+
in_features // block_size * out_features,
|
185 |
+
device=weight.device)
|
186 |
+
mask.bernoulli_(p)
|
187 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1,
|
188 |
+
in_features)
|
189 |
+
|
190 |
+
else:
|
191 |
+
# gather weight and sizes
|
192 |
+
weight = mod.weight
|
193 |
+
in_channels = mod.in_channels
|
194 |
+
out_channels = mod.out_channels
|
195 |
+
|
196 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
197 |
+
if mod.kernel_size == (1, 1):
|
198 |
+
mask = torch.zeros(
|
199 |
+
int(in_channels // block_size * out_channels),
|
200 |
+
device=weight.device, )
|
201 |
+
mask.bernoulli_(p)
|
202 |
+
mask = mask.repeat_interleave(block_size, -1).view(
|
203 |
+
-1, in_channels)
|
204 |
+
else:
|
205 |
+
mask = torch.zeros(
|
206 |
+
weight.size(0), weight.size(1), device=weight.device)
|
207 |
+
mask.bernoulli_(p)
|
208 |
+
mask = (
|
209 |
+
mask.unsqueeze(2).unsqueeze(3)
|
210 |
+
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]))
|
211 |
+
|
212 |
+
# scale weights and apply mask
|
213 |
+
mask = mask.to(
|
214 |
+
torch.
|
215 |
+
bool) # x.bool() is not currently supported in TorchScript
|
216 |
+
s = 1 / (1 - p)
|
217 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
218 |
+
|
219 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
220 |
+
return module
|
AR/exps/beats/ontology.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
AR/exps/beats/quantizer.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on VQGAN code bases
|
7 |
+
# https://github.com/CompVis/taming-transformers
|
8 |
+
# --------------------------------------------------------'
|
9 |
+
import torch
|
10 |
+
import torch.distributed as distributed
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
try:
|
15 |
+
from einops import rearrange, repeat
|
16 |
+
except ImportError:
|
17 |
+
pass
|
18 |
+
|
19 |
+
|
20 |
+
def l2norm(t):
|
21 |
+
return F.normalize(t, p=2, dim=-1)
|
22 |
+
|
23 |
+
|
24 |
+
def ema_inplace(moving_avg, new, decay):
|
25 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
26 |
+
|
27 |
+
|
28 |
+
def sample_vectors(samples, num):
|
29 |
+
num_samples, device = samples.shape[0], samples.device
|
30 |
+
|
31 |
+
if num_samples >= num:
|
32 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
33 |
+
else:
|
34 |
+
indices = torch.randint(0, num_samples, (num, ), device=device)
|
35 |
+
|
36 |
+
return samples[indices]
|
37 |
+
|
38 |
+
|
39 |
+
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
|
40 |
+
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
|
41 |
+
|
42 |
+
means = sample_vectors(samples, num_clusters)
|
43 |
+
|
44 |
+
for _ in range(num_iters):
|
45 |
+
if use_cosine_sim:
|
46 |
+
dists = samples @ means.t()
|
47 |
+
else:
|
48 |
+
diffs = rearrange(samples, 'n d -> n () d') \
|
49 |
+
- rearrange(means, 'c d -> () c d')
|
50 |
+
dists = -(diffs**2).sum(dim=-1)
|
51 |
+
|
52 |
+
buckets = dists.max(dim=-1).indices
|
53 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
54 |
+
zero_mask = bins == 0
|
55 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
56 |
+
|
57 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
58 |
+
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
|
59 |
+
new_means = new_means / bins_min_clamped[..., None]
|
60 |
+
|
61 |
+
if use_cosine_sim:
|
62 |
+
new_means = l2norm(new_means)
|
63 |
+
|
64 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
65 |
+
|
66 |
+
return means, bins
|
67 |
+
|
68 |
+
|
69 |
+
class EmbeddingEMA(nn.Module):
|
70 |
+
def __init__(self,
|
71 |
+
num_tokens,
|
72 |
+
codebook_dim,
|
73 |
+
decay=0.99,
|
74 |
+
eps=1e-5,
|
75 |
+
kmeans_init=True,
|
76 |
+
codebook_init_path=''):
|
77 |
+
super().__init__()
|
78 |
+
self.num_tokens = num_tokens
|
79 |
+
self.codebook_dim = codebook_dim
|
80 |
+
self.decay = decay
|
81 |
+
self.eps = eps
|
82 |
+
if codebook_init_path == '':
|
83 |
+
if not kmeans_init:
|
84 |
+
weight = torch.randn(num_tokens, codebook_dim)
|
85 |
+
weight = l2norm(weight)
|
86 |
+
else:
|
87 |
+
weight = torch.zeros(num_tokens, codebook_dim)
|
88 |
+
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
89 |
+
else:
|
90 |
+
print(f"load init codebook weight from {codebook_init_path}")
|
91 |
+
codebook_ckpt_weight = torch.load(
|
92 |
+
codebook_init_path, map_location='cpu')
|
93 |
+
weight = codebook_ckpt_weight.clone()
|
94 |
+
self.register_buffer('initted', torch.Tensor([True]))
|
95 |
+
|
96 |
+
self.weight = nn.Parameter(weight, requires_grad=False)
|
97 |
+
self.cluster_size = nn.Parameter(
|
98 |
+
torch.zeros(num_tokens), requires_grad=False)
|
99 |
+
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
|
100 |
+
# self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
101 |
+
self.update = True
|
102 |
+
|
103 |
+
@torch.jit.ignore
|
104 |
+
def init_embed_(self, data):
|
105 |
+
if self.initted:
|
106 |
+
return
|
107 |
+
print("Performing Kemans init for codebook")
|
108 |
+
embed, cluster_size = kmeans(
|
109 |
+
data, self.num_tokens, 10, use_cosine_sim=True)
|
110 |
+
self.weight.data.copy_(embed)
|
111 |
+
self.cluster_size.data.copy_(cluster_size)
|
112 |
+
self.initted.data.copy_(torch.Tensor([True]))
|
113 |
+
|
114 |
+
def forward(self, embed_id):
|
115 |
+
return F.embedding(embed_id, self.weight)
|
116 |
+
|
117 |
+
def cluster_size_ema_update(self, new_cluster_size):
|
118 |
+
self.cluster_size.data.mul_(self.decay).add_(
|
119 |
+
new_cluster_size, alpha=1 - self.decay)
|
120 |
+
|
121 |
+
def embed_avg_ema_update(self, new_embed_avg):
|
122 |
+
self.embed_avg.data.mul_(self.decay).add_(
|
123 |
+
new_embed_avg, alpha=1 - self.decay)
|
124 |
+
|
125 |
+
def weight_update(self, num_tokens):
|
126 |
+
n = self.cluster_size.sum()
|
127 |
+
smoothed_cluster_size = (
|
128 |
+
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n)
|
129 |
+
# normalize embedding average with smoothed cluster size
|
130 |
+
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
131 |
+
# embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
|
132 |
+
self.weight.data.copy_(embed_normalized)
|
133 |
+
|
134 |
+
|
135 |
+
def norm_ema_inplace(moving_avg, new, decay):
|
136 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
137 |
+
moving_avg.data.copy_(l2norm(moving_avg.data))
|
138 |
+
|
139 |
+
|
140 |
+
class NormEMAVectorQuantizer(nn.Module):
|
141 |
+
def __init__(self,
|
142 |
+
n_embed,
|
143 |
+
embedding_dim,
|
144 |
+
beta,
|
145 |
+
decay=0.99,
|
146 |
+
eps=1e-5,
|
147 |
+
statistic_code_usage=True,
|
148 |
+
kmeans_init=False,
|
149 |
+
codebook_init_path=''):
|
150 |
+
super().__init__()
|
151 |
+
self.codebook_dim = embedding_dim
|
152 |
+
self.num_tokens = n_embed
|
153 |
+
self.beta = beta
|
154 |
+
self.decay = decay
|
155 |
+
|
156 |
+
# learnable = True if orthogonal_reg_weight > 0 else False
|
157 |
+
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay,
|
158 |
+
eps, kmeans_init, codebook_init_path)
|
159 |
+
|
160 |
+
self.statistic_code_usage = statistic_code_usage
|
161 |
+
if statistic_code_usage:
|
162 |
+
self.register_buffer('cluster_size', torch.zeros(n_embed))
|
163 |
+
if distributed.is_available() and distributed.is_initialized():
|
164 |
+
print(
|
165 |
+
"ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!"
|
166 |
+
)
|
167 |
+
self.all_reduce_fn = distributed.all_reduce
|
168 |
+
else:
|
169 |
+
self.all_reduce_fn = nn.Identity()
|
170 |
+
|
171 |
+
def reset_cluster_size(self, device):
|
172 |
+
if self.statistic_code_usage:
|
173 |
+
self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
|
174 |
+
self.cluster_size = self.cluster_size.to(device)
|
175 |
+
|
176 |
+
def forward(self, z):
|
177 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
178 |
+
# z, 'b c h w -> b h w c'
|
179 |
+
# z = rearrange(z, 'b c h w -> b h w c')
|
180 |
+
# z = z.transpose(1, 2)
|
181 |
+
z = l2norm(z)
|
182 |
+
z_flattened = z.reshape(-1, self.codebook_dim)
|
183 |
+
|
184 |
+
self.embedding.init_embed_(z_flattened)
|
185 |
+
|
186 |
+
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
187 |
+
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
188 |
+
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
|
189 |
+
|
190 |
+
encoding_indices = torch.argmin(d, dim=1)
|
191 |
+
|
192 |
+
z_q = self.embedding(encoding_indices).view(z.shape)
|
193 |
+
|
194 |
+
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
195 |
+
|
196 |
+
if not self.training:
|
197 |
+
with torch.no_grad():
|
198 |
+
cluster_size = encodings.sum(0)
|
199 |
+
self.all_reduce_fn(cluster_size)
|
200 |
+
ema_inplace(self.cluster_size, cluster_size, self.decay)
|
201 |
+
|
202 |
+
if self.training and self.embedding.update:
|
203 |
+
# EMA cluster size
|
204 |
+
|
205 |
+
bins = encodings.sum(0)
|
206 |
+
self.all_reduce_fn(bins)
|
207 |
+
|
208 |
+
# self.embedding.cluster_size_ema_update(bins)
|
209 |
+
ema_inplace(self.cluster_size, bins, self.decay)
|
210 |
+
|
211 |
+
zero_mask = (bins == 0)
|
212 |
+
bins = bins.masked_fill(zero_mask, 1.)
|
213 |
+
|
214 |
+
embed_sum = z_flattened.t() @ encodings
|
215 |
+
self.all_reduce_fn(embed_sum)
|
216 |
+
|
217 |
+
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
|
218 |
+
embed_normalized = l2norm(embed_normalized)
|
219 |
+
|
220 |
+
embed_normalized = torch.where(
|
221 |
+
zero_mask[..., None], self.embedding.weight, embed_normalized)
|
222 |
+
norm_ema_inplace(self.embedding.weight, embed_normalized,
|
223 |
+
self.decay)
|
224 |
+
|
225 |
+
# compute loss for embedding
|
226 |
+
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
227 |
+
|
228 |
+
# preserve gradients
|
229 |
+
z_q = z + (z_q - z).detach()
|
230 |
+
|
231 |
+
# reshape back to match original input shape
|
232 |
+
# z_q, 'b h w c -> b c h w'
|
233 |
+
# z_q = rearrange(z_q, 'b h w c -> b c h w')
|
234 |
+
# z_q = z_q.transpose(1, 2)
|
235 |
+
return z_q, loss, encoding_indices
|
AR/exps/get_beats_librilight.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use AudioTag tool BEATs to filter out audios who's top1 tag is not 'speech'
|
2 |
+
# non_speech.npy, 存储一个 python dict 表示非 speech 类型的音频的 tag, 更小,加载和搜索速度更快
|
3 |
+
# audio_tag 目录存储 {utt_id}.txt, 第一行是小写的 top1 tag
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import traceback
|
8 |
+
from concurrent.futures import ThreadPoolExecutor
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import librosa
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import tqdm
|
15 |
+
from AR.exps.beats.BEATs import BEATs
|
16 |
+
from AR.exps.beats.BEATs import BEATsConfig
|
17 |
+
from AR.exps.beats.config import id_name_dict
|
18 |
+
from soundstorm.s2.exps.hubert.feature_utils import get_shard_range
|
19 |
+
from soundstorm.utils import check_txt_file
|
20 |
+
|
21 |
+
|
22 |
+
def get_BEATs_top1(wav,
|
23 |
+
BEATs_model,
|
24 |
+
BEATs_label_dict,
|
25 |
+
device: str='cpu',
|
26 |
+
topk: int=1):
|
27 |
+
wav = torch.tensor(wav).unsqueeze(0).to(device)
|
28 |
+
padding_mask = torch.zeros(wav.shape).bool().to(device)
|
29 |
+
probs = BEATs_model.extract_features(wav, padding_mask=padding_mask)[0]
|
30 |
+
# 单条推理
|
31 |
+
probs = probs[0]
|
32 |
+
topk_label_prob, topk_label_idx = probs.topk(k=topk)
|
33 |
+
topk_label = [
|
34 |
+
BEATs_label_dict[label_idx.item()] for label_idx in topk_label_idx
|
35 |
+
]
|
36 |
+
topk_label_name = [id_name_dict[label] for label in topk_label]
|
37 |
+
top1_label = topk_label_name[0]
|
38 |
+
return top1_label
|
39 |
+
|
40 |
+
|
41 |
+
def process_sentence(args,
|
42 |
+
fp: Path,
|
43 |
+
train_dump_dir: Path,
|
44 |
+
dev_dump_dir: Path,
|
45 |
+
test_dump_dir: Path,
|
46 |
+
VAD_dict,
|
47 |
+
BEATs_model,
|
48 |
+
BEATs_label_dict,
|
49 |
+
device: str='cpu'):
|
50 |
+
utt_id = fp.stem
|
51 |
+
sr = args.sr
|
52 |
+
record = []
|
53 |
+
train_audio_tag_dir = train_dump_dir / "audio_tag"
|
54 |
+
train_audio_tag_dir.mkdir(parents=True, exist_ok=True)
|
55 |
+
|
56 |
+
dev_audio_tag_dir = dev_dump_dir / "audio_tag"
|
57 |
+
dev_audio_tag_dir.mkdir(parents=True, exist_ok=True)
|
58 |
+
|
59 |
+
test_audio_tag_dir = test_dump_dir / "audio_tag"
|
60 |
+
test_audio_tag_dir.mkdir(parents=True, exist_ok=True)
|
61 |
+
|
62 |
+
try:
|
63 |
+
# get info for path
|
64 |
+
wav_path_list = str(fp).strip().split('/')
|
65 |
+
sub_dataset, spk_id, book_name = wav_path_list[-4], wav_path_list[
|
66 |
+
-3], wav_path_list[-2]
|
67 |
+
wav_name = wav_path_list[-1][:-5]
|
68 |
+
assert wav_name == utt_id
|
69 |
+
# key_name for big wav
|
70 |
+
key_name = f'{wav_name}#{sub_dataset}#{spk_id}#{book_name}'
|
71 |
+
# 判断 VAD 字典中不存在该条音频信息的情况
|
72 |
+
if key_name not in VAD_dict.keys():
|
73 |
+
print(key_name, 'not in VAD_dict !')
|
74 |
+
return record
|
75 |
+
wav = None
|
76 |
+
sorted_split_VAD_dict = sorted(VAD_dict[key_name].items())
|
77 |
+
len_dict = len(sorted_split_VAD_dict)
|
78 |
+
for index, item in enumerate(sorted_split_VAD_dict):
|
79 |
+
split_name, value = item
|
80 |
+
start, end = value
|
81 |
+
# train | dev | test
|
82 |
+
if index == len_dict - 1:
|
83 |
+
subset = 'test'
|
84 |
+
audio_tag_path = test_audio_tag_dir / (split_name + ".txt")
|
85 |
+
elif index == len_dict - 2:
|
86 |
+
subset = 'dev'
|
87 |
+
audio_tag_path = dev_audio_tag_dir / (split_name + ".txt")
|
88 |
+
else:
|
89 |
+
subset = 'train'
|
90 |
+
audio_tag_path = train_audio_tag_dir / (split_name + ".txt")
|
91 |
+
|
92 |
+
if os.path.exists(audio_tag_path) and check_txt_file(
|
93 |
+
audio_tag_path):
|
94 |
+
# print(audio_tag_path, 'exits!')
|
95 |
+
pass
|
96 |
+
else:
|
97 |
+
# 这里加判断保证在 sub wav 的循环中只 load 一次
|
98 |
+
if wav is None:
|
99 |
+
# load big wav
|
100 |
+
# 在最外层 load 如果 sub wav 的特征都存在了就会白白消耗 load 的时间
|
101 |
+
wav, _ = librosa.load(str(fp), sr=sr)
|
102 |
+
sub_wav = wav[int(start * sr):int(end * sr)]
|
103 |
+
audio_tag_top1 = get_BEATs_top1(
|
104 |
+
wav=sub_wav,
|
105 |
+
BEATs_model=BEATs_model,
|
106 |
+
BEATs_label_dict=BEATs_label_dict,
|
107 |
+
device=device)
|
108 |
+
|
109 |
+
with open(audio_tag_path, 'w') as f:
|
110 |
+
f.write(audio_tag_top1)
|
111 |
+
|
112 |
+
sub_record = {
|
113 |
+
"utt_id": split_name,
|
114 |
+
"audio_tag_path": audio_tag_path,
|
115 |
+
"subset": subset
|
116 |
+
}
|
117 |
+
# recodrd 变成 List of Dict
|
118 |
+
record.append(sub_record)
|
119 |
+
except Exception:
|
120 |
+
print("occur Exception")
|
121 |
+
traceback.print_exc()
|
122 |
+
# record 有可能是一个不完整的 List
|
123 |
+
return record
|
124 |
+
return record
|
125 |
+
|
126 |
+
|
127 |
+
def process_sentences(args,
|
128 |
+
fps: Path,
|
129 |
+
train_dump_dir: Path,
|
130 |
+
dev_dump_dir: Path,
|
131 |
+
test_dump_dir: Path,
|
132 |
+
VAD_dict,
|
133 |
+
BEATs_model,
|
134 |
+
BEATs_label_dict,
|
135 |
+
device: str='cpu',
|
136 |
+
nprocs: int=1):
|
137 |
+
print("nprocs:", nprocs)
|
138 |
+
if nprocs == 1:
|
139 |
+
results = []
|
140 |
+
for fp in tqdm.tqdm(fps, total=len(fps)):
|
141 |
+
record = process_sentence(
|
142 |
+
args=args,
|
143 |
+
fp=fp,
|
144 |
+
train_dump_dir=train_dump_dir,
|
145 |
+
dev_dump_dir=dev_dump_dir,
|
146 |
+
test_dump_dir=test_dump_dir,
|
147 |
+
VAD_dict=VAD_dict,
|
148 |
+
BEATs_model=BEATs_model,
|
149 |
+
BEATs_label_dict=BEATs_label_dict,
|
150 |
+
device=device)
|
151 |
+
if record:
|
152 |
+
results.append(record)
|
153 |
+
else:
|
154 |
+
with ThreadPoolExecutor(nprocs) as pool:
|
155 |
+
futures = []
|
156 |
+
with tqdm.tqdm(total=len(fps)) as progress:
|
157 |
+
for fp in fps:
|
158 |
+
future = pool.submit(process_sentence, args, fp,
|
159 |
+
train_dump_dir, dev_dump_dir,
|
160 |
+
test_dump_dir, VAD_dict, BEATs_model,
|
161 |
+
BEATs_label_dict, device)
|
162 |
+
future.add_done_callback(lambda p: progress.update())
|
163 |
+
futures.append(future)
|
164 |
+
|
165 |
+
results = []
|
166 |
+
for ft in futures:
|
167 |
+
record = ft.result()
|
168 |
+
if record:
|
169 |
+
results.append(record)
|
170 |
+
|
171 |
+
# torch.save() to a large `.pth` file
|
172 |
+
non_speech_dict = dict()
|
173 |
+
non_speech_dict['train'] = {}
|
174 |
+
non_speech_dict['dev'] = {}
|
175 |
+
non_speech_dict['test'] = {}
|
176 |
+
# record 是 List of Dict, 一条大 wav 一个 record,一条小 wav 一个 sub_recored
|
177 |
+
print(f"start to save {args.rank}_{args.nshard}.npy ...")
|
178 |
+
save_start_time = time.time()
|
179 |
+
for record in tqdm.tqdm(results, total=len(results), colour='green'):
|
180 |
+
for sub_record in record:
|
181 |
+
# 这里加 try, 因为 txt 文件可能损坏
|
182 |
+
try:
|
183 |
+
utt_id = sub_record["utt_id"]
|
184 |
+
subset = sub_record["subset"]
|
185 |
+
audio_tag_top1 = check_txt_file(sub_record["audio_tag_path"])
|
186 |
+
if audio_tag_top1 is not False:
|
187 |
+
if 'speech' not in audio_tag_top1.lower():
|
188 |
+
non_speech_dict[subset][utt_id] = audio_tag_top1
|
189 |
+
else:
|
190 |
+
# print(f'audio tag result of {utt_id} is speech')
|
191 |
+
pass
|
192 |
+
else:
|
193 |
+
print(f'audio tag result of {utt_id} is False')
|
194 |
+
except Exception:
|
195 |
+
print(f"{utt_id} occur Exception")
|
196 |
+
traceback.print_exc()
|
197 |
+
continue
|
198 |
+
|
199 |
+
train_filename = train_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
|
200 |
+
dev_filename = dev_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
|
201 |
+
test_filename = test_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
|
202 |
+
np.save(train_filename, non_speech_dict['train'])
|
203 |
+
print(f"npy file '{train_filename}' write down")
|
204 |
+
|
205 |
+
np.save(dev_filename, non_speech_dict['dev'])
|
206 |
+
print(f"npy file '{dev_filename}' write down")
|
207 |
+
|
208 |
+
np.save(test_filename, non_speech_dict['test'])
|
209 |
+
print(f"npy file '{test_filename}' write down")
|
210 |
+
print('time of save stage:', time.time() - save_start_time)
|
211 |
+
|
212 |
+
|
213 |
+
def main():
|
214 |
+
# parse config and args
|
215 |
+
parser = argparse.ArgumentParser(
|
216 |
+
description="Use AudioTag tool BEATs to filter out audios who's top1 tag is not 'speech'."
|
217 |
+
)
|
218 |
+
|
219 |
+
parser.add_argument(
|
220 |
+
"--data_dir", default=None, type=str, help="directory to dataset.")
|
221 |
+
|
222 |
+
parser.add_argument(
|
223 |
+
"--dump_dir",
|
224 |
+
type=str,
|
225 |
+
required=True,
|
226 |
+
help="directory to dump feature files.")
|
227 |
+
|
228 |
+
parser.add_argument(
|
229 |
+
"--num-cpu", type=int, default=1, help="number of process.")
|
230 |
+
|
231 |
+
parser.add_argument(
|
232 |
+
'--sr', type=int, default=16000, help='sample rate of model')
|
233 |
+
|
234 |
+
# For LibriLight dataset
|
235 |
+
parser.add_argument(
|
236 |
+
"--sub_dataset",
|
237 |
+
default="small",
|
238 |
+
type=str,
|
239 |
+
help="name of sub dataset of LibriLight",
|
240 |
+
choices=['small', 'medium', 'large', 'duplicate'], )
|
241 |
+
parser.add_argument(
|
242 |
+
"--VAD_path", type=str, default='./VAD/librilight_segment_dict.npy')
|
243 |
+
parser.add_argument("--nshard", type=int, default=3)
|
244 |
+
parser.add_argument("--rank", type=int, default=0)
|
245 |
+
|
246 |
+
# for BEATs
|
247 |
+
parser.add_argument(
|
248 |
+
"--BEATs_ckpt_path",
|
249 |
+
type=str,
|
250 |
+
default='./pretrained_model/BEATs_iter1_finetuned_on_AS2M_cpt1.pt')
|
251 |
+
|
252 |
+
args = parser.parse_args()
|
253 |
+
|
254 |
+
data_dir = Path(args.data_dir).expanduser()
|
255 |
+
dump_dir = Path(args.dump_dir).expanduser()
|
256 |
+
# use absolute path
|
257 |
+
dump_dir = dump_dir.resolve()
|
258 |
+
dump_dir.mkdir(parents=True, exist_ok=True)
|
259 |
+
|
260 |
+
assert data_dir.is_dir()
|
261 |
+
|
262 |
+
# sub_dataset here
|
263 |
+
sub_dataset_dir = data_dir / args.sub_dataset
|
264 |
+
# olny spk_id in list, sort by lexicographical order
|
265 |
+
speaker_list = sorted(os.listdir(sub_dataset_dir))
|
266 |
+
start, end = get_shard_range(len(speaker_list), args.nshard, args.rank)
|
267 |
+
# speaker_list for this rank
|
268 |
+
speaker_list = speaker_list[start:end]
|
269 |
+
|
270 |
+
all_wav_files = []
|
271 |
+
|
272 |
+
for speaker in speaker_list:
|
273 |
+
wav_files = sorted(list((sub_dataset_dir / speaker).rglob("*/*.flac")))
|
274 |
+
# filter out ._*.flac
|
275 |
+
wav_files = [
|
276 |
+
file for file in wav_files if not file.name.startswith('._')
|
277 |
+
]
|
278 |
+
all_wav_files += wav_files
|
279 |
+
|
280 |
+
print(f"num of wav files in rank {args.rank}:", len(all_wav_files))
|
281 |
+
# get VAD info
|
282 |
+
VAD_dict = np.load(args.VAD_path, allow_pickle=True).item()
|
283 |
+
|
284 |
+
sub_dataset_dump_dir = dump_dir / args.sub_dataset
|
285 |
+
sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
|
286 |
+
train_dump_dir = sub_dataset_dump_dir / "train"
|
287 |
+
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
288 |
+
dev_dump_dir = sub_dataset_dump_dir / "dev"
|
289 |
+
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
290 |
+
test_dump_dir = sub_dataset_dump_dir / "test"
|
291 |
+
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
292 |
+
|
293 |
+
BEATs_ckpt = torch.load(args.BEATs_ckpt_path)
|
294 |
+
|
295 |
+
BEATs_cfg = BEATsConfig(BEATs_ckpt['cfg'])
|
296 |
+
BEATs_model = BEATs(BEATs_cfg)
|
297 |
+
BEATs_model.load_state_dict(BEATs_ckpt['model'])
|
298 |
+
BEATs_model.eval()
|
299 |
+
# cpu or cuda
|
300 |
+
device = 'cpu'
|
301 |
+
BEATs_model.to(device)
|
302 |
+
|
303 |
+
BEATs_label_dict = BEATs_ckpt['label_dict']
|
304 |
+
|
305 |
+
# 每条大 wav 分出一个 dev 一个 test,比例大概是 96:2:2
|
306 |
+
if all_wav_files:
|
307 |
+
process_sentences(
|
308 |
+
args=args,
|
309 |
+
fps=all_wav_files,
|
310 |
+
train_dump_dir=train_dump_dir,
|
311 |
+
dev_dump_dir=dev_dump_dir,
|
312 |
+
test_dump_dir=test_dump_dir,
|
313 |
+
VAD_dict=VAD_dict,
|
314 |
+
BEATs_model=BEATs_model,
|
315 |
+
BEATs_label_dict=BEATs_label_dict,
|
316 |
+
device=device,
|
317 |
+
nprocs=args.num_cpu)
|
318 |
+
|
319 |
+
|
320 |
+
if __name__ == "__main__":
|
321 |
+
main()
|
AR/exps/get_phones.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
1. read text of dataset
|
3 |
+
2. text -> IPA by GruutPhonemizer
|
4 |
+
3. save out a *.npy dict for all text
|
5 |
+
my_dict = {"utt_id1": text1, "utt_id2": text2}
|
6 |
+
np.save(output_filename, my_dict)
|
7 |
+
my_dict = np.load(output_filename, allow_pickle=True).item()
|
8 |
+
"""
|
9 |
+
import argparse
|
10 |
+
import os
|
11 |
+
from concurrent.futures import ThreadPoolExecutor
|
12 |
+
from operator import itemgetter
|
13 |
+
from pathlib import Path
|
14 |
+
from typing import List
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import tqdm
|
18 |
+
from AR.text_processing.phonemizer import GruutPhonemizer
|
19 |
+
|
20 |
+
|
21 |
+
def read_txt(txt_file):
|
22 |
+
utt_name = txt_file.stem
|
23 |
+
utt_id = utt_name.split('.')[0]
|
24 |
+
try:
|
25 |
+
with open(txt_file, 'r') as file:
|
26 |
+
txt = file.readline()
|
27 |
+
record = {"utt_id": utt_id, "txt": txt}
|
28 |
+
except Exception:
|
29 |
+
print("occur Exception")
|
30 |
+
traceback.print_exc()
|
31 |
+
return None
|
32 |
+
return record
|
33 |
+
|
34 |
+
|
35 |
+
def read_txts(txt_files: List[Path], nprocs: int=1):
|
36 |
+
if nprocs == 1:
|
37 |
+
results = []
|
38 |
+
for txt_file in tqdm.tqdm(txt_files, total=len(txt_files)):
|
39 |
+
record = read_txt(txt_file=txt_file)
|
40 |
+
if record:
|
41 |
+
results.append(record)
|
42 |
+
else:
|
43 |
+
with ThreadPoolExecutor(nprocs) as pool:
|
44 |
+
futures = []
|
45 |
+
with tqdm.tqdm(total=len(txt_files)) as progress:
|
46 |
+
for txt_file in txt_files:
|
47 |
+
future = pool.submit(read_txt, txt_file)
|
48 |
+
future.add_done_callback(lambda p: progress.update())
|
49 |
+
futures.append(future)
|
50 |
+
|
51 |
+
results = []
|
52 |
+
for ft in futures:
|
53 |
+
record = ft.result()
|
54 |
+
if record:
|
55 |
+
results.append(record)
|
56 |
+
|
57 |
+
results.sort(key=itemgetter("utt_id"))
|
58 |
+
return_list = []
|
59 |
+
for item in results:
|
60 |
+
return_list.append((item["utt_id"], item["txt"]))
|
61 |
+
return return_list
|
62 |
+
|
63 |
+
|
64 |
+
def process_sentence(item, phonemizer):
|
65 |
+
utt_id, text = item
|
66 |
+
try:
|
67 |
+
phonemes = phonemizer.phonemize(text, espeak=False)
|
68 |
+
record = {"utt_id": utt_id, "phonemes": phonemes}
|
69 |
+
except Exception:
|
70 |
+
print("occur Exception")
|
71 |
+
traceback.print_exc()
|
72 |
+
return None
|
73 |
+
return record
|
74 |
+
|
75 |
+
|
76 |
+
def process_sentences(items, phonemizer, output_dir, nprocs: int=1):
|
77 |
+
if nprocs == 1:
|
78 |
+
results = []
|
79 |
+
for item in tqdm.tqdm(items, total=len(items)):
|
80 |
+
record = process_sentence(item=item, phonemizer=phonemizer)
|
81 |
+
if record:
|
82 |
+
results.append(record)
|
83 |
+
else:
|
84 |
+
with ThreadPoolExecutor(nprocs) as pool:
|
85 |
+
futures = []
|
86 |
+
with tqdm.tqdm(total=len(items)) as progress:
|
87 |
+
for item in items:
|
88 |
+
future = pool.submit(process_sentence, item, phonemizer)
|
89 |
+
future.add_done_callback(lambda p: progress.update())
|
90 |
+
futures.append(future)
|
91 |
+
|
92 |
+
results = []
|
93 |
+
for ft in futures:
|
94 |
+
record = ft.result()
|
95 |
+
if record:
|
96 |
+
results.append(record)
|
97 |
+
results.sort(key=itemgetter("utt_id"))
|
98 |
+
npy_dict = {}
|
99 |
+
for item in results:
|
100 |
+
utt_id = item["utt_id"]
|
101 |
+
phonemes = item["phonemes"]
|
102 |
+
npy_dict[utt_id] = phonemes
|
103 |
+
filename = output_dir / 'phonemes.npy'
|
104 |
+
np.save(filename, npy_dict)
|
105 |
+
print(f"npy file '{filename}' write down")
|
106 |
+
|
107 |
+
|
108 |
+
def main():
|
109 |
+
# parse config and args
|
110 |
+
parser = argparse.ArgumentParser(description="Get phones for datasets")
|
111 |
+
|
112 |
+
parser.add_argument(
|
113 |
+
"--dataset",
|
114 |
+
default="ljspeech",
|
115 |
+
type=str,
|
116 |
+
help="name of dataset, should in {ljspeech, libritts} now")
|
117 |
+
|
118 |
+
parser.add_argument(
|
119 |
+
"--data_dir", default=None, type=str, help="directory to dataset.")
|
120 |
+
|
121 |
+
parser.add_argument(
|
122 |
+
"--dump_dir",
|
123 |
+
type=str,
|
124 |
+
required=True,
|
125 |
+
help="directory to dump feature files.")
|
126 |
+
parser.add_argument(
|
127 |
+
"--num-cpu", type=int, default=1, help="number of process.")
|
128 |
+
|
129 |
+
args = parser.parse_args()
|
130 |
+
|
131 |
+
data_dir = Path(args.data_dir).expanduser()
|
132 |
+
dump_dir = Path(args.dump_dir).expanduser()
|
133 |
+
# use absolute path
|
134 |
+
dump_dir = dump_dir.resolve()
|
135 |
+
dump_dir.mkdir(parents=True, exist_ok=True)
|
136 |
+
|
137 |
+
assert data_dir.is_dir()
|
138 |
+
|
139 |
+
if args.dataset == "ljspeech":
|
140 |
+
data_dict = {}
|
141 |
+
text_path = data_dir / 'metadata.csv'
|
142 |
+
with open(text_path, 'r') as rf:
|
143 |
+
for line in rf:
|
144 |
+
line_list = line.strip().split('|')
|
145 |
+
utt_id = line_list[0]
|
146 |
+
raw_text = line_list[-1]
|
147 |
+
data_dict[utt_id] = raw_text
|
148 |
+
|
149 |
+
sorted_dict = sorted(data_dict.items())
|
150 |
+
|
151 |
+
num_train = 12900
|
152 |
+
num_dev = 100
|
153 |
+
# (utt_id, txt)
|
154 |
+
train_txts = sorted_dict[:num_train]
|
155 |
+
dev_txts = sorted_dict[num_train:num_train + num_dev]
|
156 |
+
test_txts = sorted_dict[num_train + num_dev:]
|
157 |
+
|
158 |
+
elif args.dataset == "libritts":
|
159 |
+
'''
|
160 |
+
we use train-clean-100、train-clean-360、train-other-500 here
|
161 |
+
and split dev and test from them, don't use test-* and dev-* cause the speakers are disjoint
|
162 |
+
the file structure is LibriTTS_R/train-clean-100/spkid/*/*.wav
|
163 |
+
there are about 2311 in these subsets, we split 1 dev and 1 test wav out from each speaker
|
164 |
+
'''
|
165 |
+
txt_files = []
|
166 |
+
train_txt_files = []
|
167 |
+
dev_txt_files = []
|
168 |
+
test_txt_files = []
|
169 |
+
sub_num_dev = 1
|
170 |
+
for sub_dataset_name in {
|
171 |
+
"train-clean-100", "train-clean-360", "train-other-500"
|
172 |
+
}:
|
173 |
+
sub_dataset_dir = data_dir / sub_dataset_name
|
174 |
+
# filter out hidden files
|
175 |
+
speaker_list = [
|
176 |
+
file for file in os.listdir(sub_dataset_dir)
|
177 |
+
if not file.startswith('.')
|
178 |
+
]
|
179 |
+
for speaker in speaker_list:
|
180 |
+
txt_files = sorted(
|
181 |
+
list((sub_dataset_dir / speaker).rglob(
|
182 |
+
"*/*.normalized.txt")))
|
183 |
+
# filter out ._*.wav
|
184 |
+
txt_files = [
|
185 |
+
file for file in txt_files if not file.name.startswith('._')
|
186 |
+
]
|
187 |
+
train_txt_files += txt_files[:-sub_num_dev * 2]
|
188 |
+
dev_txt_files += txt_files[-sub_num_dev * 2:-sub_num_dev]
|
189 |
+
test_txt_files += txt_files[-sub_num_dev:]
|
190 |
+
print("len(train_txt_files):", len(train_txt_files))
|
191 |
+
print("len(dev_txt_files):", len(dev_txt_files))
|
192 |
+
print("len(test_txt_files):", len(test_txt_files))
|
193 |
+
|
194 |
+
train_txts = read_txts(train_txt_files)
|
195 |
+
dev_txts = read_txts(dev_txt_files)
|
196 |
+
test_txts = read_txts(test_txt_files)
|
197 |
+
|
198 |
+
else:
|
199 |
+
print("dataset should in {ljspeech, libritts} now!")
|
200 |
+
|
201 |
+
train_dump_dir = dump_dir / "train"
|
202 |
+
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
203 |
+
dev_dump_dir = dump_dir / "dev"
|
204 |
+
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
205 |
+
test_dump_dir = dump_dir / "test"
|
206 |
+
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
207 |
+
|
208 |
+
phonemizer = GruutPhonemizer(language='en-us')
|
209 |
+
|
210 |
+
# process for the 3 sections
|
211 |
+
if train_txts:
|
212 |
+
process_sentences(
|
213 |
+
items=train_txts,
|
214 |
+
output_dir=train_dump_dir,
|
215 |
+
phonemizer=phonemizer,
|
216 |
+
nprocs=args.num_cpu)
|
217 |
+
if dev_txts:
|
218 |
+
process_sentences(
|
219 |
+
items=dev_txts,
|
220 |
+
output_dir=dev_dump_dir,
|
221 |
+
phonemizer=phonemizer,
|
222 |
+
nprocs=args.num_cpu)
|
223 |
+
if test_txts:
|
224 |
+
process_sentences(
|
225 |
+
items=test_txts,
|
226 |
+
output_dir=test_dump_dir,
|
227 |
+
phonemizer=phonemizer,
|
228 |
+
nprocs=args.num_cpu)
|
229 |
+
|
230 |
+
|
231 |
+
if __name__ == "__main__":
|
232 |
+
main()
|
AR/exps/get_phones_librilight.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
1. read text of dataset, for LibriLight read txt_*.npy -> 需要整理成 list(utt_id, txt) 的形式
|
3 |
+
2. text -> IPA by GruutPhonemizer
|
4 |
+
3. save out a *.npy dict for all text
|
5 |
+
4. LibriLight 每个 split 分开处理
|
6 |
+
my_dict = {"utt_id1": text1, "utt_id2": text2}
|
7 |
+
np.save(output_filename, my_dict)
|
8 |
+
my_dict = np.load(output_filename, allow_pickle=True).item()
|
9 |
+
"""
|
10 |
+
import argparse
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
import traceback
|
14 |
+
from concurrent.futures import ThreadPoolExecutor
|
15 |
+
from operator import itemgetter
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import tqdm
|
20 |
+
from AR.text_processing.phonemizer import GruutPhonemizer
|
21 |
+
from soundstorm.utils import check_txt_file
|
22 |
+
|
23 |
+
|
24 |
+
def read_txts(txt_file: Path, nprocs: int=1):
|
25 |
+
'''
|
26 |
+
txt_file: path of npy dict, {"utt_id1": text1, "utt_id2": text2}
|
27 |
+
'''
|
28 |
+
txt_dict = np.load(txt_file, allow_pickle=True).item()
|
29 |
+
#[(utt_id, txt), ...]
|
30 |
+
return_list = list(txt_dict.items())
|
31 |
+
return return_list
|
32 |
+
|
33 |
+
|
34 |
+
def process_sentence(item, phonemizer, output_dir):
|
35 |
+
utt_id, text = item
|
36 |
+
phonemes_dir = output_dir / "phonemes"
|
37 |
+
phonemes_dir.mkdir(parents=True, exist_ok=True)
|
38 |
+
phonemes_path = phonemes_dir / (utt_id + ".txt")
|
39 |
+
try:
|
40 |
+
if os.path.exists(phonemes_path) and check_txt_file(phonemes_path):
|
41 |
+
# print(phonemes_path, 'exits!')
|
42 |
+
pass
|
43 |
+
else:
|
44 |
+
phonemes = phonemizer.phonemize(text, espeak=False)
|
45 |
+
with open(phonemes_path, 'w') as f:
|
46 |
+
f.write(phonemes)
|
47 |
+
record = {"utt_id": utt_id, "phonemes_path": phonemes_path}
|
48 |
+
except Exception:
|
49 |
+
print("occur Exception")
|
50 |
+
traceback.print_exc()
|
51 |
+
return None
|
52 |
+
return record
|
53 |
+
|
54 |
+
|
55 |
+
def process_sentences(args, items, phonemizer, output_dir, nprocs: int=1):
|
56 |
+
print("nprocs:", nprocs)
|
57 |
+
if nprocs == 1:
|
58 |
+
results = []
|
59 |
+
for item in tqdm.tqdm(items, total=len(items)):
|
60 |
+
record = process_sentence(
|
61 |
+
item=item, phonemizer=phonemizer, output_dir=output_dir)
|
62 |
+
if record:
|
63 |
+
results.append(record)
|
64 |
+
else:
|
65 |
+
with ThreadPoolExecutor(nprocs) as pool:
|
66 |
+
futures = []
|
67 |
+
with tqdm.tqdm(total=len(items)) as progress:
|
68 |
+
for item in items:
|
69 |
+
future = pool.submit(process_sentence, item, phonemizer,
|
70 |
+
output_dir)
|
71 |
+
future.add_done_callback(lambda p: progress.update())
|
72 |
+
futures.append(future)
|
73 |
+
|
74 |
+
results = []
|
75 |
+
for ft in futures:
|
76 |
+
record = ft.result()
|
77 |
+
if record:
|
78 |
+
results.append(record)
|
79 |
+
|
80 |
+
results.sort(key=itemgetter("utt_id"))
|
81 |
+
|
82 |
+
npy_dict = {}
|
83 |
+
print(f"start to save {args.rank}_{args.nshard}.npy ...")
|
84 |
+
save_start_time = time.time()
|
85 |
+
for item in tqdm.tqdm(results, total=len(results), colour='green'):
|
86 |
+
# 这里加 try, 因为 txt 文件可能损坏
|
87 |
+
try:
|
88 |
+
utt_id = item["utt_id"]
|
89 |
+
phonemes = check_txt_file(item["phonemes_path"])
|
90 |
+
if phonemes is not False:
|
91 |
+
npy_dict[utt_id] = phonemes
|
92 |
+
else:
|
93 |
+
print(f'phonemes of {utt_id} is False')
|
94 |
+
except Exception:
|
95 |
+
print(f"{utt_id} occur Exception")
|
96 |
+
traceback.print_exc()
|
97 |
+
continue
|
98 |
+
|
99 |
+
filename = output_dir / f'phonemes_{args.rank}_{args.nshard}.npy'
|
100 |
+
np.save(filename, npy_dict)
|
101 |
+
print(f"npy file '{filename}' write down")
|
102 |
+
print('time of save stage:', time.time() - save_start_time)
|
103 |
+
|
104 |
+
|
105 |
+
def main():
|
106 |
+
# parse config and args
|
107 |
+
parser = argparse.ArgumentParser(
|
108 |
+
description="Get phones for LibriLight dataset from txt_*.npy")
|
109 |
+
|
110 |
+
parser.add_argument(
|
111 |
+
"--dump_dir",
|
112 |
+
type=str,
|
113 |
+
required=True,
|
114 |
+
help="directory to dump feature files.")
|
115 |
+
parser.add_argument(
|
116 |
+
"--num-cpu", type=int, default=1, help="number of process.")
|
117 |
+
|
118 |
+
parser.add_argument(
|
119 |
+
'--train_txt_dir',
|
120 |
+
type=str,
|
121 |
+
default='dump/small/train/',
|
122 |
+
help='dir of train txt files')
|
123 |
+
parser.add_argument(
|
124 |
+
'--dev_txt_dir',
|
125 |
+
type=str,
|
126 |
+
default='dump/small/dev/',
|
127 |
+
help='dir of dev txt files')
|
128 |
+
parser.add_argument(
|
129 |
+
'--test_txt_dir',
|
130 |
+
type=str,
|
131 |
+
default='dump/small/test/',
|
132 |
+
help='dir of test txt files')
|
133 |
+
|
134 |
+
parser.add_argument(
|
135 |
+
"--sub_dataset",
|
136 |
+
default="small",
|
137 |
+
type=str,
|
138 |
+
help="name of sub dataset of LibriLight",
|
139 |
+
choices=['small', 'medium', 'large', 'duplicate'], )
|
140 |
+
parser.add_argument("--nshard", type=int, default=3)
|
141 |
+
parser.add_argument("--rank", type=int, default=0)
|
142 |
+
|
143 |
+
args = parser.parse_args()
|
144 |
+
print(f"nshard: {args.nshard}, rank: {args.rank}")
|
145 |
+
|
146 |
+
train_txt_dir = Path(args.train_txt_dir)
|
147 |
+
dev_txt_dir = Path(args.dev_txt_dir)
|
148 |
+
test_txt_dir = Path(args.test_txt_dir)
|
149 |
+
|
150 |
+
dump_dir = Path(args.dump_dir).expanduser()
|
151 |
+
# use absolute path
|
152 |
+
dump_dir = dump_dir.resolve()
|
153 |
+
dump_dir.mkdir(parents=True, exist_ok=True)
|
154 |
+
|
155 |
+
train_txt_file = train_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
|
156 |
+
dev_txt_file = dev_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
|
157 |
+
test_txt_file = test_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
|
158 |
+
|
159 |
+
train_txts = read_txts(train_txt_file)
|
160 |
+
dev_txts = read_txts(dev_txt_file)
|
161 |
+
test_txts = read_txts(test_txt_file)
|
162 |
+
|
163 |
+
sub_dataset_dump_dir = dump_dir / args.sub_dataset
|
164 |
+
sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
|
165 |
+
train_dump_dir = sub_dataset_dump_dir / "train"
|
166 |
+
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
167 |
+
dev_dump_dir = sub_dataset_dump_dir / "dev"
|
168 |
+
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
169 |
+
test_dump_dir = sub_dataset_dump_dir / "test"
|
170 |
+
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
171 |
+
phonemizer = GruutPhonemizer(language='en-us')
|
172 |
+
|
173 |
+
# process for the 3 sections
|
174 |
+
if train_txts:
|
175 |
+
process_sentences(
|
176 |
+
args=args,
|
177 |
+
items=train_txts,
|
178 |
+
output_dir=train_dump_dir,
|
179 |
+
phonemizer=phonemizer,
|
180 |
+
nprocs=args.num_cpu)
|
181 |
+
if dev_txts:
|
182 |
+
process_sentences(
|
183 |
+
args=args,
|
184 |
+
items=dev_txts,
|
185 |
+
output_dir=dev_dump_dir,
|
186 |
+
phonemizer=phonemizer,
|
187 |
+
nprocs=args.num_cpu)
|
188 |
+
if test_txts:
|
189 |
+
process_sentences(
|
190 |
+
args=args,
|
191 |
+
items=test_txts,
|
192 |
+
output_dir=test_dump_dir,
|
193 |
+
phonemizer=phonemizer,
|
194 |
+
nprocs=args.num_cpu)
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
main()
|
AR/exps/get_txt_librilight.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import traceback
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import librosa
|
9 |
+
import numpy as np
|
10 |
+
import tqdm
|
11 |
+
import whisper
|
12 |
+
from soundstorm.s2.exps.hubert.feature_utils import get_shard_range
|
13 |
+
from soundstorm.utils import check_txt_file
|
14 |
+
|
15 |
+
|
16 |
+
def process_sentence(args,
|
17 |
+
fp: Path,
|
18 |
+
train_dump_dir: Path,
|
19 |
+
dev_dump_dir: Path,
|
20 |
+
test_dump_dir: Path,
|
21 |
+
VAD_dict):
|
22 |
+
asr_model = whisper.load_model("tiny.en")
|
23 |
+
utt_id = fp.stem
|
24 |
+
sr = args.sr
|
25 |
+
record = []
|
26 |
+
train_txt_dir = train_dump_dir / "txt"
|
27 |
+
train_txt_dir.mkdir(parents=True, exist_ok=True)
|
28 |
+
|
29 |
+
dev_txt_dir = dev_dump_dir / "txt"
|
30 |
+
dev_txt_dir.mkdir(parents=True, exist_ok=True)
|
31 |
+
|
32 |
+
test_txt_dir = test_dump_dir / "txt"
|
33 |
+
test_txt_dir.mkdir(parents=True, exist_ok=True)
|
34 |
+
|
35 |
+
try:
|
36 |
+
# get info for path
|
37 |
+
wav_path_list = str(fp).strip().split('/')
|
38 |
+
sub_dataset, spk_id, book_name = wav_path_list[-4], wav_path_list[
|
39 |
+
-3], wav_path_list[-2]
|
40 |
+
wav_name = wav_path_list[-1][:-5]
|
41 |
+
assert wav_name == utt_id
|
42 |
+
# key_name for big wav
|
43 |
+
key_name = f'{wav_name}#{sub_dataset}#{spk_id}#{book_name}'
|
44 |
+
# 判断 VAD 字典中不存在该条音频信息的情况
|
45 |
+
if key_name not in VAD_dict.keys():
|
46 |
+
print(key_name, 'not in VAD_dict !')
|
47 |
+
return record
|
48 |
+
wav = None
|
49 |
+
sorted_split_VAD_dict = sorted(VAD_dict[key_name].items())
|
50 |
+
len_dict = len(sorted_split_VAD_dict)
|
51 |
+
for index, item in enumerate(sorted_split_VAD_dict):
|
52 |
+
split_name, value = item
|
53 |
+
start, end = value
|
54 |
+
# train | dev | test
|
55 |
+
if index == len_dict - 1:
|
56 |
+
subset = 'test'
|
57 |
+
txt_path = test_txt_dir / (split_name + ".txt")
|
58 |
+
elif index == len_dict - 2:
|
59 |
+
subset = 'dev'
|
60 |
+
txt_path = dev_txt_dir / (split_name + ".txt")
|
61 |
+
else:
|
62 |
+
subset = 'train'
|
63 |
+
txt_path = train_txt_dir / (split_name + ".txt")
|
64 |
+
|
65 |
+
if os.path.exists(txt_path) and check_txt_file(txt_path):
|
66 |
+
# print(txt_path, 'exits!')
|
67 |
+
pass
|
68 |
+
else:
|
69 |
+
# 这里加判断保证在 sub wav 的循环中只 load 一次
|
70 |
+
if wav is None:
|
71 |
+
# load big wav
|
72 |
+
# 在最外层 load 如果 sub wav 的特征都存在了就会白白消耗 load 的时间
|
73 |
+
wav, _ = librosa.load(str(fp), sr=sr)
|
74 |
+
sub_wav = wav[int(start * sr):int(end * sr)]
|
75 |
+
asr_result = asr_model.transcribe(sub_wav)["text"]
|
76 |
+
with open(txt_path, 'w') as f:
|
77 |
+
f.write(asr_result)
|
78 |
+
|
79 |
+
sub_record = {
|
80 |
+
"utt_id": split_name,
|
81 |
+
"txt_path": txt_path,
|
82 |
+
"subset": subset
|
83 |
+
}
|
84 |
+
# recodrd 变成 List of Dict
|
85 |
+
record.append(sub_record)
|
86 |
+
except Exception:
|
87 |
+
print("occur Exception")
|
88 |
+
traceback.print_exc()
|
89 |
+
# record 有可能是一个不完整的 List
|
90 |
+
return record
|
91 |
+
return record
|
92 |
+
|
93 |
+
|
94 |
+
def process_sentences(args,
|
95 |
+
fps: Path,
|
96 |
+
train_dump_dir: Path,
|
97 |
+
dev_dump_dir: Path,
|
98 |
+
test_dump_dir: Path,
|
99 |
+
VAD_dict,
|
100 |
+
nprocs: int=1):
|
101 |
+
print("nprocs:", nprocs)
|
102 |
+
if nprocs == 1:
|
103 |
+
results = []
|
104 |
+
for fp in tqdm.tqdm(fps, total=len(fps)):
|
105 |
+
record = process_sentence(
|
106 |
+
args=args,
|
107 |
+
fp=fp,
|
108 |
+
train_dump_dir=train_dump_dir,
|
109 |
+
dev_dump_dir=dev_dump_dir,
|
110 |
+
test_dump_dir=test_dump_dir,
|
111 |
+
VAD_dict=VAD_dict)
|
112 |
+
if record:
|
113 |
+
results.append(record)
|
114 |
+
else:
|
115 |
+
with ThreadPoolExecutor(nprocs) as pool:
|
116 |
+
futures = []
|
117 |
+
with tqdm.tqdm(total=len(fps)) as progress:
|
118 |
+
for fp in fps:
|
119 |
+
future = pool.submit(process_sentence, args, fp,
|
120 |
+
train_dump_dir, dev_dump_dir,
|
121 |
+
test_dump_dir, VAD_dict)
|
122 |
+
future.add_done_callback(lambda p: progress.update())
|
123 |
+
futures.append(future)
|
124 |
+
|
125 |
+
results = []
|
126 |
+
for ft in futures:
|
127 |
+
record = ft.result()
|
128 |
+
if record:
|
129 |
+
results.append(record)
|
130 |
+
|
131 |
+
# torch.save() to a large `.pth` file
|
132 |
+
txt_dict = dict()
|
133 |
+
txt_dict['train'] = {}
|
134 |
+
txt_dict['dev'] = {}
|
135 |
+
txt_dict['test'] = {}
|
136 |
+
# record 是 List of Dict, 一条大 wav 一个 record,一条小 wav 一个 sub_recored
|
137 |
+
print(f"start to save {args.rank}_{args.nshard}.npy ...")
|
138 |
+
save_start_time = time.time()
|
139 |
+
for record in tqdm.tqdm(results, total=len(results), colour='green'):
|
140 |
+
for sub_record in record:
|
141 |
+
# 这里加 try, 因为 txt 文件可能损坏
|
142 |
+
try:
|
143 |
+
utt_id = sub_record["utt_id"]
|
144 |
+
subset = sub_record["subset"]
|
145 |
+
asr_result = check_txt_file(sub_record["txt_path"])
|
146 |
+
if asr_result is not False:
|
147 |
+
txt_dict[subset][utt_id] = asr_result
|
148 |
+
else:
|
149 |
+
print(f'asr result of {utt_id} is False')
|
150 |
+
except Exception:
|
151 |
+
print(f"{utt_id} occur Exception")
|
152 |
+
traceback.print_exc()
|
153 |
+
continue
|
154 |
+
|
155 |
+
train_filename = train_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
|
156 |
+
dev_filename = dev_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
|
157 |
+
test_filename = test_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
|
158 |
+
np.save(train_filename, txt_dict['train'])
|
159 |
+
print(f"npy file '{train_filename}' write down")
|
160 |
+
|
161 |
+
np.save(dev_filename, txt_dict['dev'])
|
162 |
+
print(f"npy file '{dev_filename}' write down")
|
163 |
+
|
164 |
+
np.save(test_filename, txt_dict['test'])
|
165 |
+
print(f"npy file '{test_filename}' write down")
|
166 |
+
print('time of save stage:', time.time() - save_start_time)
|
167 |
+
|
168 |
+
|
169 |
+
def main():
|
170 |
+
# parse config and args
|
171 |
+
parser = argparse.ArgumentParser(
|
172 |
+
description="Preprocess audio and then extract features for LibriLight.")
|
173 |
+
|
174 |
+
parser.add_argument(
|
175 |
+
"--data_dir", default=None, type=str, help="directory to dataset.")
|
176 |
+
|
177 |
+
parser.add_argument(
|
178 |
+
"--dump_dir",
|
179 |
+
type=str,
|
180 |
+
required=True,
|
181 |
+
help="directory to dump feature files.")
|
182 |
+
|
183 |
+
parser.add_argument(
|
184 |
+
"--num-cpu", type=int, default=1, help="number of process.")
|
185 |
+
|
186 |
+
parser.add_argument(
|
187 |
+
'--sr', type=int, default=16000, help='sample rate of model')
|
188 |
+
|
189 |
+
# For LibriLight dataset
|
190 |
+
parser.add_argument(
|
191 |
+
"--sub_dataset",
|
192 |
+
default="small",
|
193 |
+
type=str,
|
194 |
+
help="name of sub dataset of LibriLight",
|
195 |
+
choices=['small', 'medium', 'large', 'duplicate'], )
|
196 |
+
parser.add_argument(
|
197 |
+
"--VAD_path", type=str, default='./VAD/librilight_segment_dict.npy')
|
198 |
+
parser.add_argument("--nshard", type=int, default=3)
|
199 |
+
parser.add_argument("--rank", type=int, default=0)
|
200 |
+
|
201 |
+
args = parser.parse_args()
|
202 |
+
|
203 |
+
data_dir = Path(args.data_dir).expanduser()
|
204 |
+
dump_dir = Path(args.dump_dir).expanduser()
|
205 |
+
# use absolute path
|
206 |
+
dump_dir = dump_dir.resolve()
|
207 |
+
dump_dir.mkdir(parents=True, exist_ok=True)
|
208 |
+
|
209 |
+
assert data_dir.is_dir()
|
210 |
+
|
211 |
+
# sub_dataset here
|
212 |
+
sub_dataset_dir = data_dir / args.sub_dataset
|
213 |
+
# olny spk_id in list, sort by lexicographical order
|
214 |
+
speaker_list = sorted(os.listdir(sub_dataset_dir))
|
215 |
+
start, end = get_shard_range(len(speaker_list), args.nshard, args.rank)
|
216 |
+
# speaker_list for this rank
|
217 |
+
speaker_list = speaker_list[start:end]
|
218 |
+
|
219 |
+
all_wav_files = []
|
220 |
+
|
221 |
+
for speaker in speaker_list:
|
222 |
+
wav_files = sorted(list((sub_dataset_dir / speaker).rglob("*/*.flac")))
|
223 |
+
# filter out ._*.flac
|
224 |
+
wav_files = [
|
225 |
+
file for file in wav_files if not file.name.startswith('._')
|
226 |
+
]
|
227 |
+
all_wav_files += wav_files
|
228 |
+
|
229 |
+
print(f"num of wav files in rank {args.rank}:", len(all_wav_files))
|
230 |
+
# get VAD info
|
231 |
+
VAD_dict = np.load(args.VAD_path, allow_pickle=True).item()
|
232 |
+
|
233 |
+
sub_dataset_dump_dir = dump_dir / args.sub_dataset
|
234 |
+
sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
|
235 |
+
train_dump_dir = sub_dataset_dump_dir / "train"
|
236 |
+
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
237 |
+
dev_dump_dir = sub_dataset_dump_dir / "dev"
|
238 |
+
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
239 |
+
test_dump_dir = sub_dataset_dump_dir / "test"
|
240 |
+
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
241 |
+
|
242 |
+
# 每条大 wav 分出一个 dev 一个 test,比例大概是 96:2:2
|
243 |
+
if all_wav_files:
|
244 |
+
process_sentences(
|
245 |
+
args=args,
|
246 |
+
fps=all_wav_files,
|
247 |
+
train_dump_dir=train_dump_dir,
|
248 |
+
dev_dump_dir=dev_dump_dir,
|
249 |
+
test_dump_dir=test_dump_dir,
|
250 |
+
VAD_dict=VAD_dict,
|
251 |
+
nprocs=args.num_cpu)
|
252 |
+
|
253 |
+
|
254 |
+
if __name__ == "__main__":
|
255 |
+
main()
|
AR/exps/split_train_val.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy
|
2 |
+
import pandas
|
3 |
+
|
4 |
+
semantic_path = 'dump/semantic.tsv'
|
5 |
+
phoneme_path = 'dump/phoneme.npy'
|
6 |
+
train_semantic_path = 'dump/semantic_train.tsv'
|
7 |
+
train_phoneme_path = 'dump/phoneme_train.npy'
|
8 |
+
dev_semantic_path = 'dump/semantic_dev.tsv'
|
9 |
+
dev_phoneme_path = 'dump/phoneme_dev.npy'
|
10 |
+
|
11 |
+
# 读取dump/semantic.tsv
|
12 |
+
semantic_df = pandas.read_csv(semantic_path, sep='\t')
|
13 |
+
# pd.DataFrame(columns=["item_name", "semantic_audio"])
|
14 |
+
# # 读取dump/phoneme.npy
|
15 |
+
phoneme_dict = numpy.load(phoneme_path, allow_pickle=True).item()
|
16 |
+
|
17 |
+
dev_num = 20
|
18 |
+
# 随机从semantic_df中选取dev_num个
|
19 |
+
dev_df = semantic_df.sample(n=dev_num)
|
20 |
+
# 剩下的是train
|
21 |
+
train_df = semantic_df.drop(dev_df.index)
|
22 |
+
# 保存
|
23 |
+
dev_df.to_csv(dev_semantic_path, sep='\t', index=False)
|
24 |
+
train_df.to_csv(train_semantic_path, sep='\t', index=False)
|
25 |
+
|
26 |
+
# 将dev_df中的item_name取出来 作为dev_phoneme_dict的key
|
27 |
+
dev_item_names = dev_df['item_name'].tolist()
|
28 |
+
dev_phoneme_dict = {k: phoneme_dict[k] for k in dev_item_names if k in phoneme_dict}
|
29 |
+
train_phoneme_dict = {k: phoneme_dict[k] for k in phoneme_dict.keys() if k not in dev_item_names}
|
30 |
+
|
31 |
+
numpy.save(dev_phoneme_path, dev_phoneme_dict)
|
32 |
+
numpy.save(train_phoneme_path, train_phoneme_dict)
|
33 |
+
|
34 |
+
|
35 |
+
|
AR/exps/t2s.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# text to semantic
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import librosa
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import whisper
|
12 |
+
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
13 |
+
from AR.text_processing.phonemizer import GruutPhonemizer
|
14 |
+
from AR.utils.io import load_yaml_config
|
15 |
+
|
16 |
+
|
17 |
+
def get_batch(text, phonemizer):
|
18 |
+
# phoneme_ids 和 phoneme_ids_len 是需要的
|
19 |
+
phoneme = phonemizer.phonemize(text, espeak=False)
|
20 |
+
phoneme_ids = phonemizer.transform(phoneme)
|
21 |
+
phoneme_ids_len = len(phoneme_ids)
|
22 |
+
phoneme_ids = np.array(phoneme_ids)
|
23 |
+
# add batch axis here
|
24 |
+
phoneme_ids = torch.tensor(phoneme_ids).unsqueeze(0)
|
25 |
+
phoneme_ids_len = torch.tensor([phoneme_ids_len])
|
26 |
+
print("phoneme:", phoneme)
|
27 |
+
batch = {
|
28 |
+
# torch.Tensor (B, max_phoneme_length)
|
29 |
+
"phoneme_ids": phoneme_ids,
|
30 |
+
# torch.Tensor (B)
|
31 |
+
"phoneme_ids_len": phoneme_ids_len
|
32 |
+
}
|
33 |
+
return batch
|
34 |
+
|
35 |
+
|
36 |
+
def get_prompt(prompt_wav_path, asr_model, phonemizer, semantic_tokenizer):
|
37 |
+
sample_rate = 16000
|
38 |
+
# to get prompt
|
39 |
+
prompt_name = os.path.basename(prompt_wav_path).split('.')[0]
|
40 |
+
wav, _ = librosa.load(prompt_wav_path, sr=sample_rate)
|
41 |
+
# 取末尾 3s, 但是不包含最后 0.1s 防止 AR S1 infer 提前停止
|
42 |
+
wav = wav[-sample_rate * 3:-int(sample_rate * 0.1)]
|
43 |
+
# wav 需要挪出末尾的静音否则也可能提前停住
|
44 |
+
prompt_text = asr_model.transcribe(wav)["text"]
|
45 |
+
# 移除最后的句点, 防止 AR S1 infer 提前停止, 加了句点可能会有停顿
|
46 |
+
prompt_text = prompt_text.replace(".", "")
|
47 |
+
prompt_phoneme = phonemizer.phonemize(prompt_text, espeak=False)
|
48 |
+
prompt_phoneme_ids = phonemizer.transform(prompt_phoneme)
|
49 |
+
prompt_phoneme_ids_len = len(prompt_phoneme_ids)
|
50 |
+
# get prompt_semantic
|
51 |
+
# (T) -> (1, T)
|
52 |
+
wav = torch.tensor(wav).unsqueeze(0)
|
53 |
+
wav = wav.cuda()
|
54 |
+
# (1, T)
|
55 |
+
prompt_semantic_tokens = semantic_tokenizer.tokenize(wav).to(torch.int32)
|
56 |
+
prompt_phoneme_ids = torch.tensor(prompt_phoneme_ids).unsqueeze(0)
|
57 |
+
prompt_phoneme_ids_len = torch.tensor([prompt_phoneme_ids_len])
|
58 |
+
|
59 |
+
result = {
|
60 |
+
'prompt_name': prompt_name,
|
61 |
+
'prompt_phoneme_ids': prompt_phoneme_ids,
|
62 |
+
'prompt_semantic_tokens': prompt_semantic_tokens,
|
63 |
+
'prompt_phoneme_ids_len': prompt_phoneme_ids_len
|
64 |
+
}
|
65 |
+
|
66 |
+
return result
|
67 |
+
|
68 |
+
|
69 |
+
def parse_args():
|
70 |
+
# parse args and config
|
71 |
+
parser = argparse.ArgumentParser(
|
72 |
+
description="Run SoundStorm AR S1 model for input text file")
|
73 |
+
|
74 |
+
parser.add_argument(
|
75 |
+
'--config_file',
|
76 |
+
type=str,
|
77 |
+
default='conf/default.yaml',
|
78 |
+
help='path of config file')
|
79 |
+
|
80 |
+
parser.add_argument(
|
81 |
+
"--text_file",
|
82 |
+
type=str,
|
83 |
+
help="text file to be convert to semantic tokens, a 'utt_id sentence' pair per line."
|
84 |
+
)
|
85 |
+
|
86 |
+
parser.add_argument(
|
87 |
+
'--ckpt_path',
|
88 |
+
type=str,
|
89 |
+
default='exp/default/ckpt/epoch=99-step=49000.ckpt',
|
90 |
+
help='Checkpoint file of SoundStorm AR S1 model.')
|
91 |
+
|
92 |
+
parser.add_argument(
|
93 |
+
'--prompt_wav_path',
|
94 |
+
type=str,
|
95 |
+
default=None,
|
96 |
+
help='extract prompt semantic and prompt phonemes from prompt wav')
|
97 |
+
|
98 |
+
# to get semantic tokens from prompt_wav
|
99 |
+
parser.add_argument("--hubert_path", type=str, default=None)
|
100 |
+
parser.add_argument("--quantizer_path", type=str, default=None)
|
101 |
+
|
102 |
+
parser.add_argument("--output_dir", type=str, help="output dir.")
|
103 |
+
|
104 |
+
args = parser.parse_args()
|
105 |
+
return args
|
106 |
+
|
107 |
+
|
108 |
+
def main():
|
109 |
+
args = parse_args()
|
110 |
+
config = load_yaml_config(args.config_file)
|
111 |
+
|
112 |
+
output_dir = Path(args.output_dir)
|
113 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
114 |
+
|
115 |
+
hz = 50
|
116 |
+
max_sec = config['data']['max_sec']
|
117 |
+
|
118 |
+
# get models
|
119 |
+
t2s_model = Text2SemanticLightningModule.load_from_checkpoint(
|
120 |
+
checkpoint_path=args.ckpt_path, config=config)
|
121 |
+
t2s_model.cuda()
|
122 |
+
t2s_model.eval()
|
123 |
+
|
124 |
+
phonemizer: GruutPhonemizer = GruutPhonemizer(language='en-us')
|
125 |
+
|
126 |
+
# models for prompt
|
127 |
+
asr_model = whisper.load_model("tiny.en")
|
128 |
+
|
129 |
+
semantic_tokenizer = SemanticTokenizer(
|
130 |
+
hubert_path=args.hubert_path,
|
131 |
+
quantizer_path=args.quantizer_path,
|
132 |
+
duplicate=True)
|
133 |
+
|
134 |
+
prompt_result = get_prompt(
|
135 |
+
prompt_wav_path=args.prompt_wav_path,
|
136 |
+
asr_model=asr_model,
|
137 |
+
phonemizer=phonemizer,
|
138 |
+
semantic_tokenizer=semantic_tokenizer)
|
139 |
+
|
140 |
+
# zero prompt => 输出的 semantic 包含的内容是对的但是音色是乱的
|
141 |
+
# (B, 1)
|
142 |
+
# prompt = torch.ones(
|
143 |
+
# batch['phoneme_ids'].size(0), 1, dtype=torch.int32) * 0
|
144 |
+
|
145 |
+
prompt = prompt_result['prompt_semantic_tokens']
|
146 |
+
prompt_phoneme_ids_len = prompt_result['prompt_phoneme_ids_len']
|
147 |
+
prompt_phoneme_ids = prompt_result['prompt_phoneme_ids']
|
148 |
+
|
149 |
+
sentences = []
|
150 |
+
with open(args.text_file, 'rt', encoding='utf-8') as f:
|
151 |
+
for line in f:
|
152 |
+
if line.strip() != "":
|
153 |
+
items = re.split(r"\s+", line.strip(), 1)
|
154 |
+
utt_id = items[0]
|
155 |
+
sentence = " ".join(items[1:])
|
156 |
+
sentences.append((utt_id, sentence))
|
157 |
+
semantic_data = [['item_name', 'semantic_audio']]
|
158 |
+
for utt_id, sentence in sentences[1:]:
|
159 |
+
# 需要自己构造伪 batch 输入给模型
|
160 |
+
batch = get_batch(sentence, phonemizer)
|
161 |
+
# prompt 和真正的输入拼接
|
162 |
+
all_phoneme_ids = torch.cat(
|
163 |
+
[prompt_phoneme_ids, batch['phoneme_ids']], dim=1)
|
164 |
+
# 或者可以直接求 all_phoneme_ids 的 shape[-1]
|
165 |
+
all_phoneme_len = prompt_phoneme_ids_len + batch['phoneme_ids_len']
|
166 |
+
st = time.time()
|
167 |
+
with torch.no_grad():
|
168 |
+
pred_semantic = t2s_model.model.infer(
|
169 |
+
all_phoneme_ids.cuda(),
|
170 |
+
all_phoneme_len.cuda(),
|
171 |
+
prompt.cuda(),
|
172 |
+
top_k=config['inference']['top_k'],
|
173 |
+
early_stop_num=hz * max_sec)
|
174 |
+
print(f'{time.time() - st} sec used in T2S')
|
175 |
+
|
176 |
+
# 删除 prompt 对应的部分
|
177 |
+
prompt_len = prompt.shape[-1]
|
178 |
+
pred_semantic = pred_semantic[:, prompt_len:]
|
179 |
+
|
180 |
+
# bs = 1
|
181 |
+
pred_semantic = pred_semantic[0]
|
182 |
+
semantic_token = pred_semantic.detach().cpu().numpy().tolist()
|
183 |
+
semantic_token_str = ' '.join(str(x) for x in semantic_token)
|
184 |
+
semantic_data.append([utt_id, semantic_token_str])
|
185 |
+
|
186 |
+
delimiter = '\t'
|
187 |
+
filename = output_dir / f'{utt_id}_p_{prompt_result["prompt_name"]}_semantic_token.tsv'
|
188 |
+
with open(filename, 'w', encoding='utf-8') as writer:
|
189 |
+
for row in semantic_data:
|
190 |
+
line = delimiter.join(row)
|
191 |
+
writer.write(line + '\n')
|
192 |
+
# clean semantic token for next setence
|
193 |
+
semantic_data = [['item_name', 'semantic_audio']]
|
194 |
+
|
195 |
+
|
196 |
+
if __name__ == "__main__":
|
197 |
+
main()
|
AR/exps/test.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# test from dump file
|
2 |
+
import argparse
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from AR.data.dataset import Text2SemanticDataset
|
9 |
+
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
10 |
+
from AR.utils.io import load_yaml_config
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
|
13 |
+
|
14 |
+
def parse_args():
|
15 |
+
# parse args and config
|
16 |
+
parser = argparse.ArgumentParser(
|
17 |
+
description="Run SoundStorm AR S1 model for test set.")
|
18 |
+
|
19 |
+
parser.add_argument(
|
20 |
+
'--config_file',
|
21 |
+
type=str,
|
22 |
+
default='conf/default.yaml',
|
23 |
+
help='path of config file')
|
24 |
+
|
25 |
+
# args for dataset
|
26 |
+
parser.add_argument(
|
27 |
+
'--test_semantic_path',
|
28 |
+
type=str,
|
29 |
+
default='dump/test/semantic_token.tsv')
|
30 |
+
parser.add_argument(
|
31 |
+
'--test_phoneme_path', type=str, default='dump/test/phonemes.npy')
|
32 |
+
|
33 |
+
parser.add_argument(
|
34 |
+
'--ckpt_path',
|
35 |
+
type=str,
|
36 |
+
default='exp/default/ckpt/epoch=99-step=49000.ckpt',
|
37 |
+
help='Checkpoint file of SoundStorm AR S1 model.')
|
38 |
+
|
39 |
+
parser.add_argument("--output_dir", type=str, help="output dir.")
|
40 |
+
|
41 |
+
args = parser.parse_args()
|
42 |
+
return args
|
43 |
+
|
44 |
+
|
45 |
+
def main():
|
46 |
+
args = parse_args()
|
47 |
+
|
48 |
+
config = load_yaml_config(args.config_file)
|
49 |
+
|
50 |
+
output_dir = Path(args.output_dir)
|
51 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
52 |
+
|
53 |
+
batch_size = 1
|
54 |
+
hz = 50
|
55 |
+
max_sec = config['data']['max_sec']
|
56 |
+
|
57 |
+
# get dataset
|
58 |
+
test_dataset = Text2SemanticDataset(
|
59 |
+
phoneme_path=args.test_phoneme_path,
|
60 |
+
semantic_path=args.test_semantic_path,
|
61 |
+
# max_sec 需要与训练时保持一致,不然可能会效果不好,重复漏字等
|
62 |
+
# 但是这里设置太短又会直接过滤掉太长的样本,为了防止被过滤掉,可以在 infer 的时候截断
|
63 |
+
max_sec=100,
|
64 |
+
max_sample=8,
|
65 |
+
pad_val=config['data']['pad_val'])
|
66 |
+
# get model
|
67 |
+
t2s_model = Text2SemanticLightningModule.load_from_checkpoint(
|
68 |
+
checkpoint_path=args.ckpt_path, config=config)
|
69 |
+
t2s_model.cuda()
|
70 |
+
t2s_model.eval()
|
71 |
+
|
72 |
+
# 获取 batch_size 条
|
73 |
+
# 创建 DataLoader,并指定 collate_fn 函数
|
74 |
+
dataloader = DataLoader(
|
75 |
+
test_dataset,
|
76 |
+
batch_size=batch_size,
|
77 |
+
shuffle=False,
|
78 |
+
collate_fn=test_dataset.collate)
|
79 |
+
|
80 |
+
item_names = test_dataset.__get_item_names__()
|
81 |
+
|
82 |
+
# 逐批次读取数据, bs=1、shuffle=False 时可以用 __get_item_names__ 对应
|
83 |
+
semantic_data = [['item_name', 'semantic_audio']]
|
84 |
+
for i, batch in enumerate(dataloader):
|
85 |
+
# 要保证 bs = 1
|
86 |
+
utt_id = item_names[i]
|
87 |
+
if i == 0:
|
88 |
+
print("utt_id:", utt_id)
|
89 |
+
# bs > 1 时会补零
|
90 |
+
# 与 validation_step() 保持一致
|
91 |
+
semantic_len = batch['semantic_ids'].size(1)
|
92 |
+
# 以 batch['semantic_ids'] 的前 150 个为 prompt
|
93 |
+
# 多次合成,前 prompt_len 个是一样的,而且和 prompt 一样
|
94 |
+
prompt_len = min(int(semantic_len * 0.5), 150)
|
95 |
+
# 输入纯文本时 prompt 该输入什么?=> see t2s.py
|
96 |
+
prompt = batch['semantic_ids'][:, :prompt_len]
|
97 |
+
# # zero prompt => 也可以输出文本内容正确的 semantic token, 但是音色是乱的
|
98 |
+
# 证明 semantic token 中还是包含了音色信息
|
99 |
+
# prompt = torch.ones(
|
100 |
+
# batch['semantic_ids'].size(0), 1, dtype=torch.int32) * 0
|
101 |
+
# print("prompt:", prompt)
|
102 |
+
# print("prompt.shape:", prompt.shape)
|
103 |
+
np.save(output_dir / 'prompt.npy', prompt.detach().cpu().numpy())
|
104 |
+
|
105 |
+
st = time.time()
|
106 |
+
with torch.no_grad():
|
107 |
+
# calculate acc for test
|
108 |
+
loss, acc = t2s_model.model.forward(
|
109 |
+
batch['phoneme_ids'].cuda(),
|
110 |
+
batch['phoneme_ids_len'].cuda(),
|
111 |
+
batch['semantic_ids'].cuda(),
|
112 |
+
batch['semantic_ids_len'].cuda())
|
113 |
+
print("top_3_acc of this batch:", acc)
|
114 |
+
pred_semantic = t2s_model.model.infer(
|
115 |
+
batch['phoneme_ids'].cuda(),
|
116 |
+
batch['phoneme_ids_len'].cuda(),
|
117 |
+
prompt.cuda(),
|
118 |
+
top_k=config['inference']['top_k'],
|
119 |
+
# hz * max_sec in train dataloader
|
120 |
+
# 生成的长度是 1002 应该是有一些 pad
|
121 |
+
early_stop_num=hz * max_sec)
|
122 |
+
# bs = 1
|
123 |
+
pred_semantic = pred_semantic[0]
|
124 |
+
print(f'{time.time() - st} sec used in T2S')
|
125 |
+
semantic_token = pred_semantic.detach().cpu().numpy().tolist()
|
126 |
+
semantic_token_str = ' '.join(str(x) for x in semantic_token)
|
127 |
+
semantic_data.append([utt_id, semantic_token_str])
|
128 |
+
else:
|
129 |
+
break
|
130 |
+
delimiter = '\t'
|
131 |
+
filename = output_dir / "semantic_token.tsv"
|
132 |
+
with open(filename, 'w', encoding='utf-8') as writer:
|
133 |
+
for row in semantic_data:
|
134 |
+
line = delimiter.join(row)
|
135 |
+
writer.write(line + '\n')
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
main()
|
AR/exps/text.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
001 Life was like a box of chocolates, you never know what you're gonna get.
|
2 |
+
002 With great power there must come great responsibility.
|
3 |
+
003 To be or not to be, that’s a question.
|
4 |
+
004 A man can be destroyed but not defeated
|
5 |
+
005 Do not, for one repulse, give up the purpose that you resolved to effort.
|
6 |
+
006 Death is just a part of life, something we're all destined to do.
|
7 |
+
007 I think it's hard winning a war with words.
|
8 |
+
008 Don’t argue with the people of strong determination, because they may change the fact!
|
9 |
+
009 Love you three thousand times.
|
10 |
+
010 tidy tiger tied a tie tighter to tidy her tiny tall.
|
AR/exps/train.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
|
2 |
+
import argparse
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
from pytorch_lightning import Trainer
|
10 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
11 |
+
from pytorch_lightning.loggers import WandbLogger
|
12 |
+
from pytorch_lightning.strategies import DDPStrategy
|
13 |
+
from AR.data.data_module import Text2SemanticDataModule
|
14 |
+
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
15 |
+
from soundstorm.utils.io import load_yaml_config
|
16 |
+
logging.getLogger('numba').setLevel(logging.WARNING)
|
17 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
18 |
+
torch.set_float32_matmul_precision('high')
|
19 |
+
from soundstorm.utils import get_newest_ckpt
|
20 |
+
|
21 |
+
|
22 |
+
def main(args):
|
23 |
+
output_dir = Path(args.output_dir)
|
24 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
25 |
+
|
26 |
+
ckpt_dir = output_dir / 'ckpt'
|
27 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
28 |
+
|
29 |
+
config = load_yaml_config(args.config_file)
|
30 |
+
|
31 |
+
seed_everything(config["train"]["seed"], workers=True)
|
32 |
+
ckpt_callback: ModelCheckpoint = ModelCheckpoint(
|
33 |
+
save_top_k=-1,
|
34 |
+
save_on_train_epoch_end=False,
|
35 |
+
every_n_epochs=config["train"]["save_every_n_epoch"],
|
36 |
+
dirpath=ckpt_dir)
|
37 |
+
logger = WandbLogger(
|
38 |
+
project="AR_S1",
|
39 |
+
name=output_dir.stem,
|
40 |
+
save_dir=output_dir,
|
41 |
+
# resume the loss curve
|
42 |
+
resume=True,
|
43 |
+
# id='k19kvsq8'
|
44 |
+
)
|
45 |
+
trainer: Trainer = Trainer(
|
46 |
+
max_epochs=config["train"]["epochs"],
|
47 |
+
accelerator='gpu',
|
48 |
+
devices=-1,
|
49 |
+
benchmark=False,
|
50 |
+
fast_dev_run=False,
|
51 |
+
strategy=DDPStrategy(find_unused_parameters=True),
|
52 |
+
precision=config["train"]["precision"],
|
53 |
+
logger=logger,
|
54 |
+
callbacks=[ckpt_callback])
|
55 |
+
|
56 |
+
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
|
57 |
+
config, output_dir)
|
58 |
+
|
59 |
+
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
60 |
+
config,
|
61 |
+
train_semantic_path=args.train_semantic_path,
|
62 |
+
train_phoneme_path=args.train_phoneme_path,
|
63 |
+
dev_semantic_path=args.dev_semantic_path,
|
64 |
+
dev_phoneme_path=args.dev_phoneme_path)
|
65 |
+
|
66 |
+
try:
|
67 |
+
# 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
|
68 |
+
newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
|
69 |
+
ckpt_path = ckpt_dir / newest_ckpt_name
|
70 |
+
except Exception:
|
71 |
+
ckpt_path = None
|
72 |
+
print("ckpt_path:", ckpt_path)
|
73 |
+
trainer.fit(model, data_module, ckpt_path=ckpt_path)
|
74 |
+
|
75 |
+
|
76 |
+
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
|
77 |
+
if __name__ == '__main__':
|
78 |
+
parser = argparse.ArgumentParser()
|
79 |
+
parser.add_argument(
|
80 |
+
'--config_file',
|
81 |
+
type=str,
|
82 |
+
default='conf/default.yaml',
|
83 |
+
help='path of config file')
|
84 |
+
# args for dataset
|
85 |
+
parser.add_argument(
|
86 |
+
'--train_semantic_path',
|
87 |
+
type=str,
|
88 |
+
default='dump/train/semantic_token.tsv')
|
89 |
+
parser.add_argument(
|
90 |
+
'--train_phoneme_path', type=str, default='dump/train/phonemes.npy')
|
91 |
+
parser.add_argument(
|
92 |
+
'--dev_semantic_path', type=str, default='dump/dev/semantic_token.tsv')
|
93 |
+
parser.add_argument(
|
94 |
+
'--dev_phoneme_path', type=str, default='dump/dev/phonemes.npy')
|
95 |
+
parser.add_argument(
|
96 |
+
'--output_dir',
|
97 |
+
type=str,
|
98 |
+
default='exp/default',
|
99 |
+
help='directory to save the results')
|
100 |
+
|
101 |
+
args = parser.parse_args()
|
102 |
+
logging.info(str(args))
|
103 |
+
main(args)
|
AR/exps/train_librilight_6k.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
|
2 |
+
import argparse
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
from pytorch_lightning import Trainer
|
10 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
11 |
+
from pytorch_lightning.loggers import WandbLogger
|
12 |
+
from pytorch_lightning.strategies import DDPStrategy
|
13 |
+
from AR.data.data_module_librilight_6k import Text2SemanticDataModule
|
14 |
+
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
15 |
+
from soundstorm.utils import get_newest_ckpt
|
16 |
+
from soundstorm.utils.io import load_yaml_config
|
17 |
+
|
18 |
+
logging.getLogger('numba').setLevel(logging.WARNING)
|
19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
+
torch.set_float32_matmul_precision('high')
|
21 |
+
|
22 |
+
|
23 |
+
def main(args):
|
24 |
+
output_dir = Path(args.output_dir)
|
25 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
26 |
+
|
27 |
+
ckpt_dir = output_dir / 'ckpt'
|
28 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
29 |
+
|
30 |
+
config = load_yaml_config(args.config_file)
|
31 |
+
|
32 |
+
seed_everything(config["train"]["seed"], workers=True)
|
33 |
+
|
34 |
+
ckpt_callback: ModelCheckpoint = ModelCheckpoint(
|
35 |
+
save_top_k=-1,
|
36 |
+
save_on_train_epoch_end=False,
|
37 |
+
every_n_train_steps=config["train"]["every_n_train_steps"],
|
38 |
+
dirpath=ckpt_dir)
|
39 |
+
logger = WandbLogger(
|
40 |
+
project="AR_S1_LibriLight",
|
41 |
+
name=output_dir.stem,
|
42 |
+
save_dir=output_dir,
|
43 |
+
# resume the loss curve
|
44 |
+
resume=True,
|
45 |
+
# id='k19kvsq8'
|
46 |
+
)
|
47 |
+
trainer: Trainer = Trainer(
|
48 |
+
max_epochs=config["train"]["epochs"],
|
49 |
+
accelerator='gpu',
|
50 |
+
devices=-1,
|
51 |
+
benchmark=False,
|
52 |
+
fast_dev_run=False,
|
53 |
+
strategy=DDPStrategy(find_unused_parameters=True),
|
54 |
+
precision=config["train"]["precision"],
|
55 |
+
logger=logger,
|
56 |
+
callbacks=[ckpt_callback])
|
57 |
+
|
58 |
+
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
|
59 |
+
config, output_dir)
|
60 |
+
|
61 |
+
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
62 |
+
config,
|
63 |
+
train_semantic_dirs=args.train_semantic_dirs,
|
64 |
+
train_phoneme_dirs=args.train_phoneme_dirs,
|
65 |
+
dev_semantic_dirs=args.dev_semantic_dirs,
|
66 |
+
dev_phoneme_dirs=args.dev_phoneme_dirs,
|
67 |
+
train_non_speech_dirs=args.train_non_speech_dirs,
|
68 |
+
dev_non_speech_dirs=args.dev_non_speech_dirs)
|
69 |
+
try:
|
70 |
+
newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
|
71 |
+
ckpt_path = ckpt_dir / newest_ckpt_name
|
72 |
+
except Exception:
|
73 |
+
ckpt_path = None
|
74 |
+
|
75 |
+
print("ckpt_path:", ckpt_path)
|
76 |
+
trainer.fit(model, data_module, ckpt_path=ckpt_path)
|
77 |
+
|
78 |
+
|
79 |
+
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
|
80 |
+
if __name__ == '__main__':
|
81 |
+
parser = argparse.ArgumentParser()
|
82 |
+
parser.add_argument(
|
83 |
+
'--config_file',
|
84 |
+
type=str,
|
85 |
+
default='conf/default.yaml',
|
86 |
+
help='path of config file')
|
87 |
+
# args for dataset
|
88 |
+
parser.add_argument(
|
89 |
+
'--train_semantic_dirs',
|
90 |
+
type=list,
|
91 |
+
nargs='+',
|
92 |
+
default=["dump/small/train/"],
|
93 |
+
help='dirs of train semantic')
|
94 |
+
parser.add_argument(
|
95 |
+
'--train_phoneme_dirs',
|
96 |
+
type=list,
|
97 |
+
nargs='+',
|
98 |
+
default=["dump/small/train/"],
|
99 |
+
help='dirs of train phoneme')
|
100 |
+
parser.add_argument(
|
101 |
+
'--dev_semantic_dirs',
|
102 |
+
type=list,
|
103 |
+
nargs='+',
|
104 |
+
default=["dump/small/dev/"],
|
105 |
+
help='dirs of dev semantic')
|
106 |
+
parser.add_argument(
|
107 |
+
'--dev_phoneme_dirs',
|
108 |
+
type=list,
|
109 |
+
nargs='+',
|
110 |
+
default=["dump/small/dev/"],
|
111 |
+
help='dirs of dev phoneme')
|
112 |
+
parser.add_argument(
|
113 |
+
'--output_dir',
|
114 |
+
type=str,
|
115 |
+
default='exp/default',
|
116 |
+
help='directory to save the results')
|
117 |
+
|
118 |
+
parser.add_argument(
|
119 |
+
'--train_non_speech_dirs',
|
120 |
+
type=list,
|
121 |
+
nargs='+',
|
122 |
+
default=None,
|
123 |
+
help='dirs of train non_speech data')
|
124 |
+
|
125 |
+
parser.add_argument(
|
126 |
+
'--dev_non_speech_dirs',
|
127 |
+
type=list,
|
128 |
+
nargs='+',
|
129 |
+
default=None,
|
130 |
+
help='dirs of dev non_speech data')
|
131 |
+
|
132 |
+
args = parser.parse_args()
|
133 |
+
|
134 |
+
new_train_semantic_dirs = []
|
135 |
+
new_train_phoneme_dirs = []
|
136 |
+
new_dev_semantic_dirs = []
|
137 |
+
new_dev_phoneme_dirs = []
|
138 |
+
|
139 |
+
new_train_non_speech_dirs = []
|
140 |
+
new_dev_non_speech_dirs = []
|
141 |
+
|
142 |
+
# format dataset dirs
|
143 |
+
for item in args.train_semantic_dirs:
|
144 |
+
new_train_semantic_dirs.append(''.join(item))
|
145 |
+
args.train_semantic_dirs = new_train_semantic_dirs
|
146 |
+
|
147 |
+
for item in args.train_phoneme_dirs:
|
148 |
+
new_train_phoneme_dirs.append(''.join(item))
|
149 |
+
args.train_phoneme_dirs = new_train_phoneme_dirs
|
150 |
+
|
151 |
+
for item in args.dev_semantic_dirs:
|
152 |
+
new_dev_semantic_dirs.append(''.join(item))
|
153 |
+
args.dev_semantic_dirs = new_dev_semantic_dirs
|
154 |
+
|
155 |
+
for item in args.dev_phoneme_dirs:
|
156 |
+
new_dev_phoneme_dirs.append(''.join(item))
|
157 |
+
args.dev_phoneme_dirs = new_dev_phoneme_dirs
|
158 |
+
|
159 |
+
if args.train_non_speech_dirs is not None:
|
160 |
+
for item in args.train_non_speech_dirs:
|
161 |
+
new_train_non_speech_dirs.append(''.join(item))
|
162 |
+
args.train_non_speech_dirs = new_train_non_speech_dirs
|
163 |
+
|
164 |
+
if args.dev_non_speech_dirs is not None:
|
165 |
+
for item in args.dev_non_speech_dirs:
|
166 |
+
new_dev_non_speech_dirs.append(''.join(item))
|
167 |
+
args.dev_non_speech_dirs = new_dev_non_speech_dirs
|
168 |
+
|
169 |
+
logging.info(str(args))
|
170 |
+
main(args)
|
AR/models/__init__.py
ADDED
File without changes
|
AR/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (142 Bytes). View file
|
|
AR/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (145 Bytes). View file
|
|
AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc
ADDED
Binary file (3.14 kB). View file
|
|
AR/models/__pycache__/t2s_lightning_module.cpython-39.pyc
ADDED
Binary file (3.15 kB). View file
|
|
AR/models/__pycache__/t2s_model.cpython-310.pyc
ADDED
Binary file (6.84 kB). View file
|
|
AR/models/__pycache__/t2s_model.cpython-39.pyc
ADDED
Binary file (6.84 kB). View file
|
|
AR/models/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (4.5 kB). View file
|
|
AR/models/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (4.48 kB). View file
|
|
AR/models/t2s_lightning_module.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
|
2 |
+
import os,sys
|
3 |
+
now_dir = os.getcwd()
|
4 |
+
sys.path.append(now_dir)
|
5 |
+
from typing import Dict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from pytorch_lightning import LightningModule
|
9 |
+
from AR.models.t2s_model import Text2SemanticDecoder
|
10 |
+
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
11 |
+
from AR.modules.optim import ScaledAdam
|
12 |
+
|
13 |
+
|
14 |
+
class Text2SemanticLightningModule(LightningModule):
|
15 |
+
def __init__(self, config, output_dir,is_train=True):
|
16 |
+
super().__init__()
|
17 |
+
self.config = config
|
18 |
+
self.top_k = 3
|
19 |
+
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
|
20 |
+
pretrained_s1=config.get("pretrained_s1")
|
21 |
+
if(pretrained_s1 and is_train):
|
22 |
+
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
23 |
+
print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"]))
|
24 |
+
if is_train:
|
25 |
+
self.automatic_optimization = False
|
26 |
+
self.save_hyperparameters()
|
27 |
+
self.eval_dir = output_dir / 'eval'
|
28 |
+
self.eval_dir.mkdir(parents=True, exist_ok=True)
|
29 |
+
|
30 |
+
def training_step(self, batch: Dict, batch_idx: int):
|
31 |
+
|
32 |
+
opt = self.optimizers()
|
33 |
+
scheduler = self.lr_schedulers()
|
34 |
+
loss, acc = self.model.forward(
|
35 |
+
batch['phoneme_ids'], batch['phoneme_ids_len'],
|
36 |
+
batch['semantic_ids'], batch['semantic_ids_len'],
|
37 |
+
batch['bert_feature'])
|
38 |
+
self.manual_backward(loss)
|
39 |
+
if batch_idx > 0 and batch_idx % 4 == 0:
|
40 |
+
opt.step()
|
41 |
+
opt.zero_grad()
|
42 |
+
scheduler.step()
|
43 |
+
|
44 |
+
self.log(
|
45 |
+
"total_loss",
|
46 |
+
loss,
|
47 |
+
on_step=True,
|
48 |
+
on_epoch=True,
|
49 |
+
prog_bar=True,
|
50 |
+
sync_dist=True)
|
51 |
+
self.log(
|
52 |
+
"lr",
|
53 |
+
scheduler.get_last_lr()[0],
|
54 |
+
on_epoch=True,
|
55 |
+
prog_bar=True,
|
56 |
+
sync_dist=True)
|
57 |
+
self.log(
|
58 |
+
f"top_{self.top_k}_acc",
|
59 |
+
acc,
|
60 |
+
on_step=True,
|
61 |
+
on_epoch=True,
|
62 |
+
prog_bar=True,
|
63 |
+
sync_dist=True)
|
64 |
+
|
65 |
+
def validation_step(self, batch: Dict, batch_idx: int):return
|
66 |
+
# # get loss
|
67 |
+
# loss, acc = self.model.forward(
|
68 |
+
# batch['phoneme_ids'], batch['phoneme_ids_len'],
|
69 |
+
# batch['semantic_ids'], batch['semantic_ids_len'],
|
70 |
+
# batch['bert_feature']
|
71 |
+
# )
|
72 |
+
#
|
73 |
+
# self.log(
|
74 |
+
# "val_total_loss",
|
75 |
+
# loss,
|
76 |
+
# on_step=True,
|
77 |
+
# on_epoch=True,
|
78 |
+
# prog_bar=True,
|
79 |
+
# sync_dist=True)
|
80 |
+
# self.log(
|
81 |
+
# f"val_top_{self.top_k}_acc",
|
82 |
+
# acc,
|
83 |
+
# on_step=True,
|
84 |
+
# on_epoch=True,
|
85 |
+
# prog_bar=True,
|
86 |
+
# sync_dist=True)
|
87 |
+
#
|
88 |
+
# # get infer output
|
89 |
+
# semantic_len = batch['semantic_ids'].size(1)
|
90 |
+
# prompt_len = min(int(semantic_len * 0.5), 150)
|
91 |
+
# prompt = batch['semantic_ids'][:, :prompt_len]
|
92 |
+
# pred_semantic = self.model.infer(batch['phoneme_ids'],
|
93 |
+
# batch['phoneme_ids_len'], prompt,
|
94 |
+
# batch['bert_feature']
|
95 |
+
# )
|
96 |
+
# save_name = f'semantic_toks_{batch_idx}.pt'
|
97 |
+
# save_path = os.path.join(self.eval_dir, save_name)
|
98 |
+
# torch.save(pred_semantic.detach().cpu(), save_path)
|
99 |
+
|
100 |
+
def configure_optimizers(self):
|
101 |
+
model_parameters = self.model.parameters()
|
102 |
+
parameters_names = []
|
103 |
+
parameters_names.append([
|
104 |
+
name_param_pair[0]
|
105 |
+
for name_param_pair in self.model.named_parameters()
|
106 |
+
])
|
107 |
+
lm_opt = ScaledAdam(
|
108 |
+
model_parameters,
|
109 |
+
lr=0.01,
|
110 |
+
betas=(0.9, 0.95),
|
111 |
+
clipping_scale=2.0,
|
112 |
+
parameters_names=parameters_names,
|
113 |
+
show_dominant_parameters=False,
|
114 |
+
clipping_update_period=1000, )
|
115 |
+
|
116 |
+
return {
|
117 |
+
"optimizer": lm_opt,
|
118 |
+
"lr_scheduler": {
|
119 |
+
"scheduler":
|
120 |
+
WarmupCosineLRSchedule(
|
121 |
+
lm_opt,
|
122 |
+
init_lr=self.config['optimizer']['lr_init'],
|
123 |
+
peak_lr=self.config['optimizer']['lr'],
|
124 |
+
end_lr=self.config['optimizer']['lr_end'],
|
125 |
+
warmup_steps=self.config['optimizer']['warmup_steps'],
|
126 |
+
total_steps=self.config['optimizer']['decay_steps'])
|
127 |
+
}
|
128 |
+
}
|
AR/models/t2s_model.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from AR.models.utils import make_pad_mask
|
6 |
+
from AR.models.utils import topk_sampling,sample,logits_to_probs,multinomial_sample_one_no_sync
|
7 |
+
from AR.modules.embedding import SinePositionalEmbedding
|
8 |
+
from AR.modules.embedding import TokenEmbedding
|
9 |
+
from AR.modules.transformer import LayerNorm
|
10 |
+
from AR.modules.transformer import TransformerEncoder
|
11 |
+
from AR.modules.transformer import TransformerEncoderLayer
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
from torchmetrics.classification import MulticlassAccuracy
|
15 |
+
|
16 |
+
default_config = {
|
17 |
+
"embedding_dim": 512,
|
18 |
+
"hidden_dim": 512,
|
19 |
+
"num_head": 8,
|
20 |
+
"num_layers": 12,
|
21 |
+
"num_codebook": 8,
|
22 |
+
"p_dropout": 0.0,
|
23 |
+
"vocab_size": 1024 + 1,
|
24 |
+
"phoneme_vocab_size": 512,
|
25 |
+
"EOS": 1024
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
class Text2SemanticDecoder(nn.Module):
|
30 |
+
def __init__(self, config, norm_first=False, top_k=3):
|
31 |
+
super(Text2SemanticDecoder, self).__init__()
|
32 |
+
self.model_dim = config['model']["hidden_dim"]
|
33 |
+
self.embedding_dim = config['model']["embedding_dim"]
|
34 |
+
self.num_head = config['model']["head"]
|
35 |
+
self.num_layers = config['model']["n_layer"]
|
36 |
+
self.norm_first = norm_first
|
37 |
+
self.vocab_size = config['model']["vocab_size"]
|
38 |
+
self.phoneme_vocab_size = config['model']["phoneme_vocab_size"]
|
39 |
+
self.p_dropout = config['model']["dropout"]
|
40 |
+
self.EOS = config['model']["EOS"]
|
41 |
+
self.norm_first = norm_first
|
42 |
+
assert self.EOS == self.vocab_size - 1
|
43 |
+
# should be same as num of kmeans bin
|
44 |
+
# assert self.EOS == 1024
|
45 |
+
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
46 |
+
self.ar_text_embedding = TokenEmbedding(
|
47 |
+
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
|
48 |
+
self.ar_text_position = SinePositionalEmbedding(
|
49 |
+
self.embedding_dim, dropout=0.1, scale=False, alpha=True)
|
50 |
+
self.ar_audio_embedding = TokenEmbedding(
|
51 |
+
self.embedding_dim, self.vocab_size, self.p_dropout)
|
52 |
+
self.ar_audio_position = SinePositionalEmbedding(
|
53 |
+
self.embedding_dim, dropout=0.1, scale=False, alpha=True)
|
54 |
+
|
55 |
+
self.h = TransformerEncoder(
|
56 |
+
TransformerEncoderLayer(
|
57 |
+
d_model=self.model_dim,
|
58 |
+
nhead=self.num_head,
|
59 |
+
dim_feedforward=self.model_dim * 4,
|
60 |
+
dropout=0.1,
|
61 |
+
batch_first=True,
|
62 |
+
norm_first=norm_first, ),
|
63 |
+
num_layers=self.num_layers,
|
64 |
+
norm=LayerNorm(self.model_dim) if norm_first else None, )
|
65 |
+
|
66 |
+
self.ar_predict_layer = nn.Linear(
|
67 |
+
self.model_dim, self.vocab_size, bias=False)
|
68 |
+
self.loss_fct = nn.CrossEntropyLoss(reduction='sum')
|
69 |
+
|
70 |
+
self.ar_accuracy_metric = MulticlassAccuracy(
|
71 |
+
self.vocab_size,
|
72 |
+
top_k=top_k,
|
73 |
+
average="micro",
|
74 |
+
multidim_average="global",
|
75 |
+
ignore_index=self.EOS, )
|
76 |
+
|
77 |
+
def forward(self, x, x_lens, y, y_lens, bert_feature):
|
78 |
+
'''
|
79 |
+
x: phoneme_ids
|
80 |
+
y: semantic_ids
|
81 |
+
'''
|
82 |
+
x = self.ar_text_embedding(x)
|
83 |
+
x = x + self.bert_proj(bert_feature.transpose(1,2))
|
84 |
+
x = self.ar_text_position(x)
|
85 |
+
x_mask = make_pad_mask(x_lens)
|
86 |
+
|
87 |
+
y_mask = make_pad_mask(y_lens)
|
88 |
+
y_mask_int = y_mask.type(torch.int64)
|
89 |
+
codes = y.type(torch.int64) * (1 - y_mask_int)
|
90 |
+
|
91 |
+
# Training
|
92 |
+
# AR Decoder
|
93 |
+
y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
|
94 |
+
x_len = x_lens.max()
|
95 |
+
y_len = y_lens.max()
|
96 |
+
y_emb = self.ar_audio_embedding(y)
|
97 |
+
y_pos = self.ar_audio_position(y_emb)
|
98 |
+
|
99 |
+
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
100 |
+
ar_xy_padding_mask = xy_padding_mask
|
101 |
+
|
102 |
+
x_attn_mask = F.pad(
|
103 |
+
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
|
104 |
+
(0, y_len),
|
105 |
+
value=True, )
|
106 |
+
y_attn_mask = F.pad(
|
107 |
+
torch.triu(
|
108 |
+
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
109 |
+
diagonal=1, ),
|
110 |
+
(x_len, 0),
|
111 |
+
value=False, )
|
112 |
+
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
|
113 |
+
bsz, src_len = x.shape[0], x_len + y_len
|
114 |
+
_xy_padding_mask = (ar_xy_padding_mask.view(bsz, 1, 1, src_len)
|
115 |
+
.expand(-1, self.num_head, -1, -1)
|
116 |
+
.reshape(bsz * self.num_head, 1, src_len))
|
117 |
+
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
118 |
+
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
119 |
+
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
120 |
+
xy_attn_mask = new_attn_mask
|
121 |
+
# x 和完整的 y 一次性输入模型
|
122 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
123 |
+
xy_dec, _ = self.h(
|
124 |
+
(xy_pos, None),
|
125 |
+
mask=xy_attn_mask, )
|
126 |
+
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
|
127 |
+
# loss
|
128 |
+
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
|
129 |
+
loss = F.cross_entropy(logits, targets, reduction='sum')
|
130 |
+
acc = self.ar_accuracy_metric(logits.detach(), targets).item()
|
131 |
+
return loss, acc
|
132 |
+
|
133 |
+
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
|
134 |
+
def infer(self,
|
135 |
+
x,
|
136 |
+
x_lens,
|
137 |
+
prompts,
|
138 |
+
bert_feature,
|
139 |
+
top_k: int=-100,
|
140 |
+
early_stop_num: int=-1,
|
141 |
+
temperature: float=1.0):
|
142 |
+
|
143 |
+
x = self.ar_text_embedding(x)
|
144 |
+
x = x + self.bert_proj(bert_feature.transpose(1,2))
|
145 |
+
x = self.ar_text_position(x)
|
146 |
+
|
147 |
+
# AR Decoder
|
148 |
+
y = prompts
|
149 |
+
prefix_len = y.shape[1]
|
150 |
+
x_len = x.shape[1]
|
151 |
+
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
152 |
+
stop = False
|
153 |
+
for _ in tqdm(range(1500)):
|
154 |
+
y_emb = self.ar_audio_embedding(y)
|
155 |
+
y_pos = self.ar_audio_position(y_emb)
|
156 |
+
# x 和逐渐增长的 y 一起输入给模型
|
157 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
158 |
+
y_len = y.shape[1]
|
159 |
+
x_attn_mask_pad = F.pad(
|
160 |
+
x_attn_mask,
|
161 |
+
(0, y_len),
|
162 |
+
value=True, )
|
163 |
+
y_attn_mask = F.pad(
|
164 |
+
torch.triu(
|
165 |
+
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
166 |
+
(x_len, 0),
|
167 |
+
value=False, )
|
168 |
+
xy_attn_mask = torch.concat(
|
169 |
+
[x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
|
170 |
+
|
171 |
+
xy_dec, _ = self.h(
|
172 |
+
(xy_pos, None),
|
173 |
+
mask=xy_attn_mask, )
|
174 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
175 |
+
samples = topk_sampling(
|
176 |
+
logits, top_k=top_k, top_p=1.0, temperature=temperature)
|
177 |
+
|
178 |
+
if early_stop_num != -1 and (y.shape[1] - prefix_len
|
179 |
+
) > early_stop_num:
|
180 |
+
print("use early stop num:", early_stop_num)
|
181 |
+
stop = True
|
182 |
+
|
183 |
+
if torch.argmax(
|
184 |
+
logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
185 |
+
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
186 |
+
stop = True
|
187 |
+
if stop:
|
188 |
+
if prompts.shape[1] == y.shape[1]:
|
189 |
+
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
190 |
+
print('bad zero prediction')
|
191 |
+
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
192 |
+
break
|
193 |
+
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
|
194 |
+
# print(samples.shape)#[1,1]#第一个1是bs
|
195 |
+
# import os
|
196 |
+
# os._exit(2333)
|
197 |
+
y = torch.concat([y, samples], dim=1)
|
198 |
+
return y
|
199 |
+
|
200 |
+
def pad_y_eos(self, y, y_mask_int, eos_id):
|
201 |
+
targets = F.pad(
|
202 |
+
y, (0, 1), value=0) + eos_id * F.pad(
|
203 |
+
y_mask_int, (0, 1), value=1)
|
204 |
+
# 错位
|
205 |
+
return targets[:, :-1], targets[:, 1:]
|
206 |
+
|
207 |
+
def infer_panel(self,
|
208 |
+
x,#####全部文本token
|
209 |
+
x_lens,
|
210 |
+
prompts,####参考音频token
|
211 |
+
bert_feature,
|
212 |
+
top_k: int=-100,
|
213 |
+
early_stop_num: int=-1,
|
214 |
+
temperature: float=1.0):
|
215 |
+
|
216 |
+
x = self.ar_text_embedding(x)
|
217 |
+
x = x + self.bert_proj(bert_feature.transpose(1,2))
|
218 |
+
x = self.ar_text_position(x)
|
219 |
+
|
220 |
+
# AR Decoder
|
221 |
+
y = prompts
|
222 |
+
prefix_len = y.shape[1]
|
223 |
+
x_len = x.shape[1]
|
224 |
+
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
225 |
+
stop = False
|
226 |
+
# print(1111111,self.num_layers)
|
227 |
+
cache={
|
228 |
+
"all_stage":self.num_layers,
|
229 |
+
"k":[None]*self.num_layers,###根据配置自己手写
|
230 |
+
"v":[None]*self.num_layers,
|
231 |
+
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
|
232 |
+
"y_emb":None,##只需要对最新的samples求emb,再拼历史的就行
|
233 |
+
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
|
234 |
+
# "xy_dec":None,###不需要,本来只需要最后一个做logits
|
235 |
+
"first_infer":1,
|
236 |
+
"stage":0
|
237 |
+
}
|
238 |
+
for idx in tqdm(range(1500)):
|
239 |
+
if(cache["first_infer"]==1):
|
240 |
+
y_emb = self.ar_audio_embedding(y)
|
241 |
+
else:
|
242 |
+
y_emb = torch.cat([cache["y_emb"],self.ar_audio_embedding(y[:,-1:])],1)
|
243 |
+
cache["y_emb"]=y_emb
|
244 |
+
y_pos = self.ar_audio_position(y_emb)
|
245 |
+
# x 和逐渐增长的 y 一起输入给模型
|
246 |
+
if(cache["first_infer"]==1):
|
247 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
248 |
+
else:
|
249 |
+
xy_pos=y_pos[:,-1:]
|
250 |
+
y_len = y_pos.shape[1]
|
251 |
+
###以下3个不做缓存
|
252 |
+
if (cache["first_infer"] == 1):
|
253 |
+
x_attn_mask_pad = F.pad(
|
254 |
+
x_attn_mask,
|
255 |
+
(0, y_len),###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
256 |
+
value=True, )
|
257 |
+
y_attn_mask = F.pad(###yy的右上1扩展到左边xy的0,(y,x+y)
|
258 |
+
torch.triu(
|
259 |
+
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
260 |
+
(x_len, 0),
|
261 |
+
value=False, )
|
262 |
+
xy_attn_mask = torch.concat(
|
263 |
+
[x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
|
264 |
+
else:
|
265 |
+
###最右边一列(是错的)
|
266 |
+
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
|
267 |
+
# xy_attn_mask[:,-1]=False
|
268 |
+
###最下面一行(是对的)
|
269 |
+
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool, device=xy_pos.device)
|
270 |
+
# pdb.set_trace()
|
271 |
+
###缓存重头戏
|
272 |
+
# print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
|
273 |
+
xy_dec, _ = self.h(
|
274 |
+
(xy_pos, None),
|
275 |
+
mask=xy_attn_mask,cache=cache )
|
276 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
|
277 |
+
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
|
278 |
+
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
279 |
+
if early_stop_num != -1 and (y.shape[1] - prefix_len
|
280 |
+
) > early_stop_num:
|
281 |
+
print("use early stop num:", early_stop_num)
|
282 |
+
stop = True
|
283 |
+
|
284 |
+
if torch.argmax(
|
285 |
+
logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
286 |
+
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
287 |
+
stop = True
|
288 |
+
if stop:
|
289 |
+
if prompts.shape[1] == y.shape[1]:
|
290 |
+
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
291 |
+
print('bad zero prediction')
|
292 |
+
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
293 |
+
break
|
294 |
+
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
|
295 |
+
# print(samples.shape)#[1,1]#第一个1是bs
|
296 |
+
y = torch.concat([y, samples], dim=1)
|
297 |
+
cache["first_infer"]=0
|
298 |
+
return y,idx
|
AR/models/utils.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchaudio
|
5 |
+
|
6 |
+
|
7 |
+
def sequence_mask(length, max_length=None):
|
8 |
+
if max_length is None:
|
9 |
+
max_length = length.max()
|
10 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
11 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
12 |
+
|
13 |
+
|
14 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor:
|
15 |
+
"""
|
16 |
+
Args:
|
17 |
+
lengths:
|
18 |
+
A 1-D tensor containing sentence lengths.
|
19 |
+
max_len:
|
20 |
+
The length of masks.
|
21 |
+
Returns:
|
22 |
+
Return a 2-D bool tensor, where masked positions
|
23 |
+
are filled with `True` and non-masked positions are
|
24 |
+
filled with `False`.
|
25 |
+
|
26 |
+
#>>> lengths = torch.tensor([1, 3, 2, 5])
|
27 |
+
#>>> make_pad_mask(lengths)
|
28 |
+
tensor([[False, True, True, True, True],
|
29 |
+
[False, False, False, True, True],
|
30 |
+
[False, False, True, True, True],
|
31 |
+
[False, False, False, False, False]])
|
32 |
+
"""
|
33 |
+
assert lengths.ndim == 1, lengths.ndim
|
34 |
+
max_len = max(max_len, lengths.max())
|
35 |
+
n = lengths.size(0)
|
36 |
+
seq_range = torch.arange(0, max_len, device=lengths.device)
|
37 |
+
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
38 |
+
|
39 |
+
return expaned_lengths >= lengths.unsqueeze(-1)
|
40 |
+
|
41 |
+
|
42 |
+
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
43 |
+
def top_k_top_p_filtering(logits,
|
44 |
+
top_k=0,
|
45 |
+
top_p=1.0,
|
46 |
+
filter_value=-float("Inf"),
|
47 |
+
min_tokens_to_keep=1):
|
48 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
49 |
+
Args:
|
50 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
51 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
52 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
53 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
54 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
55 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
56 |
+
"""
|
57 |
+
if top_k > 0:
|
58 |
+
top_k = min(max(top_k, min_tokens_to_keep),
|
59 |
+
logits.size(-1)) # Safety check
|
60 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
61 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
62 |
+
logits[indices_to_remove] = filter_value
|
63 |
+
|
64 |
+
if top_p < 1.0:
|
65 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
66 |
+
cumulative_probs = torch.cumsum(
|
67 |
+
F.softmax(sorted_logits, dim=-1), dim=-1)
|
68 |
+
|
69 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
70 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
71 |
+
if min_tokens_to_keep > 1:
|
72 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
73 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
74 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
75 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
76 |
+
..., :-1].clone()
|
77 |
+
sorted_indices_to_remove[..., 0] = 0
|
78 |
+
|
79 |
+
# scatter sorted tensors to original indexing
|
80 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
81 |
+
1, sorted_indices, sorted_indices_to_remove)
|
82 |
+
logits[indices_to_remove] = filter_value
|
83 |
+
return logits
|
84 |
+
|
85 |
+
|
86 |
+
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
87 |
+
# temperature: (`optional`) float
|
88 |
+
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
89 |
+
# top_k: (`optional`) int
|
90 |
+
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
91 |
+
# top_p: (`optional`) float
|
92 |
+
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
93 |
+
|
94 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
95 |
+
if temperature != 1.0:
|
96 |
+
logits = logits / temperature
|
97 |
+
# Top-p/top-k filtering
|
98 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
99 |
+
# Sample
|
100 |
+
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
101 |
+
return token
|
102 |
+
|
103 |
+
|
104 |
+
from typing import Optional, Tuple
|
105 |
+
def multinomial_sample_one_no_sync(
|
106 |
+
probs_sort,
|
107 |
+
): # Does multinomial sampling without a cuda synchronization
|
108 |
+
q = torch.empty_like(probs_sort).exponential_(1)
|
109 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
110 |
+
|
111 |
+
|
112 |
+
def logits_to_probs(
|
113 |
+
logits,
|
114 |
+
previous_tokens: Optional[torch.Tensor] = None,
|
115 |
+
temperature: float = 1.0,
|
116 |
+
top_k: Optional[int] = None,
|
117 |
+
top_p: Optional[int] = None,
|
118 |
+
repetition_penalty: float = 1.0,
|
119 |
+
):
|
120 |
+
previous_tokens=previous_tokens.squeeze()
|
121 |
+
# print(logits.shape,previous_tokens.shape)
|
122 |
+
# pdb.set_trace()
|
123 |
+
if previous_tokens is not None and repetition_penalty != 1.0:
|
124 |
+
previous_tokens = previous_tokens.long()
|
125 |
+
score = torch.gather(logits, dim=0, index=previous_tokens)
|
126 |
+
score = torch.where(
|
127 |
+
score < 0, score * repetition_penalty, score / repetition_penalty
|
128 |
+
)
|
129 |
+
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
130 |
+
|
131 |
+
if top_p is not None and top_p < 1.0:
|
132 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
133 |
+
cum_probs = torch.cumsum(
|
134 |
+
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
135 |
+
)
|
136 |
+
sorted_indices_to_remove = cum_probs > top_p
|
137 |
+
sorted_indices_to_remove[0] = False # keep at least one option
|
138 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
139 |
+
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
140 |
+
)
|
141 |
+
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
142 |
+
|
143 |
+
logits = logits / max(temperature, 1e-5)
|
144 |
+
|
145 |
+
if top_k is not None:
|
146 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
147 |
+
pivot = v.select(-1, -1).unsqueeze(-1)
|
148 |
+
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
149 |
+
|
150 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
151 |
+
return probs
|
152 |
+
|
153 |
+
|
154 |
+
def sample(
|
155 |
+
logits,
|
156 |
+
previous_tokens: Optional[torch.Tensor] = None,
|
157 |
+
**sampling_kwargs,
|
158 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
159 |
+
probs = logits_to_probs(
|
160 |
+
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
|
161 |
+
)
|
162 |
+
idx_next = multinomial_sample_one_no_sync(probs)
|
163 |
+
return idx_next, probs
|
164 |
+
|
AR/modules/__init__.py
ADDED
File without changes
|
AR/modules/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (143 Bytes). View file
|
|
AR/modules/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (146 Bytes). View file
|
|