Update model.py
Browse files
model.py
CHANGED
@@ -130,7 +130,7 @@ def sinusoids(length, channels, max_timescale=10000):
|
|
130 |
|
131 |
class rotary(nn.Module):
|
132 |
_seen = set()
|
133 |
-
def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False,
|
134 |
learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []):
|
135 |
super().__init__()
|
136 |
self.use_pbias = False
|
@@ -143,24 +143,27 @@ class rotary(nn.Module):
|
|
143 |
self._counter = 0
|
144 |
self.dims = dims
|
145 |
self.max_ctx = max_ctx
|
146 |
-
self.
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
self.theta = nn.Parameter(
|
152 |
-
torch.tensor(float(theta)), requires_grad=learned_theta)
|
153 |
self.min_theta = nn.Parameter(
|
154 |
-
torch.tensor(
|
155 |
self.max_theta = nn.Parameter(
|
156 |
-
torch.tensor(
|
|
|
|
|
|
|
157 |
|
158 |
-
self.pitch_scale = nn.Parameter(torch.tensor(
|
159 |
requires_grad=learned_pitch)
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
164 |
requires_grad=learned_radius)
|
165 |
|
166 |
def get_pitch_bias(self, f0):
|
@@ -189,49 +192,76 @@ class rotary(nn.Module):
|
|
189 |
rotary.get_sim = get_sim
|
190 |
rotary.fwd_sim = fwd_sim
|
191 |
|
192 |
-
def
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
def forward(self, x=None, f0=None, stage=None) -> Tensor:
|
199 |
if isinstance(x, int):
|
200 |
-
|
201 |
else:
|
202 |
-
|
|
|
203 |
|
204 |
if f0 is not None:
|
205 |
f0_mean = f0.mean()
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
freqs = torch.einsum('i,j->ij', t,
|
213 |
-
|
214 |
freqs = freqs.float()
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
#
|
219 |
-
|
220 |
-
#
|
221 |
-
|
222 |
-
|
223 |
-
# freqs = torch.polar(radius, freqs)
|
224 |
-
# else:
|
225 |
-
|
226 |
-
# freqs = torch.polar(torch.ones_like(freqs), freqs)
|
227 |
-
# freqs = freqs.unsqueeze(0)
|
228 |
-
|
229 |
-
radius = F.softplus(self.radius)
|
230 |
-
freqs = torch.polar(radius.unsqueeze(0).expand_as(freqs), freqs)
|
231 |
else:
|
232 |
-
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
233 |
-
|
234 |
-
|
235 |
if "rotary" in self.debug:
|
236 |
if f0 is not None:
|
237 |
key = f"{self._counter}_{f0_theta:.2f}"
|
@@ -239,12 +269,11 @@ class rotary(nn.Module):
|
|
239 |
if not hasattr(self, '_prev_f0_theta'):
|
240 |
self._prev_f0_theta = f0_theta
|
241 |
print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
|
242 |
-
elif abs(self._prev_f0_theta - f0_theta) >
|
243 |
print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
|
244 |
self._prev_f0_theta = f0_theta
|
245 |
rotary._seen.add(key)
|
246 |
self._counter += 1
|
247 |
-
|
248 |
return freqs
|
249 |
|
250 |
@staticmethod
|
@@ -258,13 +287,11 @@ class rotary(nn.Module):
|
|
258 |
x1 = x1 * freqs
|
259 |
x1 = torch.view_as_real(x1).flatten(-2)
|
260 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
261 |
-
|
262 |
else:
|
263 |
x1 = x[..., :freqs.shape[-1]*2]
|
264 |
x2 = x[..., freqs.shape[-1]*2:]
|
265 |
|
266 |
if x.ndim == 2:
|
267 |
-
|
268 |
x1 = x1.unsqueeze(0)
|
269 |
x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
|
270 |
x1 = torch.view_as_complex(x1)
|
|
|
130 |
|
131 |
class rotary(nn.Module):
|
132 |
_seen = set()
|
133 |
+
def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, radii=False,
|
134 |
learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []):
|
135 |
super().__init__()
|
136 |
self.use_pbias = False
|
|
|
143 |
self._counter = 0
|
144 |
self.dims = dims
|
145 |
self.max_ctx = max_ctx
|
146 |
+
self.radii = radii
|
147 |
+
pitch_scale = 1.0
|
148 |
+
# theta_rescale = 1.0
|
149 |
+
# theta *= theta_rescale ** (dims / (dims - 2))
|
150 |
+
|
|
|
|
|
151 |
self.min_theta = nn.Parameter(
|
152 |
+
torch.tensor(20.0), requires_grad=learned_theta)
|
153 |
self.max_theta = nn.Parameter(
|
154 |
+
torch.tensor(400.0), requires_grad=learned_theta)
|
155 |
+
|
156 |
+
self.theta = nn.Parameter(
|
157 |
+
torch.tensor(float(theta)), requires_grad=learned_theta)
|
158 |
|
159 |
+
self.pitch_scale = nn.Parameter(torch.tensor(pitch_scale),
|
160 |
requires_grad=learned_pitch)
|
161 |
|
162 |
+
freqs = 1. / (theta ** (torch.arange(0, dims, 2)[:(dims // 2)].float() / dims))
|
163 |
+
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
|
164 |
+
|
165 |
+
if radii:
|
166 |
+
self.radius = nn.Parameter(torch.ones(dims // 2),
|
167 |
requires_grad=learned_radius)
|
168 |
|
169 |
def get_pitch_bias(self, f0):
|
|
|
192 |
rotary.get_sim = get_sim
|
193 |
rotary.fwd_sim = fwd_sim
|
194 |
|
195 |
+
def align_f0(self, f0, token_length):
|
196 |
+
batch_size, f0_length = f0.shape
|
197 |
+
if f0_length == token_length:
|
198 |
+
return f0 # No resampling needed (encoder path - audio features)
|
199 |
+
frames_per_token = f0_length / token_length
|
200 |
+
|
201 |
+
indices = torch.arange(token_length, device=f0.device)
|
202 |
+
indices = (indices * frames_per_token).long()#.clamp(max=f0_length-1)
|
203 |
+
#center_positions = ((indices + 0.5) * frames_per_token).long()
|
204 |
+
batch_indices = torch.arange(batch_size, device=f0.device).unsqueeze(1)
|
205 |
+
return f0[batch_indices, indices.unsqueeze(0).expand(batch_size, -1)]
|
206 |
+
|
207 |
+
def scale_f0(self, f0):
|
208 |
+
f0_min = f0.min(dim=1, keepdim=True)[0]
|
209 |
+
f0_max = f0.max(dim=1, keepdim=True)[0]
|
210 |
+
denom = f0_max - f0_min + 1e-8
|
211 |
+
normalized_f0 = (f0 - f0_min) / denom
|
212 |
+
# normalized_f0 = (f0 - f0_min) / (f0_max - f0_min)
|
213 |
+
normalized_f0 = torch.clamp(normalized_f0, 0.0, 1.0)
|
214 |
+
return normalized_f0
|
215 |
+
|
216 |
+
def process_f0(f0, threshold=0.05):
|
217 |
+
thresholded_f0 = torch.where(f0 < threshold, torch.zeros_like(f0), f0)
|
218 |
+
return thresholded_f0
|
219 |
+
|
220 |
+
def map_perceptual(self, f0_mean, theta=10000.0):
|
221 |
+
if f0_mean >= theta:
|
222 |
+
return torch.log(f0_mean / theta)
|
223 |
+
else:
|
224 |
+
return -torch.log(theta / f0_mean)
|
225 |
+
|
226 |
+
def linear_map(self, freq, min_freq=40.0, max_freq=400.0, target_max=10000.0):
|
227 |
+
mapped_freq = ((freq - min_freq) / (max_freq - min_freq)) * target_max
|
228 |
+
return mapped_freq
|
229 |
+
|
230 |
+
def log_map(self, freq, min_freq=40.0, max_freq=400.0, target_max=10000.0):
|
231 |
+
log_freq = torch.log(freq)
|
232 |
+
log_min_freq = torch.log(min_freq)
|
233 |
+
log_max_freq = torch.log(max_freq)
|
234 |
+
|
235 |
+
mapped_log_freq = ((log_freq - log_min_freq) / (log_max_freq - log_min_freq)) * torch.log(torch.tensor(target_max, device=self.device))
|
236 |
+
return mapped_log_freq
|
237 |
|
238 |
def forward(self, x=None, f0=None, stage=None) -> Tensor:
|
239 |
if isinstance(x, int):
|
240 |
+
seq_len = x
|
241 |
else:
|
242 |
+
batch, seq_len, _ = x.shape
|
243 |
+
t = torch.arange(seq_len, device=self.device).float()
|
244 |
|
245 |
if f0 is not None:
|
246 |
f0_mean = f0.mean()
|
247 |
+
theta = self.theta
|
248 |
+
f0_theta = theta * (f0_mean * 1e-2 + 1.0)
|
249 |
+
freqs = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
|
250 |
+
else:
|
251 |
+
freqs = self.freqs
|
252 |
+
|
253 |
+
freqs = torch.einsum('i,j->ij', t, freqs)
|
|
|
254 |
freqs = freqs.float()
|
255 |
+
|
256 |
+
if self.radii and f0 is not None:
|
257 |
+
radius = self.align_f0(f0, seq_len)
|
258 |
+
# radius = self.scale_f0(radius)
|
259 |
+
radius = F.softplus(self.radius) * radius
|
260 |
+
# radius = radius.unsqueeze(-1)
|
261 |
+
freqs = torch.polar(radius.unsqueeze(-1), freqs.unsqueeze(0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
else:
|
263 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs.unsqueeze(0))
|
264 |
+
# print(f"Step {self._counter}: Block: {stage}: Radius: {radius}")
|
|
|
265 |
if "rotary" in self.debug:
|
266 |
if f0 is not None:
|
267 |
key = f"{self._counter}_{f0_theta:.2f}"
|
|
|
269 |
if not hasattr(self, '_prev_f0_theta'):
|
270 |
self._prev_f0_theta = f0_theta
|
271 |
print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
|
272 |
+
elif abs(self._prev_f0_theta - f0_theta) > 1000.0:
|
273 |
print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
|
274 |
self._prev_f0_theta = f0_theta
|
275 |
rotary._seen.add(key)
|
276 |
self._counter += 1
|
|
|
277 |
return freqs
|
278 |
|
279 |
@staticmethod
|
|
|
287 |
x1 = x1 * freqs
|
288 |
x1 = torch.view_as_real(x1).flatten(-2)
|
289 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
|
|
290 |
else:
|
291 |
x1 = x[..., :freqs.shape[-1]*2]
|
292 |
x2 = x[..., freqs.shape[-1]*2:]
|
293 |
|
294 |
if x.ndim == 2:
|
|
|
295 |
x1 = x1.unsqueeze(0)
|
296 |
x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
|
297 |
x1 = torch.view_as_complex(x1)
|