Update model_simple.py
Browse files- model_simple.py +39 -53
model_simple.py
CHANGED
@@ -21,6 +21,14 @@ dtype = torch.float32
|
|
21 |
warnings.filterwarnings("ignore")
|
22 |
logging.basicConfig(level=logging.ERROR)
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
@dataclass
|
25 |
class Dimensions:
|
26 |
vocab: int
|
@@ -63,13 +71,12 @@ class rotary(nn.Module):
|
|
63 |
|
64 |
class MultiheadA(nn.Module):
|
65 |
|
66 |
-
def __init__(self, dims: int, head: int
|
67 |
super(MultiheadA, self).__init__()
|
68 |
|
69 |
self.dims = dims
|
70 |
self.head = head
|
71 |
self.head_dim = dims // head
|
72 |
-
self.debug = debug
|
73 |
|
74 |
self.q = nn.Linear(dims, dims).to(device, dtype)
|
75 |
self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
|
@@ -119,7 +126,6 @@ class Residual(nn.Module):
|
|
119 |
self.ctx = ctx
|
120 |
self.head_dim = dims // head
|
121 |
|
122 |
-
|
123 |
self.blend = nn.Parameter(torch.tensor(0.5))
|
124 |
act_fn = get_activation(act)
|
125 |
self.attn = MultiheadA(dims, head)
|
@@ -144,65 +150,36 @@ class Residual(nn.Module):
|
|
144 |
x = x + gate * mlp_out
|
145 |
return x
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
super().__init__()
|
151 |
-
|
152 |
self.dims = dims
|
153 |
self.head = head
|
154 |
-
self.
|
|
|
|
|
155 |
self.dropout = 0.01
|
156 |
act_fn = get_activation(act)
|
157 |
|
|
|
|
|
|
|
|
|
|
|
158 |
# pitch
|
159 |
# self.encoder = nn.Sequential(
|
160 |
# Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
161 |
# Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
162 |
# Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
163 |
|
164 |
-
|
165 |
self.encoder = nn.Sequential(
|
166 |
Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
167 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
168 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
169 |
|
170 |
-
|
171 |
-
self.
|
172 |
-
self.norm = RMSNorm(dims)
|
173 |
-
|
174 |
-
def forward(self, x, xa=None, mask=None, max_tscale=36000):
|
175 |
-
if x.dim() == 2:
|
176 |
-
x = x.unsqueeze(0)
|
177 |
-
# x = self.pitch(x).permute(0, 2, 1)
|
178 |
-
x = self.encoder(x).permute(0, 2, 1)
|
179 |
-
max_tscale = x.shape[1] * 1000 if max_tscale is None else max_tscale
|
180 |
-
x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
|
181 |
-
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
182 |
-
x = self.norm(x)
|
183 |
-
return x
|
184 |
-
|
185 |
-
class processor(nn.Module):
|
186 |
-
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
187 |
-
super(processor, self).__init__()
|
188 |
-
self.dims = dims
|
189 |
-
self.head = head
|
190 |
-
self.layer = layer
|
191 |
-
self.ctx = ctx
|
192 |
-
self.act = act
|
193 |
-
self.dropout = 0.01
|
194 |
-
act_fn = get_activation(act)
|
195 |
-
|
196 |
-
self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
|
197 |
-
self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
|
198 |
-
self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
|
199 |
-
|
200 |
-
self.bA = nn.ModuleList(
|
201 |
-
[feature_encoder(mels=mels, dims=dims, head=head, layer=layer, act=act_fn)] +
|
202 |
-
[Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
|
203 |
-
self.bB = nn.ModuleList([
|
204 |
-
Residual(ctx=ctx, dims=dims, head=head, act=act_fn)
|
205 |
-
for _ in range(layer)])
|
206 |
|
207 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
208 |
self.register_buffer("mask", mask, persistent=False)
|
@@ -211,6 +188,9 @@ class processor(nn.Module):
|
|
211 |
def forward(self, x, xa, sequential=False) -> Tensor:
|
212 |
x = self.token(x.long()) + self.positional[:x.shape[1]]
|
213 |
|
|
|
|
|
|
|
214 |
for b in chain(self.bA or []):
|
215 |
xa = b(x=xa, xa=None, mask=None)
|
216 |
|
@@ -222,7 +202,17 @@ class processor(nn.Module):
|
|
222 |
else:
|
223 |
a = torch.sigmoid(self.blend)
|
224 |
x = a * xc + (1 - a) * x
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
x = self.norm(x)
|
227 |
x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
|
228 |
return x
|
@@ -252,10 +242,8 @@ class Echo(nn.Module):
|
|
252 |
enc= {}
|
253 |
if pitch is not None:
|
254 |
xa = pitch
|
255 |
-
enc["pitch"] = pitch
|
256 |
if spectrogram is not None:
|
257 |
xa = spectrogram
|
258 |
-
enc["spectrogram"] = spectrogram
|
259 |
|
260 |
x = input_ids
|
261 |
logits = self.processor(x, xa)
|
@@ -306,8 +294,6 @@ class Echo(nn.Module):
|
|
306 |
self.init_counts["MultiheadA"] += 1
|
307 |
elif isinstance(module, Residual):
|
308 |
self.init_counts["Residual"] += 1
|
309 |
-
elif isinstance(module, feature_encoder):
|
310 |
-
self.init_counts["feature_encoder"] += 1
|
311 |
elif isinstance(module, processor):
|
312 |
self.init_counts["processor"] += 1
|
313 |
elif isinstance(module, Echo):
|
@@ -336,10 +322,10 @@ def main():
|
|
336 |
|
337 |
extract_args = {
|
338 |
"waveform": False,
|
339 |
-
"spec":
|
340 |
"f0": False,
|
341 |
"f0t": False,
|
342 |
-
"pitch":
|
343 |
"harmonics": False,
|
344 |
"aperiodics": False,
|
345 |
"phase_mod": False,
|
|
|
21 |
warnings.filterwarnings("ignore")
|
22 |
logging.basicConfig(level=logging.ERROR)
|
23 |
|
24 |
+
PATH = 'E:/hf'
|
25 |
+
os.environ['HF_HOME'] = PATH
|
26 |
+
os.environ['HF_DATASETS_CACHE'] = PATH
|
27 |
+
os.environ['TORCH_HOME'] = PATH
|
28 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
29 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
30 |
+
|
31 |
+
|
32 |
@dataclass
|
33 |
class Dimensions:
|
34 |
vocab: int
|
|
|
71 |
|
72 |
class MultiheadA(nn.Module):
|
73 |
|
74 |
+
def __init__(self, dims: int, head: int):
|
75 |
super(MultiheadA, self).__init__()
|
76 |
|
77 |
self.dims = dims
|
78 |
self.head = head
|
79 |
self.head_dim = dims // head
|
|
|
80 |
|
81 |
self.q = nn.Linear(dims, dims).to(device, dtype)
|
82 |
self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
|
|
|
126 |
self.ctx = ctx
|
127 |
self.head_dim = dims // head
|
128 |
|
|
|
129 |
self.blend = nn.Parameter(torch.tensor(0.5))
|
130 |
act_fn = get_activation(act)
|
131 |
self.attn = MultiheadA(dims, head)
|
|
|
150 |
x = x + gate * mlp_out
|
151 |
return x
|
152 |
|
153 |
+
class processor(nn.Module):
|
154 |
+
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
155 |
+
super(processor, self).__init__()
|
|
|
|
|
156 |
self.dims = dims
|
157 |
self.head = head
|
158 |
+
self.layer = layer
|
159 |
+
self.ctx = ctx
|
160 |
+
self.act = act
|
161 |
self.dropout = 0.01
|
162 |
act_fn = get_activation(act)
|
163 |
|
164 |
+
self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
|
165 |
+
self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
|
166 |
+
self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
|
167 |
+
self.positional_sin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
168 |
+
|
169 |
# pitch
|
170 |
# self.encoder = nn.Sequential(
|
171 |
# Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
172 |
# Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
173 |
# Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
174 |
|
175 |
+
|
176 |
self.encoder = nn.Sequential(
|
177 |
Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
178 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
179 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
180 |
|
181 |
+
self.bA = nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
|
182 |
+
self.bB = nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
185 |
self.register_buffer("mask", mask, persistent=False)
|
|
|
188 |
def forward(self, x, xa, sequential=False) -> Tensor:
|
189 |
x = self.token(x.long()) + self.positional[:x.shape[1]]
|
190 |
|
191 |
+
xa = self.encoder(xa).permute(0, 2, 1)
|
192 |
+
xa = xa + self.positional_sin(xa.shape[1], xa.shape[-1], 36000).to(device, dtype)
|
193 |
+
|
194 |
for b in chain(self.bA or []):
|
195 |
xa = b(x=xa, xa=None, mask=None)
|
196 |
|
|
|
202 |
else:
|
203 |
a = torch.sigmoid(self.blend)
|
204 |
x = a * xc + (1 - a) * x
|
205 |
+
|
206 |
+
# for b in chain(self.bB or []):
|
207 |
+
# xd = b(x=torch.cat([x, xa], dim=1), xa=None, mask=None)
|
208 |
+
# xm = b(x=xd[:, :x.shape[1]], xa=xd[:, x.shape[1]:], mask=None)
|
209 |
+
# if sequential:
|
210 |
+
# x = xm
|
211 |
+
# else:
|
212 |
+
# a = torch.sigmoid(self.blend)
|
213 |
+
# x = a * x + (1 - a) * xm
|
214 |
+
|
215 |
+
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
216 |
x = self.norm(x)
|
217 |
x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
|
218 |
return x
|
|
|
242 |
enc= {}
|
243 |
if pitch is not None:
|
244 |
xa = pitch
|
|
|
245 |
if spectrogram is not None:
|
246 |
xa = spectrogram
|
|
|
247 |
|
248 |
x = input_ids
|
249 |
logits = self.processor(x, xa)
|
|
|
294 |
self.init_counts["MultiheadA"] += 1
|
295 |
elif isinstance(module, Residual):
|
296 |
self.init_counts["Residual"] += 1
|
|
|
|
|
297 |
elif isinstance(module, processor):
|
298 |
self.init_counts["processor"] += 1
|
299 |
elif isinstance(module, Echo):
|
|
|
322 |
|
323 |
extract_args = {
|
324 |
"waveform": False,
|
325 |
+
"spec": True,
|
326 |
"f0": False,
|
327 |
"f0t": False,
|
328 |
+
"pitch": False,
|
329 |
"harmonics": False,
|
330 |
"aperiodics": False,
|
331 |
"phase_mod": False,
|