clr commited on
Commit
477b0e7
·
1 Parent(s): 2739231

Upload ctcalign.py

Browse files
Files changed (1) hide show
  1. ctcalign.py +3 -3
ctcalign.py CHANGED
@@ -26,10 +26,10 @@ def f2s(fr):
26
  class CTCAligner:
27
 
28
  def __init__(self, model_path,model_word_separator, model_blank_token):
29
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  torch.random.manual_seed(0)
31
 
32
- self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device)
33
  self.processor = Wav2Vec2Processor.from_pretrained(model_path)
34
 
35
  # build labels dict from a processor where it is not directly accessible
@@ -56,7 +56,7 @@ class CTCAligner:
56
  def get_frame_probs(wav,aligner):
57
  with torch.inference_mode(): # similar to with torch.no_grad():
58
  input_values = aligner.processor(wav,sampling_rate=16000).input_values[0]
59
- input_values = torch.tensor(input_values, device=aligner.device).unsqueeze(0)
60
  emits = aligner.model(input_values).logits
61
  emits = torch.log_softmax(emits, dim=-1)
62
  return emits[0].cpu().detach()
 
26
  class CTCAligner:
27
 
28
  def __init__(self, model_path,model_word_separator, model_blank_token):
29
+ #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  torch.random.manual_seed(0)
31
 
32
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_path)#.to(self.device)
33
  self.processor = Wav2Vec2Processor.from_pretrained(model_path)
34
 
35
  # build labels dict from a processor where it is not directly accessible
 
56
  def get_frame_probs(wav,aligner):
57
  with torch.inference_mode(): # similar to with torch.no_grad():
58
  input_values = aligner.processor(wav,sampling_rate=16000).input_values[0]
59
+ input_values = torch.tensor(input_values).unsqueeze(0)#, device=aligner.device).unsqueeze(0)
60
  emits = aligner.model(input_values).logits
61
  emits = torch.log_softmax(emits, dim=-1)
62
  return emits[0].cpu().detach()