Update model_hf.py
Browse files- model_hf.py +5 -7
model_hf.py
CHANGED
@@ -377,14 +377,15 @@ class rotary(nn.Module):
|
|
377 |
else:
|
378 |
f0 = f0.view(-1)
|
379 |
|
380 |
-
if f0 is not None
|
381 |
f0_mean = f0.mean()
|
382 |
theta = f0_mean + self.theta
|
383 |
else:
|
384 |
theta = self.theta
|
385 |
freqs = self.theta_freqs(theta)
|
386 |
freqs = t[:, None] * freqs[None, :]
|
387 |
-
|
|
|
388 |
radius = f0.to(device, dtype)
|
389 |
L = radius.shape[0]
|
390 |
if L != ctx:
|
@@ -392,12 +393,9 @@ class rotary(nn.Module):
|
|
392 |
idx = torch.arange(ctx, device=f0.device)
|
393 |
idx = (idx * F).long().clamp(0, L - 1)
|
394 |
radius = radius[idx]
|
395 |
-
|
396 |
-
radius = radius.unsqueeze(-1).expand(-1, freqs.shape[-1])
|
397 |
-
# radius = torch.sigmoid(radius)
|
398 |
else:
|
399 |
-
|
400 |
-
freqs = torch.polar(radius, freqs)
|
401 |
|
402 |
if "radius" in self.debug and self.counter % 100 == 0:
|
403 |
theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
|
|
|
377 |
else:
|
378 |
f0 = f0.view(-1)
|
379 |
|
380 |
+
if f0 is not None:
|
381 |
f0_mean = f0.mean()
|
382 |
theta = f0_mean + self.theta
|
383 |
else:
|
384 |
theta = self.theta
|
385 |
freqs = self.theta_freqs(theta)
|
386 |
freqs = t[:, None] * freqs[None, :]
|
387 |
+
|
388 |
+
if self.radii and f0 is not None:
|
389 |
radius = f0.to(device, dtype)
|
390 |
L = radius.shape[0]
|
391 |
if L != ctx:
|
|
|
393 |
idx = torch.arange(ctx, device=f0.device)
|
394 |
idx = (idx * F).long().clamp(0, L - 1)
|
395 |
radius = radius[idx]
|
396 |
+
freqs = torch.polar(radius.unsqueeze(-1).expand_as(freqs), freqs)
|
|
|
|
|
397 |
else:
|
398 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
|
|
399 |
|
400 |
if "radius" in self.debug and self.counter % 100 == 0:
|
401 |
theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
|