njeffrie commited on
Commit
2ef02e9
1 Parent(s): a6fc06f

Update moonshine to support batch decoding

Browse files
Files changed (1) hide show
  1. modeling_moonshine.py +4 -3
modeling_moonshine.py CHANGED
@@ -398,7 +398,7 @@ class AudioPreprocessor(nn.Module):
398
  assert (
399
  src.shape[-1] >= 1023
400
  ), f"src shape[-1] {src.shape[-1]} should be at least 1023"
401
- src = src.unsqueeze(-2)
402
  return self.audio_preprocess(src)
403
 
404
 
@@ -435,7 +435,8 @@ class MoonshineModelTorch(nn.Module):
435
  sot_token = 1
436
  eot_token = 2
437
 
438
- seq = torch.as_tensor([[sot_token]]).to(src.device)
 
439
 
440
  vals = self.decoder_initial(x=seq, enc_src=enc)
441
  logits = vals[0]
@@ -448,7 +449,7 @@ class MoonshineModelTorch(nn.Module):
448
  seq = torch.cat((seq, sample), dim=-1)
449
 
450
  seq_len = int(src.shape[-1] * 6.5 / 16000)
451
- while sample != eot_token and len(seq.flatten()) <= seq_len:
452
  vals = self.decoder(
453
  seq,
454
  *k_cache,
 
398
  assert (
399
  src.shape[-1] >= 1023
400
  ), f"src shape[-1] {src.shape[-1]} should be at least 1023"
401
+ src = src.reshape((-1, 1, src.shape[-1]))
402
  return self.audio_preprocess(src)
403
 
404
 
 
435
  sot_token = 1
436
  eot_token = 2
437
 
438
+ sot_array = [[sot_token] for _ in range(enc.shape[0])]
439
+ seq = torch.as_tensor(sot_array).to(src.device)
440
 
441
  vals = self.decoder_initial(x=seq, enc_src=enc)
442
  logits = vals[0]
 
449
  seq = torch.cat((seq, sample), dim=-1)
450
 
451
  seq_len = int(src.shape[-1] * 6.5 / 16000)
452
+ while any([eot_token not in sub_seq for sub_seq in seq]) and seq.shape[-1] <= seq_len:
453
  vals = self.decoder(
454
  seq,
455
  *k_cache,