kevinwang676 commited on
Commit
e3bb1ba
·
verified ·
1 Parent(s): 53c32e3

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +89 -0
utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import torchaudio
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ logging.basicConfig(level=logging.DEBUG,
21
+ format='%(asctime)s %(levelname)s %(message)s')
22
+
23
+
24
+ def read_lists(list_file):
25
+ lists = []
26
+ with open(list_file, 'r', encoding='utf8') as fin:
27
+ for line in fin:
28
+ lists.append(line.strip())
29
+ return lists
30
+
31
+
32
+ def read_json_lists(list_file):
33
+ lists = read_lists(list_file)
34
+ results = {}
35
+ for fn in lists:
36
+ with open(fn, 'r', encoding='utf8') as fin:
37
+ results.update(json.load(fin))
38
+ return results
39
+
40
+
41
+ def load_wav(wav, target_sr):
42
+ speech, sample_rate = torchaudio.load(wav, backend='soundfile')
43
+ speech = speech.mean(dim=0, keepdim=True)
44
+ if sample_rate != target_sr:
45
+ assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
46
+ speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
47
+ return speech
48
+
49
+
50
+ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
51
+ import tensorrt as trt
52
+ _min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
53
+ _opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
54
+ _max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]
55
+ input_names = ["x", "mask", "mu", "t", "spks", "cond"]
56
+
57
+ logging.info("Converting onnx to trt...")
58
+ network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
59
+ logger = trt.Logger(trt.Logger.INFO)
60
+ builder = trt.Builder(logger)
61
+ network = builder.create_network(network_flags)
62
+ parser = trt.OnnxParser(network, logger)
63
+ config = builder.create_builder_config()
64
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
65
+ if fp16:
66
+ config.set_flag(trt.BuilderFlag.FP16)
67
+ profile = builder.create_optimization_profile()
68
+ # load onnx model
69
+ with open(onnx_model, "rb") as f:
70
+ if not parser.parse(f.read()):
71
+ for error in range(parser.num_errors):
72
+ print(parser.get_error(error))
73
+ raise ValueError('failed to parse {}'.format(onnx_model))
74
+ # set input shapes
75
+ for i in range(len(input_names)):
76
+ profile.set_shape(input_names[i], _min_shape[i], _opt_shape[i], _max_shape[i])
77
+ tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
78
+ # set input and output data type
79
+ for i in range(network.num_inputs):
80
+ input_tensor = network.get_input(i)
81
+ input_tensor.dtype = tensor_dtype
82
+ for i in range(network.num_outputs):
83
+ output_tensor = network.get_output(i)
84
+ output_tensor.dtype = tensor_dtype
85
+ config.add_optimization_profile(profile)
86
+ engine_bytes = builder.build_serialized_network(network, config)
87
+ # save trt engine
88
+ with open(trt_model, "wb") as f:
89
+ f.write(engine_bytes)