Sin2pi commited on
Commit
71514c3
·
verified ·
1 Parent(s): 64551de

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +17 -10
model.py CHANGED
@@ -285,21 +285,20 @@ class rotary(nn.Module):
285
  learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = [], use_pbias = False):
286
  super().__init__()
287
 
288
- self.dims = dims
289
  self.use_pbias = use_pbias
290
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
291
  self.dtype = torch.float32
292
  self.debug = debug
293
  self._counter = 0
294
-
295
  self.max_ctx = max_ctx
296
  self.radii = radii
297
  f0_factor = 0.5
298
- self.adaptation: bool = False
299
  pitch_scale = 1.0
300
  radius = 1
301
 
302
- if self.adaptation:
303
  self.f0_scale = nn.Parameter(torch.tensor(f0_factor, device=self.device, dtype=self.dtype), requires_grad=True)
304
  else:
305
  self.register_buffer('f0_scale', torch.tensor(f0_factor))
@@ -310,36 +309,44 @@ class rotary(nn.Module):
310
  self.freqs = nn.Parameter(torch.tensor(freqs, device=self.device, dtype=self.dtype), requires_grad=True)
311
  self.radius = nn.Parameter(torch.ones(radius, device=self.device, dtype=self.dtype), requires_grad=True)
312
 
313
- def forward(self, x=None, feat=None, layer=None) -> Tensor:
314
- f0 = feat.get("f0") if feat else None
 
315
  if isinstance(x, int):
316
  ctx = x
317
  else:
318
  batch, ctx, dims = x.shape
319
  t = torch.arange(ctx, device=self.device).float()
 
320
  if f0 is not None:
321
  f0_mean=f0.mean()+1e-8
322
  theta=f0_mean*self.pitch_scale
323
  freqs = 1. / (theta ** (torch.arange(0, self.dims, 2, device=self.device, dtype=self.dtype)[:(self.dims // 2)].float() /self.dims))
324
  else:
325
  freqs = self.freqs
 
326
  freqs = torch.einsum('i,j->ij', t, freqs)
327
  freqs = freqs.float()
328
-
329
  if self.radii:
330
- radius = feat.get("f0d") if feat else self.radius
 
331
  radius = radius.float()
 
332
  else:
333
  radius = self.radius
334
- freqs = torch.polar(radius.unsqueeze(-1), freqs) # freqs = torch.polar(torch.ones_like(freqs), freqs.unsqueeze(0))
335
-
 
336
  if "rotary" in self.debug:
337
  if f0 is not None:
338
  key = f"{self._counter}_{theta:.2f}"
339
  if key not in rotary._seen:
340
  if not hasattr(self, '_prev_f0_theta'):
341
  self._prev_f0_theta = theta
 
342
  elif abs(self._prev_f0_theta - theta) > 100.0:
 
343
  print(f"{layer} : {f0_mean} : Theta: {theta:.2f} : {theta:.2f} : {ctx} ")
344
  if self.radii:
345
  print(f"radius: {radius} Hz, enc: {layer} Hz, ctx: {ctx}")
 
285
  learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = [], use_pbias = False):
286
  super().__init__()
287
 
 
288
  self.use_pbias = use_pbias
289
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
290
  self.dtype = torch.float32
291
  self.debug = debug
292
  self._counter = 0
293
+ self.dims = dims
294
  self.max_ctx = max_ctx
295
  self.radii = radii
296
  f0_factor = 0.5
297
+ self.learned_adaptation: bool = False
298
  pitch_scale = 1.0
299
  radius = 1
300
 
301
+ if self.learned_adaptation:
302
  self.f0_scale = nn.Parameter(torch.tensor(f0_factor, device=self.device, dtype=self.dtype), requires_grad=True)
303
  else:
304
  self.register_buffer('f0_scale', torch.tensor(f0_factor))
 
309
  self.freqs = nn.Parameter(torch.tensor(freqs, device=self.device, dtype=self.dtype), requires_grad=True)
310
  self.radius = nn.Parameter(torch.ones(radius, device=self.device, dtype=self.dtype), requires_grad=True)
311
 
312
+ def forward(self, x=None, layer=None, enc=None) -> Tensor:
313
+
314
+ f0 = enc.get("f0") if enc else None
315
  if isinstance(x, int):
316
  ctx = x
317
  else:
318
  batch, ctx, dims = x.shape
319
  t = torch.arange(ctx, device=self.device).float()
320
+
321
  if f0 is not None:
322
  f0_mean=f0.mean()+1e-8
323
  theta=f0_mean*self.pitch_scale
324
  freqs = 1. / (theta ** (torch.arange(0, self.dims, 2, device=self.device, dtype=self.dtype)[:(self.dims // 2)].float() /self.dims))
325
  else:
326
  freqs = self.freqs
327
+
328
  freqs = torch.einsum('i,j->ij', t, freqs)
329
  freqs = freqs.float()
330
+ # print(f"{layer} : {f0_mean} : {theta:.2f} : {ctx} ")
331
  if self.radii:
332
+ # radius = self.align_f0(f0, ctx)
333
+ radius = enc.get("f0d") if enc else self.radius
334
  radius = radius.float()
335
+
336
  else:
337
  radius = self.radius
338
+ # freqs = torch.polar(self.radius.unsqueeze(-1), freqs)
339
+ freqs = torch.polar(radius.unsqueeze(-1), freqs)
340
+
341
  if "rotary" in self.debug:
342
  if f0 is not None:
343
  key = f"{self._counter}_{theta:.2f}"
344
  if key not in rotary._seen:
345
  if not hasattr(self, '_prev_f0_theta'):
346
  self._prev_f0_theta = theta
347
+ # print(f"Step {self._counter}: Theta: {theta:.2f} Hz")
348
  elif abs(self._prev_f0_theta - theta) > 100.0:
349
+ # print(f"Step {self._counter}: Theta: {theta:.2f} Hz, freqs: {freqs.shape}")
350
  print(f"{layer} : {f0_mean} : Theta: {theta:.2f} : {theta:.2f} : {ctx} ")
351
  if self.radii:
352
  print(f"radius: {radius} Hz, enc: {layer} Hz, ctx: {ctx}")