File size: 11,996 Bytes
d5a705a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d03cd4
 
 
11c9723
0fd1326
eaeb462
 
 
 
 
 
 
 
 
 
93b313e
67c6d40
05b483e
0fd1326
67c6d40
eb210a9
25a309a
87ec76d
e38891a
 
 
 
 
034db85
25a309a
6745dd0
d4f1c3c
36e60da
2bc5a91
d6f816b
36e60da
d6f816b
2bc5a91
0fd1326
 
 
018f69e
 
 
d6f816b
 
018f69e
 
 
d6f816b
 
 
018f69e
d6f816b
018f69e
f7847b1
 
 
11c9723
f7847b1
 
 
018f69e
 
d8794a9
67c6d40
f7847b1
5470638
11c9723
 
 
5470638
 
11c9723
 
 
 
f7847b1
606e3d4
 
 
 
 
 
 
 
 
 
 
5470638
11c9723
606e3d4
 
67c6d40
606e3d4
67c6d40
 
 
 
 
 
 
d6f816b
 
d8794a9
d6f816b
 
 
 
 
 
018f69e
1e519f6
82d6276
 
d6f816b
82d6276
d6f816b
82d6276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6f816b
82d6276
 
 
 
 
d6f816b
 
82d6276
 
 
 
 
 
 
 
 
 
d6f816b
82d6276
 
 
d6f816b
82d6276
d6f816b
82d6276
 
 
 
 
 
 
d6f816b
82d6276
 
d6f816b
 
82d6276
 
 
 
 
 
 
 
 
 
 
 
2bc5a91
034db85
 
4d76b89
b693b35
05b483e
f985f06
 
034db85
 
f7847b1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
---
license: apache-2.0
datasets:
- google/fleurs
metrics:
- wer
- accuracy
- cer
pipeline_tag: automatic-speech-recognition
tags:
- pitch
- f0
- echo
- whiper
- waveform
- spectrogram
- hilbert
- asr
- nlp
- new

---

ASR model + pitch aware relative positional embeddings. 

### Decrease WER significantly compared to standard inverse frequency. 'eval_wer': 35.3

    def _compute_freqs_base(self):
        mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
        return 200 * mel_scale / 1000 

### Standared inv freqs: 'eval_wer': 61.6
     freqs = 1.0 / (self.theta ** (torch.arange(0, self.head_dim, 2, device=device, dtype=dtype) / (self.head_dim // 2)))
     

<img width="1363" height="732" alt="pitch_spectrogram" src="https://github.com/user-attachments/assets/ceb65e94-7df4-41b7-aa3d-c4aa4c6c0717" />

<img width="233" height="77" alt="legend" src="https://github.com/user-attachments/assets/fad84550-a199-43b3-8471-d011a9fd6f94" />

https://huggingface.co/Sin2pi/asr-model/tensorboard

Questions:

   -How can we make attention mechanisms aware of speech-specific properties?
   
   -Can we incorporate acoustic information directly into positional encodings?
   
   -Does pitch-conditioning improve speech recognition?

---


To explore the relationship between pitch and rotary embeddings, the model implements three complementary pitch based enhancements:

1. Pitch modulated theta Pitch f0 is used to modify the theta parameter, dynamically adjusting the rotary frequency.
2. Direct similarity bias: A pitch based similarity bias is added directly to the attention mechanism.
3. Variable radii in torch.polar: The unit circle radius 1.0 in the torch.polar calculation is replaced with variable radii derived from f0. This creates acoustically-weighted positional encodings, so each position in the embedding space reflects the acoustic prominence in the original speech. This approach effectively adds phase and amplitutde information without significant computational overhead.




The function `torch.polar` constructs a complex tensor from polar coordinates:

````python
# torch.polarmagnitude, angle returns:
result = magnitude * torch.cosangle + 1j * torch.sinangle
````

So, for each element:
- magnitude is the modulus radius, r
- angle is the phase theta, in radians
- The result is: `r * expi * theta = r * costheta + i * sintheta`

Reference: [PyTorch Documentation - torch.polar]https:pytorch.orgdocsstablegeneratedtorch.polar.html




<img width="1370" height="576" alt="123123" src="https://github.com/user-attachments/assets/17031084-48aa-46db-8b12-c025417f3074" />




```python


# Modified freq calculation:

      pos = torch.arange(ctx, device=device, dtype=dtype) 
      freqs = (self.theta / 220.0) * 200 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
      freqs = pos[:, None] * freqs

# standard
     pos = torch.arange(ctx, dtype=torch.float32, device=device)
     freqs = 1.0 / (self.theta ** (torch.arange(0, self.head_dim, 2, device=device, dtype=dtype) / (self.head_dim // 2)))
     freqs = pos[:, None] * freqs

```
      # 200Hz - 4000Hz (covers 95% of speech content)
      freqs = (self.theta / 220.0) * 200 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
      
      # 150Hz - 6000Hz (covers speech + some emotion/intonation)
      freqs = (self.theta / 220.0) * 150 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 6000/150)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
      
      # 80Hz - 2000Hz (focus on fundamental frequencies + first few harmonics)
      freqs = (self.theta / 220.0) * 80 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 2000/80)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000

      # original
      freqs = (self.theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000


Standard RoPE: 1, 0.1, 0.01, 0.001... (arbitrary geometric)
This RoPE: 80Hz, 100Hz, 140Hz... (perceptually meaningful)

----
    def _apply_radii(self, freqs, f0, ctx):
        if self.radii and f0 is not None:
            radius = f0.to(device, dtype)
            return torch.polar(radius.unsqueeze(-1), freqs), radius
        else:
            return torch.polar(torch.ones_like(freqs), freqs), None


```python

    def accumulate_phase(self, f0, t_frame, phi0=0.0):
        omega = 2 * torch.pi * f0
        dphi = omega * t_frame
        phi = torch.cumsum(dphi, dim=0) + phi0
        phi = torch.remainder(phi, 2 * torch.pi) 
        return phi
```

A closer look at whats going on. Here is a slice of the actual radius values for one step
      
      [encoder] [Radius] torch.Size[454] 92.32 [Theta] 10092.01 [f0] torch.Size[454] [Freqs] torch.Size[454, 64] 2.17+1.17j [ctx] 454
      
     [encoder] [Radius] tensor[  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           283.7590, 260.6043, 257.3410, 261.8319, 249.1852, 257.2541, 263.8165,
           272.2421, 277.6960, 286.9628, 303.8460, 305.1561, 319.5129, 330.6942,
           362.0888, 355.8571, 352.8432, 336.9354, 313.0566, 319.9086, 303.4355,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000, 220.4299,
           254.7619, 239.6506, 228.8830, 227.3063, 225.6784, 225.7169, 211.7767,
           223.6572, 223.4174, 222.4496, 225.1645, 228.7840, 231.8760, 228.9148,
           230.6227,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000, 205.7097, 202.8816, 182.0329, 181.6536, 180.2186,
           177.8911, 176.0775, 171.3846, 173.9602, 170.4824, 171.5723, 172.0810,
           174.3897, 177.3261, 188.3212, 188.9799, 186.7493, 221.3487,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000, 304.4458,
           263.3122, 251.7635, 211.8467, 207.5651, 195.3680, 184.0717, 206.3800,
           197.8661,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000, 197.6399, 195.0042, 190.7016, 187.0234, 183.5980,
           183.6842, 185.0038, 185.5778, 187.4167, 185.5085, 183.4160,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000, 190.3175, 209.7377, 206.2731, 211.9862, 219.2756, 214.3068,
           202.6881, 192.1823, 210.3404, 235.5456, 230.7845, 234.5441, 234.9773,
           241.1199, 241.9640, 237.0773, 231.6952, 238.0375, 257.9242, 264.4094,
           265.3747, 251.0286, 245.7093,   0.0000, 274.9167, 273.4767, 271.6227,
           256.5457, 245.8942, 251.3361, 240.1572, 228.9316,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000, 202.6190, 217.7865, 212.3347, 208.2926,
           209.9206, 209.3961, 210.3909, 211.6021, 202.8511, 205.1674, 211.7455,
           217.9954,   0.0000, 264.9778, 229.7112, 200.8905, 182.4680, 179.4812,
           175.4307, 172.7844, 173.6305, 172.1901, 170.5743, 167.2979, 166.7781,
           166.7783, 170.8816, 173.0406, 176.2869, 181.9142, 212.7904, 170.4449,
           173.1710, 168.3079, 154.1663,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000, 211.3942, 202.3412, 203.6764, 198.4441,
           186.2687, 209.0010, 209.5012, 214.6487, 203.8741, 195.8432, 180.9673,
             0.0000,   0.0000,   0.0000, 197.7340, 198.9476, 204.5347, 209.5858,
           204.5406, 195.1048, 198.1545, 199.8559, 207.3548, 217.9402, 217.2366,
           216.4711, 212.4731, 217.5183, 218.0658, 208.7833,   0.0000, 243.7485,
           215.1998, 235.4733, 215.3242, 215.1489, 212.6266, 203.9319, 191.8531,
           197.2219, 202.7850,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000, 224.1991, 167.9602,
           190.8241, 178.5659, 175.4639, 172.6353, 173.5884, 173.2250,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000, 209.6374, 196.2949,
           216.4672, 236.3051, 195.2339, 241.1573,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000, 195.8783, 145.3826,   0.0000,   0.0000,   0.0000,   0.0000,
             0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000],
          device='cuda:0'
          
What the Radius Values Tell Us:

1. Speech Structure is Clear
   
         Zeros: Silenceunvoiced segments no pitch
         Non-zero values: Voiced speech segments with pitch
         Pattern: 0.0000 β†’ 283.7590 β†’ 0.0000 β†’ 220.4299

2. Pitch Range is Realistic
   
         Range: ~145-365 Hz
         Typical speech: 80-400 Hz for most speakers
         Model values: 145-365 Hz

3. Temporal Dynamics
   
         Clusters: Pitch values cluster together natural speech
         Transitions: Smooth changes between values
         Silence gaps: Natural pauses in speech

         Silence detection: 0.0000 = no pitch silenceunvoiced
         Pitch extraction: 283.7590 = actual f0 values
         Speech segmentation: Clear boundaries between voicedunvoiced
         
         Realistic values: 145-365 Hz is normal speech range
         Proper structure: Matches natural speech patterns
         Variable radius: Working as intended

The Complex Frequency Result:

      [Freqs] torch.Size[454, 64] 2.17+1.17j


      Magnitude: sqrt2.17Β² + 1.17Β² β‰ˆ 2.5
      Phase: atan21.17, 2.17 β‰ˆ 0.49 radians
      
      Variable radius: Each frame has different magnitude

      Silence frames: radius β‰ˆ 0 β†’ freqs β‰ˆ 0
      Voiced frames: radius β‰ˆ 200-300 β†’ freqs β‰ˆ 2-3
      
      Variable attention: Important frames get more attention

      Silence: No acoustic prominence β†’ low radius
      Speech: High acoustic prominence β†’ high radius
      Transitions: Natural pitch changes






----