Upload ctcalign.py
Browse files- ctcalign.py +3 -3
ctcalign.py
CHANGED
@@ -26,10 +26,10 @@ def f2s(fr):
|
|
26 |
class CTCAligner:
|
27 |
|
28 |
def __init__(self, model_path,model_word_separator, model_blank_token):
|
29 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
torch.random.manual_seed(0)
|
31 |
|
32 |
-
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
|
33 |
self.processor = Wav2Vec2Processor.from_pretrained(model_path)
|
34 |
|
35 |
# build labels dict from a processor where it is not directly accessible
|
@@ -56,7 +56,7 @@ class CTCAligner:
|
|
56 |
def get_frame_probs(wav,aligner):
|
57 |
with torch.inference_mode(): # similar to with torch.no_grad():
|
58 |
input_values = aligner.processor(wav,sampling_rate=16000).input_values[0]
|
59 |
-
input_values = torch.tensor(input_values
|
60 |
emits = aligner.model(input_values).logits
|
61 |
emits = torch.log_softmax(emits, dim=-1)
|
62 |
return emits[0].cpu().detach()
|
|
|
26 |
class CTCAligner:
|
27 |
|
28 |
def __init__(self, model_path,model_word_separator, model_blank_token):
|
29 |
+
#self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
torch.random.manual_seed(0)
|
31 |
|
32 |
+
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)#.to(self.device)
|
33 |
self.processor = Wav2Vec2Processor.from_pretrained(model_path)
|
34 |
|
35 |
# build labels dict from a processor where it is not directly accessible
|
|
|
56 |
def get_frame_probs(wav,aligner):
|
57 |
with torch.inference_mode(): # similar to with torch.no_grad():
|
58 |
input_values = aligner.processor(wav,sampling_rate=16000).input_values[0]
|
59 |
+
input_values = torch.tensor(input_values).unsqueeze(0)#, device=aligner.device).unsqueeze(0)
|
60 |
emits = aligner.model(input_values).logits
|
61 |
emits = torch.log_softmax(emits, dim=-1)
|
62 |
return emits[0].cpu().detach()
|