Sin2pi commited on
Commit
82bee02
·
verified ·
1 Parent(s): 5219f06

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +79 -52
model.py CHANGED
@@ -130,7 +130,7 @@ def sinusoids(length, channels, max_timescale=10000):
130
 
131
  class rotary(nn.Module):
132
  _seen = set()
133
- def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, variable_radius=False,
134
  learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []):
135
  super().__init__()
136
  self.use_pbias = False
@@ -143,24 +143,27 @@ class rotary(nn.Module):
143
  self._counter = 0
144
  self.dims = dims
145
  self.max_ctx = max_ctx
146
- self.variable_radius = variable_radius
147
-
148
- self.inv_freq = nn.Parameter(
149
- 1.0 / (10000 ** (torch.arange(0, dims, 2, device=device, dtype=dtype) / dims)),
150
- requires_grad=learned_freq)
151
- self.theta = nn.Parameter(
152
- torch.tensor(float(theta)), requires_grad=learned_theta)
153
  self.min_theta = nn.Parameter(
154
- torch.tensor(600.0), requires_grad=learned_theta)
155
  self.max_theta = nn.Parameter(
156
- torch.tensor(2400.0), requires_grad=learned_theta)
 
 
 
157
 
158
- self.pitch_scale = nn.Parameter(torch.tensor(1.0),
159
  requires_grad=learned_pitch)
160
 
161
- if variable_radius:
162
- self.radius = nn.Parameter(
163
- torch.ones(dims // 2),
 
 
164
  requires_grad=learned_radius)
165
 
166
  def get_pitch_bias(self, f0):
@@ -189,49 +192,76 @@ class rotary(nn.Module):
189
  rotary.get_sim = get_sim
190
  rotary.fwd_sim = fwd_sim
191
 
192
- def align_f0_to_tokens(self, f0, token_length):
193
- ratio = len(f0) / token_length
194
- indices = [int(i * ratio) for i in range(token_length)]
195
- indices = [min(i, len(f0) - 1) for i in indices]
196
- return f0[indices]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  def forward(self, x=None, f0=None, stage=None) -> Tensor:
199
  if isinstance(x, int):
200
- t = torch.arange(x, device=self.device).float()
201
  else:
202
- t = x.float().to(self.inv_freq.device)
 
203
 
204
  if f0 is not None:
205
  f0_mean = f0.mean()
206
- f0_mean = torch.clamp(f0_mean, min=80.0, max=600.0)
207
- perceptual_factor = torch.log(1 + f0_mean / 700.0) / torch.log(torch.tensor(1 + 300.0 / 700.0))
208
- f0_theta = self.min_theta + perceptual_factor * (self.max_theta - self.min_theta)
209
- inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
210
- else:
211
- inv_freq = self.inv_freq
212
- freqs = torch.einsum('i,j->ij', t, inv_freq)
213
-
214
  freqs = freqs.float()
215
- if self.variable_radius:
216
-
217
- # if f0 is not None:
218
- # f0 = f0[0]
219
- # seq_len = x
220
- # f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
221
- # radius = f0
222
-
223
- # freqs = torch.polar(radius, freqs)
224
- # else:
225
-
226
- # freqs = torch.polar(torch.ones_like(freqs), freqs)
227
- # freqs = freqs.unsqueeze(0)
228
-
229
- radius = F.softplus(self.radius)
230
- freqs = torch.polar(radius.unsqueeze(0).expand_as(freqs), freqs)
231
  else:
232
- freqs = torch.polar(torch.ones_like(freqs), freqs)
233
- freqs = freqs.unsqueeze(0)
234
-
235
  if "rotary" in self.debug:
236
  if f0 is not None:
237
  key = f"{self._counter}_{f0_theta:.2f}"
@@ -239,12 +269,11 @@ class rotary(nn.Module):
239
  if not hasattr(self, '_prev_f0_theta'):
240
  self._prev_f0_theta = f0_theta
241
  print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
242
- elif abs(self._prev_f0_theta - f0_theta) > 200.0:
243
  print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
244
  self._prev_f0_theta = f0_theta
245
  rotary._seen.add(key)
246
  self._counter += 1
247
-
248
  return freqs
249
 
250
  @staticmethod
@@ -258,13 +287,11 @@ class rotary(nn.Module):
258
  x1 = x1 * freqs
259
  x1 = torch.view_as_real(x1).flatten(-2)
260
  return torch.cat([x1.type_as(x), x2], dim=-1)
261
-
262
  else:
263
  x1 = x[..., :freqs.shape[-1]*2]
264
  x2 = x[..., freqs.shape[-1]*2:]
265
 
266
  if x.ndim == 2:
267
-
268
  x1 = x1.unsqueeze(0)
269
  x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
270
  x1 = torch.view_as_complex(x1)
 
130
 
131
  class rotary(nn.Module):
132
  _seen = set()
133
+ def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, radii=False,
134
  learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []):
135
  super().__init__()
136
  self.use_pbias = False
 
143
  self._counter = 0
144
  self.dims = dims
145
  self.max_ctx = max_ctx
146
+ self.radii = radii
147
+ pitch_scale = 1.0
148
+ # theta_rescale = 1.0
149
+ # theta *= theta_rescale ** (dims / (dims - 2))
150
+
 
 
151
  self.min_theta = nn.Parameter(
152
+ torch.tensor(20.0), requires_grad=learned_theta)
153
  self.max_theta = nn.Parameter(
154
+ torch.tensor(400.0), requires_grad=learned_theta)
155
+
156
+ self.theta = nn.Parameter(
157
+ torch.tensor(float(theta)), requires_grad=learned_theta)
158
 
159
+ self.pitch_scale = nn.Parameter(torch.tensor(pitch_scale),
160
  requires_grad=learned_pitch)
161
 
162
+ freqs = 1. / (theta ** (torch.arange(0, dims, 2)[:(dims // 2)].float() / dims))
163
+ self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
164
+
165
+ if radii:
166
+ self.radius = nn.Parameter(torch.ones(dims // 2),
167
  requires_grad=learned_radius)
168
 
169
  def get_pitch_bias(self, f0):
 
192
  rotary.get_sim = get_sim
193
  rotary.fwd_sim = fwd_sim
194
 
195
+ def align_f0(self, f0, token_length):
196
+ batch_size, f0_length = f0.shape
197
+ if f0_length == token_length:
198
+ return f0 # No resampling needed (encoder path - audio features)
199
+ frames_per_token = f0_length / token_length
200
+
201
+ indices = torch.arange(token_length, device=f0.device)
202
+ indices = (indices * frames_per_token).long()#.clamp(max=f0_length-1)
203
+ #center_positions = ((indices + 0.5) * frames_per_token).long()
204
+ batch_indices = torch.arange(batch_size, device=f0.device).unsqueeze(1)
205
+ return f0[batch_indices, indices.unsqueeze(0).expand(batch_size, -1)]
206
+
207
+ def scale_f0(self, f0):
208
+ f0_min = f0.min(dim=1, keepdim=True)[0]
209
+ f0_max = f0.max(dim=1, keepdim=True)[0]
210
+ denom = f0_max - f0_min + 1e-8
211
+ normalized_f0 = (f0 - f0_min) / denom
212
+ # normalized_f0 = (f0 - f0_min) / (f0_max - f0_min)
213
+ normalized_f0 = torch.clamp(normalized_f0, 0.0, 1.0)
214
+ return normalized_f0
215
+
216
+ def process_f0(f0, threshold=0.05):
217
+ thresholded_f0 = torch.where(f0 < threshold, torch.zeros_like(f0), f0)
218
+ return thresholded_f0
219
+
220
+ def map_perceptual(self, f0_mean, theta=10000.0):
221
+ if f0_mean >= theta:
222
+ return torch.log(f0_mean / theta)
223
+ else:
224
+ return -torch.log(theta / f0_mean)
225
+
226
+ def linear_map(self, freq, min_freq=40.0, max_freq=400.0, target_max=10000.0):
227
+ mapped_freq = ((freq - min_freq) / (max_freq - min_freq)) * target_max
228
+ return mapped_freq
229
+
230
+ def log_map(self, freq, min_freq=40.0, max_freq=400.0, target_max=10000.0):
231
+ log_freq = torch.log(freq)
232
+ log_min_freq = torch.log(min_freq)
233
+ log_max_freq = torch.log(max_freq)
234
+
235
+ mapped_log_freq = ((log_freq - log_min_freq) / (log_max_freq - log_min_freq)) * torch.log(torch.tensor(target_max, device=self.device))
236
+ return mapped_log_freq
237
 
238
  def forward(self, x=None, f0=None, stage=None) -> Tensor:
239
  if isinstance(x, int):
240
+ seq_len = x
241
  else:
242
+ batch, seq_len, _ = x.shape
243
+ t = torch.arange(seq_len, device=self.device).float()
244
 
245
  if f0 is not None:
246
  f0_mean = f0.mean()
247
+ theta = self.theta
248
+ f0_theta = theta * (f0_mean * 1e-2 + 1.0)
249
+ freqs = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
250
+ else:
251
+ freqs = self.freqs
252
+
253
+ freqs = torch.einsum('i,j->ij', t, freqs)
 
254
  freqs = freqs.float()
255
+
256
+ if self.radii and f0 is not None:
257
+ radius = self.align_f0(f0, seq_len)
258
+ # radius = self.scale_f0(radius)
259
+ radius = F.softplus(self.radius) * radius
260
+ # radius = radius.unsqueeze(-1)
261
+ freqs = torch.polar(radius.unsqueeze(-1), freqs.unsqueeze(0))
 
 
 
 
 
 
 
 
 
262
  else:
263
+ freqs = torch.polar(torch.ones_like(freqs), freqs.unsqueeze(0))
264
+ # print(f"Step {self._counter}: Block: {stage}: Radius: {radius}")
 
265
  if "rotary" in self.debug:
266
  if f0 is not None:
267
  key = f"{self._counter}_{f0_theta:.2f}"
 
269
  if not hasattr(self, '_prev_f0_theta'):
270
  self._prev_f0_theta = f0_theta
271
  print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
272
+ elif abs(self._prev_f0_theta - f0_theta) > 1000.0:
273
  print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
274
  self._prev_f0_theta = f0_theta
275
  rotary._seen.add(key)
276
  self._counter += 1
 
277
  return freqs
278
 
279
  @staticmethod
 
287
  x1 = x1 * freqs
288
  x1 = torch.view_as_real(x1).flatten(-2)
289
  return torch.cat([x1.type_as(x), x2], dim=-1)
 
290
  else:
291
  x1 = x[..., :freqs.shape[-1]*2]
292
  x2 = x[..., freqs.shape[-1]*2:]
293
 
294
  if x.ndim == 2:
 
295
  x1 = x1.unsqueeze(0)
296
  x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
297
  x1 = torch.view_as_complex(x1)