osanseviero
commited on
Commit
·
5913155
1
Parent(s):
46d31bb
Update pipeline.py
Browse files- pipeline.py +5 -3
pipeline.py
CHANGED
@@ -9,9 +9,9 @@ class PreTrainedPipeline():
|
|
9 |
"""
|
10 |
Initialize model
|
11 |
"""
|
12 |
-
processor = Wav2Vec2Processor.from_pretrained(path)
|
13 |
-
model = Wav2Vec2ForCTC.from_pretrained(path)
|
14 |
-
vocab_list = list(processor.tokenizer.get_vocab().keys())
|
15 |
|
16 |
# convert ctc blank character representation
|
17 |
vocab_list[0] = ""
|
@@ -39,6 +39,8 @@ class PreTrainedPipeline():
|
|
39 |
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
|
40 |
the detected text from the input audio.
|
41 |
"""
|
|
|
|
|
42 |
return {
|
43 |
"text": self.decoder.decode(logits)
|
44 |
}
|
|
|
9 |
"""
|
10 |
Initialize model
|
11 |
"""
|
12 |
+
self.processor = Wav2Vec2Processor.from_pretrained(path)
|
13 |
+
self.model = Wav2Vec2ForCTC.from_pretrained(path)
|
14 |
+
vocab_list = list(self.processor.tokenizer.get_vocab().keys())
|
15 |
|
16 |
# convert ctc blank character representation
|
17 |
vocab_list[0] = ""
|
|
|
39 |
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
|
40 |
the detected text from the input audio.
|
41 |
"""
|
42 |
+
input_values = self.processor(arr, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
|
43 |
+
logits = self.model(input_values).logits.cpu().detach().numpy()[0]
|
44 |
return {
|
45 |
"text": self.decoder.decode(logits)
|
46 |
}
|