Sin2pi commited on
Commit
4487bd6
·
verified ·
1 Parent(s): 05b483e

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +39 -53
model_simple.py CHANGED
@@ -21,6 +21,14 @@ dtype = torch.float32
21
  warnings.filterwarnings("ignore")
22
  logging.basicConfig(level=logging.ERROR)
23
 
 
 
 
 
 
 
 
 
24
  @dataclass
25
  class Dimensions:
26
  vocab: int
@@ -63,13 +71,12 @@ class rotary(nn.Module):
63
 
64
  class MultiheadA(nn.Module):
65
 
66
- def __init__(self, dims: int, head: int, debug: List[str] = []):
67
  super(MultiheadA, self).__init__()
68
 
69
  self.dims = dims
70
  self.head = head
71
  self.head_dim = dims // head
72
- self.debug = debug
73
 
74
  self.q = nn.Linear(dims, dims).to(device, dtype)
75
  self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
@@ -119,7 +126,6 @@ class Residual(nn.Module):
119
  self.ctx = ctx
120
  self.head_dim = dims // head
121
 
122
-
123
  self.blend = nn.Parameter(torch.tensor(0.5))
124
  act_fn = get_activation(act)
125
  self.attn = MultiheadA(dims, head)
@@ -144,65 +150,36 @@ class Residual(nn.Module):
144
  x = x + gate * mlp_out
145
  return x
146
 
147
-
148
- class feature_encoder(nn.Module):
149
- def __init__(self, mels, dims, head, layer, act="gelu"):
150
- super().__init__()
151
-
152
  self.dims = dims
153
  self.head = head
154
- self.head_dim = dims // head
 
 
155
  self.dropout = 0.01
156
  act_fn = get_activation(act)
157
 
 
 
 
 
 
158
  # pitch
159
  # self.encoder = nn.Sequential(
160
  # Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
161
  # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
162
  # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
163
 
164
- # spectrogram
165
  self.encoder = nn.Sequential(
166
  Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
167
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
168
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
169
 
170
-
171
- self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
172
- self.norm = RMSNorm(dims)
173
-
174
- def forward(self, x, xa=None, mask=None, max_tscale=36000):
175
- if x.dim() == 2:
176
- x = x.unsqueeze(0)
177
- # x = self.pitch(x).permute(0, 2, 1)
178
- x = self.encoder(x).permute(0, 2, 1)
179
- max_tscale = x.shape[1] * 1000 if max_tscale is None else max_tscale
180
- x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
181
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
182
- x = self.norm(x)
183
- return x
184
-
185
- class processor(nn.Module):
186
- def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
187
- super(processor, self).__init__()
188
- self.dims = dims
189
- self.head = head
190
- self.layer = layer
191
- self.ctx = ctx
192
- self.act = act
193
- self.dropout = 0.01
194
- act_fn = get_activation(act)
195
-
196
- self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
197
- self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
198
- self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
199
-
200
- self.bA = nn.ModuleList(
201
- [feature_encoder(mels=mels, dims=dims, head=head, layer=layer, act=act_fn)] +
202
- [Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
203
- self.bB = nn.ModuleList([
204
- Residual(ctx=ctx, dims=dims, head=head, act=act_fn)
205
- for _ in range(layer)])
206
 
207
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
208
  self.register_buffer("mask", mask, persistent=False)
@@ -211,6 +188,9 @@ class processor(nn.Module):
211
  def forward(self, x, xa, sequential=False) -> Tensor:
212
  x = self.token(x.long()) + self.positional[:x.shape[1]]
213
 
 
 
 
214
  for b in chain(self.bA or []):
215
  xa = b(x=xa, xa=None, mask=None)
216
 
@@ -222,7 +202,17 @@ class processor(nn.Module):
222
  else:
223
  a = torch.sigmoid(self.blend)
224
  x = a * xc + (1 - a) * x
225
-
 
 
 
 
 
 
 
 
 
 
226
  x = self.norm(x)
227
  x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
228
  return x
@@ -252,10 +242,8 @@ class Echo(nn.Module):
252
  enc= {}
253
  if pitch is not None:
254
  xa = pitch
255
- enc["pitch"] = pitch
256
  if spectrogram is not None:
257
  xa = spectrogram
258
- enc["spectrogram"] = spectrogram
259
 
260
  x = input_ids
261
  logits = self.processor(x, xa)
@@ -306,8 +294,6 @@ class Echo(nn.Module):
306
  self.init_counts["MultiheadA"] += 1
307
  elif isinstance(module, Residual):
308
  self.init_counts["Residual"] += 1
309
- elif isinstance(module, feature_encoder):
310
- self.init_counts["feature_encoder"] += 1
311
  elif isinstance(module, processor):
312
  self.init_counts["processor"] += 1
313
  elif isinstance(module, Echo):
@@ -336,10 +322,10 @@ def main():
336
 
337
  extract_args = {
338
  "waveform": False,
339
- "spec": False,
340
  "f0": False,
341
  "f0t": False,
342
- "pitch": True,
343
  "harmonics": False,
344
  "aperiodics": False,
345
  "phase_mod": False,
 
21
  warnings.filterwarnings("ignore")
22
  logging.basicConfig(level=logging.ERROR)
23
 
24
+ PATH = 'E:/hf'
25
+ os.environ['HF_HOME'] = PATH
26
+ os.environ['HF_DATASETS_CACHE'] = PATH
27
+ os.environ['TORCH_HOME'] = PATH
28
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
29
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
30
+
31
+
32
  @dataclass
33
  class Dimensions:
34
  vocab: int
 
71
 
72
  class MultiheadA(nn.Module):
73
 
74
+ def __init__(self, dims: int, head: int):
75
  super(MultiheadA, self).__init__()
76
 
77
  self.dims = dims
78
  self.head = head
79
  self.head_dim = dims // head
 
80
 
81
  self.q = nn.Linear(dims, dims).to(device, dtype)
82
  self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
 
126
  self.ctx = ctx
127
  self.head_dim = dims // head
128
 
 
129
  self.blend = nn.Parameter(torch.tensor(0.5))
130
  act_fn = get_activation(act)
131
  self.attn = MultiheadA(dims, head)
 
150
  x = x + gate * mlp_out
151
  return x
152
 
153
+ class processor(nn.Module):
154
+ def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
155
+ super(processor, self).__init__()
 
 
156
  self.dims = dims
157
  self.head = head
158
+ self.layer = layer
159
+ self.ctx = ctx
160
+ self.act = act
161
  self.dropout = 0.01
162
  act_fn = get_activation(act)
163
 
164
+ self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
165
+ self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
166
+ self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
167
+ self.positional_sin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
168
+
169
  # pitch
170
  # self.encoder = nn.Sequential(
171
  # Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
172
  # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
173
  # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
174
 
175
+
176
  self.encoder = nn.Sequential(
177
  Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
178
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
179
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
180
 
181
+ self.bA = nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
182
+ self.bB = nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
185
  self.register_buffer("mask", mask, persistent=False)
 
188
  def forward(self, x, xa, sequential=False) -> Tensor:
189
  x = self.token(x.long()) + self.positional[:x.shape[1]]
190
 
191
+ xa = self.encoder(xa).permute(0, 2, 1)
192
+ xa = xa + self.positional_sin(xa.shape[1], xa.shape[-1], 36000).to(device, dtype)
193
+
194
  for b in chain(self.bA or []):
195
  xa = b(x=xa, xa=None, mask=None)
196
 
 
202
  else:
203
  a = torch.sigmoid(self.blend)
204
  x = a * xc + (1 - a) * x
205
+
206
+ # for b in chain(self.bB or []):
207
+ # xd = b(x=torch.cat([x, xa], dim=1), xa=None, mask=None)
208
+ # xm = b(x=xd[:, :x.shape[1]], xa=xd[:, x.shape[1]:], mask=None)
209
+ # if sequential:
210
+ # x = xm
211
+ # else:
212
+ # a = torch.sigmoid(self.blend)
213
+ # x = a * x + (1 - a) * xm
214
+
215
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
216
  x = self.norm(x)
217
  x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
218
  return x
 
242
  enc= {}
243
  if pitch is not None:
244
  xa = pitch
 
245
  if spectrogram is not None:
246
  xa = spectrogram
 
247
 
248
  x = input_ids
249
  logits = self.processor(x, xa)
 
294
  self.init_counts["MultiheadA"] += 1
295
  elif isinstance(module, Residual):
296
  self.init_counts["Residual"] += 1
 
 
297
  elif isinstance(module, processor):
298
  self.init_counts["processor"] += 1
299
  elif isinstance(module, Echo):
 
322
 
323
  extract_args = {
324
  "waveform": False,
325
+ "spec": True,
326
  "f0": False,
327
  "f0t": False,
328
+ "pitch": False,
329
  "harmonics": False,
330
  "aperiodics": False,
331
  "phase_mod": False,