Update model.py
Browse files
model.py
CHANGED
@@ -285,21 +285,20 @@ class rotary(nn.Module):
|
|
285 |
learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = [], use_pbias = False):
|
286 |
super().__init__()
|
287 |
|
288 |
-
self.dims = dims
|
289 |
self.use_pbias = use_pbias
|
290 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
291 |
self.dtype = torch.float32
|
292 |
self.debug = debug
|
293 |
self._counter = 0
|
294 |
-
|
295 |
self.max_ctx = max_ctx
|
296 |
self.radii = radii
|
297 |
f0_factor = 0.5
|
298 |
-
self.
|
299 |
pitch_scale = 1.0
|
300 |
radius = 1
|
301 |
|
302 |
-
if self.
|
303 |
self.f0_scale = nn.Parameter(torch.tensor(f0_factor, device=self.device, dtype=self.dtype), requires_grad=True)
|
304 |
else:
|
305 |
self.register_buffer('f0_scale', torch.tensor(f0_factor))
|
@@ -310,36 +309,44 @@ class rotary(nn.Module):
|
|
310 |
self.freqs = nn.Parameter(torch.tensor(freqs, device=self.device, dtype=self.dtype), requires_grad=True)
|
311 |
self.radius = nn.Parameter(torch.ones(radius, device=self.device, dtype=self.dtype), requires_grad=True)
|
312 |
|
313 |
-
def forward(self, x=None,
|
314 |
-
|
|
|
315 |
if isinstance(x, int):
|
316 |
ctx = x
|
317 |
else:
|
318 |
batch, ctx, dims = x.shape
|
319 |
t = torch.arange(ctx, device=self.device).float()
|
|
|
320 |
if f0 is not None:
|
321 |
f0_mean=f0.mean()+1e-8
|
322 |
theta=f0_mean*self.pitch_scale
|
323 |
freqs = 1. / (theta ** (torch.arange(0, self.dims, 2, device=self.device, dtype=self.dtype)[:(self.dims // 2)].float() /self.dims))
|
324 |
else:
|
325 |
freqs = self.freqs
|
|
|
326 |
freqs = torch.einsum('i,j->ij', t, freqs)
|
327 |
freqs = freqs.float()
|
328 |
-
|
329 |
if self.radii:
|
330 |
-
radius =
|
|
|
331 |
radius = radius.float()
|
|
|
332 |
else:
|
333 |
radius = self.radius
|
334 |
-
freqs = torch.polar(radius.unsqueeze(-1), freqs)
|
335 |
-
|
|
|
336 |
if "rotary" in self.debug:
|
337 |
if f0 is not None:
|
338 |
key = f"{self._counter}_{theta:.2f}"
|
339 |
if key not in rotary._seen:
|
340 |
if not hasattr(self, '_prev_f0_theta'):
|
341 |
self._prev_f0_theta = theta
|
|
|
342 |
elif abs(self._prev_f0_theta - theta) > 100.0:
|
|
|
343 |
print(f"{layer} : {f0_mean} : Theta: {theta:.2f} : {theta:.2f} : {ctx} ")
|
344 |
if self.radii:
|
345 |
print(f"radius: {radius} Hz, enc: {layer} Hz, ctx: {ctx}")
|
|
|
285 |
learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = [], use_pbias = False):
|
286 |
super().__init__()
|
287 |
|
|
|
288 |
self.use_pbias = use_pbias
|
289 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
290 |
self.dtype = torch.float32
|
291 |
self.debug = debug
|
292 |
self._counter = 0
|
293 |
+
self.dims = dims
|
294 |
self.max_ctx = max_ctx
|
295 |
self.radii = radii
|
296 |
f0_factor = 0.5
|
297 |
+
self.learned_adaptation: bool = False
|
298 |
pitch_scale = 1.0
|
299 |
radius = 1
|
300 |
|
301 |
+
if self.learned_adaptation:
|
302 |
self.f0_scale = nn.Parameter(torch.tensor(f0_factor, device=self.device, dtype=self.dtype), requires_grad=True)
|
303 |
else:
|
304 |
self.register_buffer('f0_scale', torch.tensor(f0_factor))
|
|
|
309 |
self.freqs = nn.Parameter(torch.tensor(freqs, device=self.device, dtype=self.dtype), requires_grad=True)
|
310 |
self.radius = nn.Parameter(torch.ones(radius, device=self.device, dtype=self.dtype), requires_grad=True)
|
311 |
|
312 |
+
def forward(self, x=None, layer=None, enc=None) -> Tensor:
|
313 |
+
|
314 |
+
f0 = enc.get("f0") if enc else None
|
315 |
if isinstance(x, int):
|
316 |
ctx = x
|
317 |
else:
|
318 |
batch, ctx, dims = x.shape
|
319 |
t = torch.arange(ctx, device=self.device).float()
|
320 |
+
|
321 |
if f0 is not None:
|
322 |
f0_mean=f0.mean()+1e-8
|
323 |
theta=f0_mean*self.pitch_scale
|
324 |
freqs = 1. / (theta ** (torch.arange(0, self.dims, 2, device=self.device, dtype=self.dtype)[:(self.dims // 2)].float() /self.dims))
|
325 |
else:
|
326 |
freqs = self.freqs
|
327 |
+
|
328 |
freqs = torch.einsum('i,j->ij', t, freqs)
|
329 |
freqs = freqs.float()
|
330 |
+
# print(f"{layer} : {f0_mean} : {theta:.2f} : {ctx} ")
|
331 |
if self.radii:
|
332 |
+
# radius = self.align_f0(f0, ctx)
|
333 |
+
radius = enc.get("f0d") if enc else self.radius
|
334 |
radius = radius.float()
|
335 |
+
|
336 |
else:
|
337 |
radius = self.radius
|
338 |
+
# freqs = torch.polar(self.radius.unsqueeze(-1), freqs)
|
339 |
+
freqs = torch.polar(radius.unsqueeze(-1), freqs)
|
340 |
+
|
341 |
if "rotary" in self.debug:
|
342 |
if f0 is not None:
|
343 |
key = f"{self._counter}_{theta:.2f}"
|
344 |
if key not in rotary._seen:
|
345 |
if not hasattr(self, '_prev_f0_theta'):
|
346 |
self._prev_f0_theta = theta
|
347 |
+
# print(f"Step {self._counter}: Theta: {theta:.2f} Hz")
|
348 |
elif abs(self._prev_f0_theta - theta) > 100.0:
|
349 |
+
# print(f"Step {self._counter}: Theta: {theta:.2f} Hz, freqs: {freqs.shape}")
|
350 |
print(f"{layer} : {f0_mean} : Theta: {theta:.2f} : {theta:.2f} : {ctx} ")
|
351 |
if self.radii:
|
352 |
print(f"radius: {radius} Hz, enc: {layer} Hz, ctx: {ctx}")
|