File size: 1,969 Bytes
613af8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import argparse
import torch
from whisper import load_model
import os
from openvino.tools import mo
from openvino.frontend import FrontEndManager
from openvino.runtime import serialize
import shutil

def convert_encoder(hparams, encoder, mname):
    encoder.eval()

    mel = torch.zeros((1, hparams.n_mels, 3000))

    onnx_folder = os.path.join(os.path.dirname(__file__), "onnx_encoder")

    #create a directory to store the onnx model, and other collateral that is saved during onnx export procedure
    if not os.path.isdir(onnx_folder):
        os.makedirs(onnx_folder)

    onnx_path = os.path.join(onnx_folder, "whisper_encoder.onnx")

    # Export the PyTorch model to ONNX
    torch.onnx.export(
        encoder,
        mel,
        onnx_path,
        input_names=["mel"],
        output_names=["output_features"]
    )

    # Convert ONNX to OpenVINO IR format using the frontend
    fem = FrontEndManager()
    onnx_fe = fem.load_by_framework("onnx")
    onnx_model = onnx_fe.load(onnx_path)
    ov_model = onnx_fe.convert(onnx_model)

    # Serialize the OpenVINO model to XML and BIN files
    serialize(ov_model, xml_path=os.path.join(os.path.dirname(__file__), "ggml-" + mname + "-encoder-openvino.xml"))

    # Cleanup
    if os.path.isdir(onnx_folder):
        shutil.rmtree(onnx_folder)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3)", required=True)
    args = parser.parse_args()

    if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v1", "large-v2", "large-v3"]:
        raise ValueError("Invalid model name")

    whisper = load_model(args.model).cpu()
    hparams = whisper.dims

    encoder = whisper.encoder

    # Convert encoder to onnx
    convert_encoder(hparams, encoder, args.model)