File size: 11,996 Bytes
d5a705a 9d03cd4 11c9723 0fd1326 eaeb462 93b313e 67c6d40 05b483e 0fd1326 67c6d40 eb210a9 25a309a 87ec76d e38891a 034db85 25a309a 6745dd0 d4f1c3c 36e60da 2bc5a91 d6f816b 36e60da d6f816b 2bc5a91 0fd1326 018f69e d6f816b 018f69e d6f816b 018f69e d6f816b 018f69e f7847b1 11c9723 f7847b1 018f69e d8794a9 67c6d40 f7847b1 5470638 11c9723 5470638 11c9723 f7847b1 606e3d4 5470638 11c9723 606e3d4 67c6d40 606e3d4 67c6d40 d6f816b d8794a9 d6f816b 018f69e 1e519f6 82d6276 d6f816b 82d6276 d6f816b 82d6276 d6f816b 82d6276 d6f816b 82d6276 d6f816b 82d6276 d6f816b 82d6276 d6f816b 82d6276 d6f816b 82d6276 d6f816b 82d6276 2bc5a91 034db85 4d76b89 b693b35 05b483e f985f06 034db85 f7847b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
---
license: apache-2.0
datasets:
- google/fleurs
metrics:
- wer
- accuracy
- cer
pipeline_tag: automatic-speech-recognition
tags:
- pitch
- f0
- echo
- whiper
- waveform
- spectrogram
- hilbert
- asr
- nlp
- new
---
ASR model + pitch aware relative positional embeddings.
### Decrease WER significantly compared to standard inverse frequency. 'eval_wer': 35.3
def _compute_freqs_base(self):
mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
return 200 * mel_scale / 1000
### Standared inv freqs: 'eval_wer': 61.6
freqs = 1.0 / (self.theta ** (torch.arange(0, self.head_dim, 2, device=device, dtype=dtype) / (self.head_dim // 2)))
<img width="1363" height="732" alt="pitch_spectrogram" src="https://github.com/user-attachments/assets/ceb65e94-7df4-41b7-aa3d-c4aa4c6c0717" />
<img width="233" height="77" alt="legend" src="https://github.com/user-attachments/assets/fad84550-a199-43b3-8471-d011a9fd6f94" />
https://huggingface.co/Sin2pi/asr-model/tensorboard
Questions:
-How can we make attention mechanisms aware of speech-specific properties?
-Can we incorporate acoustic information directly into positional encodings?
-Does pitch-conditioning improve speech recognition?
---
To explore the relationship between pitch and rotary embeddings, the model implements three complementary pitch based enhancements:
1. Pitch modulated theta Pitch f0 is used to modify the theta parameter, dynamically adjusting the rotary frequency.
2. Direct similarity bias: A pitch based similarity bias is added directly to the attention mechanism.
3. Variable radii in torch.polar: The unit circle radius 1.0 in the torch.polar calculation is replaced with variable radii derived from f0. This creates acoustically-weighted positional encodings, so each position in the embedding space reflects the acoustic prominence in the original speech. This approach effectively adds phase and amplitutde information without significant computational overhead.
The function `torch.polar` constructs a complex tensor from polar coordinates:
````python
# torch.polarmagnitude, angle returns:
result = magnitude * torch.cosangle + 1j * torch.sinangle
````
So, for each element:
- magnitude is the modulus radius, r
- angle is the phase theta, in radians
- The result is: `r * expi * theta = r * costheta + i * sintheta`
Reference: [PyTorch Documentation - torch.polar]https:pytorch.orgdocsstablegeneratedtorch.polar.html
<img width="1370" height="576" alt="123123" src="https://github.com/user-attachments/assets/17031084-48aa-46db-8b12-c025417f3074" />
```python
# Modified freq calculation:
pos = torch.arange(ctx, device=device, dtype=dtype)
freqs = (self.theta / 220.0) * 200 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
freqs = pos[:, None] * freqs
# standard
pos = torch.arange(ctx, dtype=torch.float32, device=device)
freqs = 1.0 / (self.theta ** (torch.arange(0, self.head_dim, 2, device=device, dtype=dtype) / (self.head_dim // 2)))
freqs = pos[:, None] * freqs
```
# 200Hz - 4000Hz (covers 95% of speech content)
freqs = (self.theta / 220.0) * 200 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
# 150Hz - 6000Hz (covers speech + some emotion/intonation)
freqs = (self.theta / 220.0) * 150 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 6000/150)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
# 80Hz - 2000Hz (focus on fundamental frequencies + first few harmonics)
freqs = (self.theta / 220.0) * 80 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 2000/80)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
# original
freqs = (self.theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
Standard RoPE: 1, 0.1, 0.01, 0.001... (arbitrary geometric)
This RoPE: 80Hz, 100Hz, 140Hz... (perceptually meaningful)
----
def _apply_radii(self, freqs, f0, ctx):
if self.radii and f0 is not None:
radius = f0.to(device, dtype)
return torch.polar(radius.unsqueeze(-1), freqs), radius
else:
return torch.polar(torch.ones_like(freqs), freqs), None
```python
def accumulate_phase(self, f0, t_frame, phi0=0.0):
omega = 2 * torch.pi * f0
dphi = omega * t_frame
phi = torch.cumsum(dphi, dim=0) + phi0
phi = torch.remainder(phi, 2 * torch.pi)
return phi
```
A closer look at whats going on. Here is a slice of the actual radius values for one step
[encoder] [Radius] torch.Size[454] 92.32 [Theta] 10092.01 [f0] torch.Size[454] [Freqs] torch.Size[454, 64] 2.17+1.17j [ctx] 454
[encoder] [Radius] tensor[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
283.7590, 260.6043, 257.3410, 261.8319, 249.1852, 257.2541, 263.8165,
272.2421, 277.6960, 286.9628, 303.8460, 305.1561, 319.5129, 330.6942,
362.0888, 355.8571, 352.8432, 336.9354, 313.0566, 319.9086, 303.4355,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 220.4299,
254.7619, 239.6506, 228.8830, 227.3063, 225.6784, 225.7169, 211.7767,
223.6572, 223.4174, 222.4496, 225.1645, 228.7840, 231.8760, 228.9148,
230.6227, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 205.7097, 202.8816, 182.0329, 181.6536, 180.2186,
177.8911, 176.0775, 171.3846, 173.9602, 170.4824, 171.5723, 172.0810,
174.3897, 177.3261, 188.3212, 188.9799, 186.7493, 221.3487, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 304.4458,
263.3122, 251.7635, 211.8467, 207.5651, 195.3680, 184.0717, 206.3800,
197.8661, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 197.6399, 195.0042, 190.7016, 187.0234, 183.5980,
183.6842, 185.0038, 185.5778, 187.4167, 185.5085, 183.4160, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 190.3175, 209.7377, 206.2731, 211.9862, 219.2756, 214.3068,
202.6881, 192.1823, 210.3404, 235.5456, 230.7845, 234.5441, 234.9773,
241.1199, 241.9640, 237.0773, 231.6952, 238.0375, 257.9242, 264.4094,
265.3747, 251.0286, 245.7093, 0.0000, 274.9167, 273.4767, 271.6227,
256.5457, 245.8942, 251.3361, 240.1572, 228.9316, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 202.6190, 217.7865, 212.3347, 208.2926,
209.9206, 209.3961, 210.3909, 211.6021, 202.8511, 205.1674, 211.7455,
217.9954, 0.0000, 264.9778, 229.7112, 200.8905, 182.4680, 179.4812,
175.4307, 172.7844, 173.6305, 172.1901, 170.5743, 167.2979, 166.7781,
166.7783, 170.8816, 173.0406, 176.2869, 181.9142, 212.7904, 170.4449,
173.1710, 168.3079, 154.1663, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 211.3942, 202.3412, 203.6764, 198.4441,
186.2687, 209.0010, 209.5012, 214.6487, 203.8741, 195.8432, 180.9673,
0.0000, 0.0000, 0.0000, 197.7340, 198.9476, 204.5347, 209.5858,
204.5406, 195.1048, 198.1545, 199.8559, 207.3548, 217.9402, 217.2366,
216.4711, 212.4731, 217.5183, 218.0658, 208.7833, 0.0000, 243.7485,
215.1998, 235.4733, 215.3242, 215.1489, 212.6266, 203.9319, 191.8531,
197.2219, 202.7850, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 224.1991, 167.9602,
190.8241, 178.5659, 175.4639, 172.6353, 173.5884, 173.2250, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 209.6374, 196.2949,
216.4672, 236.3051, 195.2339, 241.1573, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 195.8783, 145.3826, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
device='cuda:0'
What the Radius Values Tell Us:
1. Speech Structure is Clear
Zeros: Silenceunvoiced segments no pitch
Non-zero values: Voiced speech segments with pitch
Pattern: 0.0000 β 283.7590 β 0.0000 β 220.4299
2. Pitch Range is Realistic
Range: ~145-365 Hz
Typical speech: 80-400 Hz for most speakers
Model values: 145-365 Hz
3. Temporal Dynamics
Clusters: Pitch values cluster together natural speech
Transitions: Smooth changes between values
Silence gaps: Natural pauses in speech
Silence detection: 0.0000 = no pitch silenceunvoiced
Pitch extraction: 283.7590 = actual f0 values
Speech segmentation: Clear boundaries between voicedunvoiced
Realistic values: 145-365 Hz is normal speech range
Proper structure: Matches natural speech patterns
Variable radius: Working as intended
The Complex Frequency Result:
[Freqs] torch.Size[454, 64] 2.17+1.17j
Magnitude: sqrt2.17Β² + 1.17Β² β 2.5
Phase: atan21.17, 2.17 β 0.49 radians
Variable radius: Each frame has different magnitude
Silence frames: radius β 0 β freqs β 0
Voiced frames: radius β 200-300 β freqs β 2-3
Variable attention: Important frames get more attention
Silence: No acoustic prominence β low radius
Speech: High acoustic prominence β high radius
Transitions: Natural pitch changes
----
|