File size: 3,324 Bytes
568e264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os

import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi

from wenet.cli.hub import Hub
from wenet.paraformer.search import (gen_timestamps_from_peak,
                                     paraformer_greedy_search)
from wenet.text.paraformer_tokenizer import ParaformerTokenizer
from wenet.utils.common import TORCH_NPU_AVAILABLE  # noqa just ensure to check torch-npu


class Paraformer:

    def __init__(self, model_dir: str, resample_rate: int = 16000) -> None:

        model_path = os.path.join(model_dir, 'final.zip')
        units_path = os.path.join(model_dir, 'units.txt')
        self.model = torch.jit.load(model_path)
        self.resample_rate = resample_rate
        self.device = torch.device("cpu")
        self.tokenizer = ParaformerTokenizer(symbol_table=units_path)

    def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
        waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
        waveform = waveform.to(torch.float).to(self.device)
        if sample_rate != self.resample_rate:
            waveform = torchaudio.transforms.Resample(
                orig_freq=sample_rate, new_freq=self.resample_rate)(waveform)
        feats = kaldi.fbank(waveform,
                            num_mel_bins=80,
                            frame_length=25,
                            frame_shift=10,
                            energy_floor=0.0,
                            sample_frequency=self.resample_rate,
                            window_type="hamming")
        feats = feats.unsqueeze(0)
        feats_lens = torch.tensor([feats.size(1)],
                                  dtype=torch.int64,
                                  device=feats.device)

        decoder_out, token_num, tp_alphas = self.model.forward_paraformer(
            feats, feats_lens)
        cif_peaks = self.model.forward_cif_peaks(tp_alphas, token_num)
        res = paraformer_greedy_search(decoder_out, token_num, cif_peaks)[0]
        result = {}
        result['confidence'] = res.confidence
        result['text'] = self.tokenizer.detokenize(res.tokens)[0]
        if tokens_info:
            tokens_info = []
            times = gen_timestamps_from_peak(res.times,
                                             num_frames=tp_alphas.size(1),
                                             frame_rate=0.02)

            for i, x in enumerate(res.tokens):
                tokens_info.append({
                    'token': self.tokenizer.char_dict[x],
                    'start': round(times[i][0], 3),
                    'end': round(times[i][1], 3),
                    'confidence': round(res.tokens_confidence[i], 2)
                })
            result['tokens'] = tokens_info

        return result

    def align(self, audio_file: str, label: str) -> dict:
        raise NotImplementedError("Align is currently not supported")


def load_model(model_dir: str = None,
               gpu: int = -1,
               device: str = "cpu") -> Paraformer:
    if model_dir is None:
        model_dir = Hub.get_model_by_lang('paraformer')
    if gpu != -1:
        # remain the original usage of gpu
        device = "cuda"
    paraformer = Paraformer(model_dir)
    paraformer.device = torch.device(device)
    paraformer.model.to(device)
    return paraformer