Update moonshine to support batch decoding
Browse files- modeling_moonshine.py +4 -3
modeling_moonshine.py
CHANGED
@@ -398,7 +398,7 @@ class AudioPreprocessor(nn.Module):
|
|
398 |
assert (
|
399 |
src.shape[-1] >= 1023
|
400 |
), f"src shape[-1] {src.shape[-1]} should be at least 1023"
|
401 |
-
src = src.
|
402 |
return self.audio_preprocess(src)
|
403 |
|
404 |
|
@@ -435,7 +435,8 @@ class MoonshineModelTorch(nn.Module):
|
|
435 |
sot_token = 1
|
436 |
eot_token = 2
|
437 |
|
438 |
-
|
|
|
439 |
|
440 |
vals = self.decoder_initial(x=seq, enc_src=enc)
|
441 |
logits = vals[0]
|
@@ -448,7 +449,7 @@ class MoonshineModelTorch(nn.Module):
|
|
448 |
seq = torch.cat((seq, sample), dim=-1)
|
449 |
|
450 |
seq_len = int(src.shape[-1] * 6.5 / 16000)
|
451 |
-
while
|
452 |
vals = self.decoder(
|
453 |
seq,
|
454 |
*k_cache,
|
|
|
398 |
assert (
|
399 |
src.shape[-1] >= 1023
|
400 |
), f"src shape[-1] {src.shape[-1]} should be at least 1023"
|
401 |
+
src = src.reshape((-1, 1, src.shape[-1]))
|
402 |
return self.audio_preprocess(src)
|
403 |
|
404 |
|
|
|
435 |
sot_token = 1
|
436 |
eot_token = 2
|
437 |
|
438 |
+
sot_array = [[sot_token] for _ in range(enc.shape[0])]
|
439 |
+
seq = torch.as_tensor(sot_array).to(src.device)
|
440 |
|
441 |
vals = self.decoder_initial(x=seq, enc_src=enc)
|
442 |
logits = vals[0]
|
|
|
449 |
seq = torch.cat((seq, sample), dim=-1)
|
450 |
|
451 |
seq_len = int(src.shape[-1] * 6.5 / 16000)
|
452 |
+
while any([eot_token not in sub_seq for sub_seq in seq]) and seq.shape[-1] <= seq_len:
|
453 |
vals = self.decoder(
|
454 |
seq,
|
455 |
*k_cache,
|