# Copyright (c) 2022, Horizon Inc. Xingchen Song (sxc19@tsinghua.org.cn) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """NOTE(xcsong): Currently, we only support 1. specific conformer encoder architecture, see: encoder: conformer encoder_conf: activation_type: **must be** relu attention_heads: 2 or 4 or 8 or any number divisible by output_size causal: **must be** true cnn_module_kernel: 1 ~ 7 cnn_module_norm: **must be** batch_norm input_layer: **must be** conv2d8 linear_units: 1 ~ 2048 normalize_before: **must be** true num_blocks: 1 ~ 12 output_size: 1 ~ 512 pos_enc_layer_type: **must be** no_pos selfattention_layer_type: **must be** selfattn use_cnn_module: **must be** true use_dynamic_chunk: **must be** true use_dynamic_left_chunk: **must be** true 2. specific decoding method: ctc_greedy_search """ from __future__ import print_function import os import sys import copy import math import yaml import logging from typing import Tuple import torch import numpy as np from wenet.transformer.embedding import NoPositionalEncoding from wenet.utils.init_model import init_model from wenet.bin.export_onnx_cpu import (get_args, to_numpy, print_input_output_info) try: import onnx import onnxruntime except ImportError: print('Please install onnx and onnxruntime!') sys.exit(1) logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) class BPULayerNorm(torch.nn.Module): """Refactor torch.nn.LayerNorm to meet 4-D dataflow.""" def __init__(self, module, chunk_size=8, run_on_bpu=False): super().__init__() original = copy.deepcopy(module) self.hidden = module.weight.size(0) self.chunk_size = chunk_size self.run_on_bpu = run_on_bpu if self.run_on_bpu: self.weight = torch.nn.Parameter( module.weight.reshape(1, self.hidden, 1, 1).repeat(1, 1, 1, chunk_size)) self.bias = torch.nn.Parameter( module.bias.reshape(1, self.hidden, 1, 1).repeat(1, 1, 1, chunk_size)) self.negtive = torch.nn.Parameter( torch.ones((1, self.hidden, 1, chunk_size)) * -1.0) self.eps = torch.nn.Parameter( torch.zeros((1, self.hidden, 1, chunk_size)) + module.eps) self.mean_conv_1 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False) self.mean_conv_1.weight = torch.nn.Parameter( torch.ones(self.hidden, self.hidden, 1, 1) / (1.0 * self.hidden)) self.mean_conv_2 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False) self.mean_conv_2.weight = torch.nn.Parameter( torch.ones(self.hidden, self.hidden, 1, 1) / (1.0 * self.hidden)) else: self.norm = module self.check_equal(original) def check_equal(self, module): random_data = torch.randn(1, self.chunk_size, self.hidden) orig_out = module(random_data) new_out = self.forward(random_data.transpose(1, 2).unsqueeze(2)) np.testing.assert_allclose(to_numpy(orig_out), to_numpy( new_out.squeeze(2).transpose(1, 2)), rtol=1e-02, atol=1e-03) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.run_on_bpu: u = self.mean_conv_1(x) # (1, h, 1, c) numerator = x + u * self.negtive # (1, h, 1, c) s = torch.pow(numerator, 2) # (1, h, 1, c) s = self.mean_conv_2(s) # (1, h, 1, c) denominator = torch.sqrt(s + self.eps) # (1, h, 1, c) x = torch.div(numerator, denominator) # (1, h, 1, c) x = x * self.weight + self.bias else: x = x.squeeze(2).transpose(1, 2).contiguous() x = self.norm(x) x = x.transpose(1, 2).contiguous().unsqueeze(2) return x class BPUIdentity(torch.nn.Module): """Refactor torch.nn.Identity(). For inserting BPU node whose input == output. """ def __init__(self, channels): super().__init__() self.channels = channels self.identity_conv = torch.nn.Conv2d(channels, channels, 1, groups=channels, bias=False) torch.nn.init.dirac_(self.identity_conv.weight.data, groups=channels) self.check_equal() def check_equal(self): random_data = torch.randn(1, self.channels, 1, 10) result = self.forward(random_data) np.testing.assert_allclose(to_numpy(random_data), to_numpy(result), rtol=1e-02, atol=1e-03) def forward(self, x: torch.Tensor) -> torch.Tensor: """Identity with 4-D dataflow, input == output. Args: x (torch.Tensor): (batch, in_channel, 1, time) Returns: (torch.Tensor): (batch, in_channel, 1, time). """ return self.identity_conv(x) class BPULinear(torch.nn.Module): """Refactor torch.nn.Linear or pointwise_conv""" def __init__(self, module, is_pointwise_conv=False): super().__init__() # Unchanged submodules and attributes original = copy.deepcopy(module) self.idim = module.weight.size(1) self.odim = module.weight.size(0) self.is_pointwise_conv = is_pointwise_conv # Modify weight & bias self.linear = torch.nn.Conv2d(self.idim, self.odim, 1, 1) if is_pointwise_conv: # (odim, idim, kernel=1) -> (odim, idim, 1, 1) self.linear.weight = torch.nn.Parameter( module.weight.unsqueeze(-1)) else: # (odim, idim) -> (odim, idim, 1, 1) self.linear.weight = torch.nn.Parameter( module.weight.unsqueeze(2).unsqueeze(3)) self.linear.bias = module.bias self.check_equal(original) def check_equal(self, module): random_data = torch.randn(1, 8, self.idim) if self.is_pointwise_conv: random_data = random_data.transpose(1, 2) original_result = module(random_data) if self.is_pointwise_conv: random_data = random_data.transpose(1, 2) original_result = original_result.transpose(1, 2) random_data = random_data.transpose(1, 2).unsqueeze(2) new_result = self.forward(random_data) np.testing.assert_allclose(to_numpy(original_result), to_numpy( new_result.squeeze(2).transpose(1, 2)), rtol=1e-02, atol=1e-03) def forward(self, x: torch.Tensor) -> torch.Tensor: """Linear with 4-D dataflow. Args: x (torch.Tensor): (batch, in_channel, 1, time) Returns: (torch.Tensor): (batch, out_channel, 1, time). """ return self.linear(x) class BPUGlobalCMVN(torch.nn.Module): """Refactor wenet/transformer/cmvn.py::GlobalCMVN""" def __init__(self, module): super().__init__() # Unchanged submodules and attributes self.norm_var = module.norm_var # NOTE(xcsong): Expand to 4-D tensor, (mel_dim) -> (1, 1, mel_dim, 1) self.mean = module.mean.unsqueeze(-1).unsqueeze(0).unsqueeze(0) self.istd = module.istd.unsqueeze(-1).unsqueeze(0).unsqueeze(0) def forward(self, x: torch.Tensor) -> torch.Tensor: """CMVN with 4-D dataflow. Args: x (torch.Tensor): (batch, 1, mel_dim, time) Returns: (torch.Tensor): normalized feature with same shape. """ x = x - self.mean if self.norm_var: x = x * self.istd return x class BPUConv2dSubsampling8(torch.nn.Module): """Refactor wenet/transformer/subsampling.py::Conv2dSubsampling8 NOTE(xcsong): Only support pos_enc_class == NoPositionalEncoding """ def __init__(self, module): super().__init__() # Unchanged submodules and attributes original = copy.deepcopy(module) self.right_context = module.right_context self.subsampling_rate = module.subsampling_rate assert isinstance(module.pos_enc, NoPositionalEncoding) # 1. Modify self.conv # NOTE(xcsong): We change input shape from (1, 1, frames, mel_dim) # to (1, 1, mel_dim, frames) for more efficient computation. self.conv = module.conv for idx in [0, 2, 4]: self.conv[idx].weight = torch.nn.Parameter( module.conv[idx].weight.transpose(2, 3)) # 2. Modify self.linear # NOTE(xcsong): Split final projection to meet the requirment of # maximum kernel_size (7 for XJ3) self.linear = torch.nn.ModuleList() odim = module.linear.weight.size(0) # 512, in this case freq = module.linear.weight.size(1) // odim # 4608 // 512 == 9 self.odim, self.freq = odim, freq weight = module.linear.weight.reshape( odim, odim, freq, 1) # (odim, odim * freq) -> (odim, odim, freq, 1) self.split_size = [] num_split = (freq - 1) // 7 + 1 # XJ3 requires kernel_size <= 7 slice_begin = 0 for idx in range(num_split): kernel_size = min(freq, (idx + 1) * 7) - idx * 7 conv_ele = torch.nn.Conv2d(odim, odim, (kernel_size, 1), (kernel_size, 1)) conv_ele.weight = torch.nn.Parameter( weight[:, :, slice_begin:slice_begin + kernel_size, :]) conv_ele.bias = torch.nn.Parameter(torch.zeros_like(conv_ele.bias)) self.linear.append(conv_ele) self.split_size.append(kernel_size) slice_begin += kernel_size self.linear[0].bias = torch.nn.Parameter(module.linear.bias) self.check_equal(original) def check_equal(self, module): random_data = torch.randn(1, 67, 80) mask = torch.zeros(1, 1, 67) original_result, _, _ = module(random_data, mask) # (1, 8, 512) random_data = random_data.transpose(1, 2).unsqueeze(0) # (1, 1, 80, 67) new_result = self.forward(random_data) # (1, 512, 1, 8) np.testing.assert_allclose(to_numpy(original_result), to_numpy( new_result.squeeze(2).transpose(1, 2)), rtol=1e-02, atol=1e-03) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x with 4-D dataflow. Args: x (torch.Tensor): Input tensor (#batch, 1, mel_dim, time). Returns: torch.Tensor: Subsampled tensor (#batch, odim, 1, time'), where time' = time // 8. """ x = self.conv(x) # (1, odim, freq, time') x_out = torch.zeros(x.size(0), self.odim, 1, x.size(3)) x = torch.split(x, self.split_size, dim=2) for idx, (x_part, layer) in enumerate(zip(x, self.linear)): x_out += layer(x_part) return x_out class BPUMultiHeadedAttention(torch.nn.Module): """Refactor wenet/transformer/attention.py::MultiHeadedAttention NOTE(xcsong): Only support attention_class == MultiHeadedAttention, we do not consider RelPositionMultiHeadedAttention currently. """ def __init__(self, module, chunk_size, left_chunks): super().__init__() # Unchanged submodules and attributes original = copy.deepcopy(module) self.d_k = module.d_k self.h = module.h n_feat = self.d_k * self.h self.chunk_size = chunk_size self.left_chunks = left_chunks self.time = chunk_size * (left_chunks + 1) self.activation = torch.nn.Softmax(dim=-1) # 1. Modify self.linear_x self.linear_q = BPULinear(module.linear_q) self.linear_k = BPULinear(module.linear_k) self.linear_v = BPULinear(module.linear_v) self.linear_out = BPULinear(module.linear_out) # 2. denom self.register_buffer( "denom", torch.full((1, self.h, 1, 1), 1.0 / math.sqrt(self.d_k))) self.check_equal(original) def check_equal(self, module): random_data = torch.randn(1, self.chunk_size, self.d_k * self.h) mask = torch.ones((1, self.h, self.chunk_size, self.time), dtype=torch.bool) cache = torch.zeros(1, self.h, self.chunk_size * self.left_chunks, self.d_k * 2) original_out, original_cache = module(random_data, random_data, random_data, mask[:, 0, :, :], torch.empty(0), cache) random_data = random_data.transpose(1, 2).unsqueeze(2) cache = cache.reshape(1, self.h, self.d_k * 2, self.chunk_size * self.left_chunks) new_out, new_cache = self.forward(random_data, random_data, random_data, mask, cache) np.testing.assert_allclose(to_numpy(original_out), to_numpy( new_out.squeeze(2).transpose(1, 2)), rtol=1e-02, atol=1e-03) np.testing.assert_allclose(to_numpy(original_cache), to_numpy(new_cache.transpose(2, 3)), rtol=1e-02, atol=1e-03) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute scaled dot product attention. Args: q (torch.Tensor): Query tensor (#batch, size, 1, chunk_size). k (torch.Tensor): Key tensor (#batch, size, 1, chunk_size). v (torch.Tensor): Value tensor (#batch, size, 1, chunk_size). mask (torch.Tensor): Mask tensor, (#batch, head, chunk_size, cache_t + chunk_size). cache (torch.Tensor): Cache tensor (1, head, d_k * 2, cache_t), where `cache_t == chunk_size * left_chunks`. Returns: torch.Tensor: Output tensor (#batch, size, 1, chunk_size). torch.Tensor: Cache tensor (1, head, d_k * 2, cache_t + chunk_size) where `cache_t == chunk_size * left_chunks` """ # 1. Forward QKV q = self.linear_q(q) # (1, d, 1, c) d == size, c == chunk_size k = self.linear_k(k) # (1, d, 1, c) v = self.linear_v(v) # (1, d, 1, c) q = q.view(1, self.h, self.d_k, self.chunk_size) k = k.view(1, self.h, self.d_k, self.chunk_size) v = v.view(1, self.h, self.d_k, self.chunk_size) q = q.transpose(2, 3) # (batch, head, time1, d_k) k_cache, v_cache = torch.split(cache, cache.size(2) // 2, dim=2) k = torch.cat((k_cache, k), dim=3) v = torch.cat((v_cache, v), dim=3) new_cache = torch.cat((k, v), dim=2) # 2. (Q^T)K scores = torch.matmul(q, k) * self.denom # (#b, n_head, time1, time2) # 3. Forward attention mask = mask.eq(0) scores = scores.masked_fill(mask, -float('inf')) attn = self.activation(scores).masked_fill(mask, 0.0) attn = attn.transpose(2, 3) x = torch.matmul(v, attn) x = x.view(1, self.d_k * self.h, 1, self.chunk_size) x_out = self.linear_out(x) return x_out, new_cache class BPUConvolution(torch.nn.Module): """Refactor wenet/transformer/convolution.py::ConvolutionModule NOTE(xcsong): Only suport use_layer_norm == False """ def __init__(self, module): super().__init__() # Unchanged submodules and attributes original = copy.deepcopy(module) self.lorder = module.lorder self.use_layer_norm = False self.activation = module.activation channels = module.pointwise_conv1.weight.size(1) self.channels = channels kernel_size = module.depthwise_conv.weight.size(2) assert module.use_layer_norm is False # 1. Modify self.pointwise_conv1 self.pointwise_conv1 = BPULinear(module.pointwise_conv1, True) # 2. Modify self.depthwise_conv self.depthwise_conv = torch.nn.Conv2d(channels, channels, (1, kernel_size), stride=1, groups=channels) self.depthwise_conv.weight = torch.nn.Parameter( module.depthwise_conv.weight.unsqueeze(-2)) self.depthwise_conv.bias = torch.nn.Parameter( module.depthwise_conv.bias) # 3. Modify self.norm, Only support batchnorm2d self.norm = torch.nn.BatchNorm2d(channels) self.norm.training = False self.norm.num_features = module.norm.num_features self.norm.eps = module.norm.eps self.norm.momentum = module.norm.momentum self.norm.weight = torch.nn.Parameter(module.norm.weight) self.norm.bias = torch.nn.Parameter(module.norm.bias) self.norm.running_mean = module.norm.running_mean self.norm.running_var = module.norm.running_var # 4. Modify self.pointwise_conv2 self.pointwise_conv2 = BPULinear(module.pointwise_conv2, True) # 5. Identity conv, for running `concat` on BPU self.identity = BPUIdentity(channels) self.check_equal(original) def check_equal(self, module): random_data = torch.randn(1, 8, self.channels) cache = torch.zeros((1, self.channels, self.lorder)) original_out, original_cache = module(random_data, cache=cache) random_data = random_data.transpose(1, 2).unsqueeze(2) cache = cache.unsqueeze(2) new_out, new_cache = self.forward(random_data, cache) np.testing.assert_allclose(to_numpy(original_out), to_numpy( new_out.squeeze(2).transpose(1, 2)), rtol=1e-02, atol=1e-03) np.testing.assert_allclose(to_numpy(original_cache), to_numpy(new_cache.squeeze(2)), rtol=1e-02, atol=1e-03) def forward(self, x: torch.Tensor, cache: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Compute convolution module. Args: x (torch.Tensor): Input tensor (#batch, channels, 1, chunk_size). cache (torch.Tensor): left context cache, it is only used in causal convolution (#batch, channels, 1, cache_t). Returns: torch.Tensor: Output tensor (#batch, channels, 1, chunk_size). torch.Tensor: Cache tensor (#batch, channels, 1, cache_t). """ # Concat cache x = torch.cat((self.identity(cache), self.identity(x)), dim=3) new_cache = x[:, :, :, -self.lorder:] # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channel, 1, dim) x = torch.nn.functional.glu(x, dim=1) # (b, channel, 1, dim) # Depthwise Conv x = self.depthwise_conv(x) x = self.activation(self.norm(x)) x = self.pointwise_conv2(x) return x, new_cache class BPUFFN(torch.nn.Module): """Refactor wenet/transformer/positionwise_feed_forward.py::PositionwiseFeedForward """ def __init__(self, module): super().__init__() # Unchanged submodules and attributes original = copy.deepcopy(module) self.activation = module.activation # 1. Modify self.w_x self.w_1 = BPULinear(module.w_1) self.w_2 = BPULinear(module.w_2) self.check_equal(original) def check_equal(self, module): random_data = torch.randn(1, 8, self.w_1.idim) original_out = module(random_data) random_data = random_data.transpose(1, 2).unsqueeze(2) new_out = self.forward(random_data) np.testing.assert_allclose(to_numpy(original_out), to_numpy( new_out.squeeze(2).transpose(1, 2)), rtol=1e-02, atol=1e-03) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward function. Args: xs: input tensor (B, D, 1, L) Returns: output tensor, (B, D, 1, L) """ return self.w_2(self.activation(self.w_1(x))) class BPUConformerEncoderLayer(torch.nn.Module): """Refactor wenet/transformer/encoder_layer.py::ConformerEncoderLayer """ def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False): super().__init__() # Unchanged submodules and attributes original = copy.deepcopy(module) self.size = module.size assert module.normalize_before is True assert module.concat_after is False # 1. Modify submodules self.feed_forward_macaron = BPUFFN(module.feed_forward_macaron) self.self_attn = BPUMultiHeadedAttention(module.self_attn, chunk_size, left_chunks) self.conv_module = BPUConvolution(module.conv_module) self.feed_forward = BPUFFN(module.feed_forward) # 2. Modify norms self.norm_ff = BPULayerNorm(module.norm_ff, chunk_size, ln_run_on_bpu) self.norm_mha = BPULayerNorm(module.norm_mha, chunk_size, ln_run_on_bpu) self.norm_ff_macron = BPULayerNorm(module.norm_ff_macaron, chunk_size, ln_run_on_bpu) self.norm_conv = BPULayerNorm(module.norm_conv, chunk_size, ln_run_on_bpu) self.norm_final = BPULayerNorm(module.norm_final, chunk_size, ln_run_on_bpu) # 3. 4-D ff_scale self.register_buffer("ff_scale", torch.full((1, self.size, 1, 1), module.ff_scale)) self.check_equal(original) def check_equal(self, module): time1 = self.self_attn.chunk_size time2 = self.self_attn.time h, d_k = self.self_attn.h, self.self_attn.d_k random_x = torch.randn(1, time1, self.size) att_mask = torch.ones(1, h, time1, time2) att_cache = torch.zeros(1, h, time2 - time1, d_k * 2) cnn_cache = torch.zeros(1, self.size, self.conv_module.lorder) original_x, _, original_att_cache, original_cnn_cache = module( random_x, att_mask[:, 0, :, :], torch.empty(0), att_cache=att_cache, cnn_cache=cnn_cache) random_x = random_x.transpose(1, 2).unsqueeze(2) att_cache = att_cache.reshape(1, h, d_k * 2, time2 - time1) cnn_cache = cnn_cache.unsqueeze(2) new_x, new_att_cache, new_cnn_cache = self.forward( random_x, att_mask, att_cache, cnn_cache) np.testing.assert_allclose(to_numpy(original_att_cache), to_numpy(new_att_cache.transpose(2, 3)), rtol=1e-02, atol=1e-03) np.testing.assert_allclose(to_numpy(original_x), to_numpy(new_x.squeeze(2).transpose(1, 2)), rtol=1e-02, atol=1e-03) np.testing.assert_allclose(to_numpy(original_cnn_cache), to_numpy(new_cnn_cache.squeeze(2)), rtol=1e-02, atol=1e-03) def forward( self, x: torch.Tensor, att_mask: torch.Tensor, att_cache: torch.Tensor, cnn_cache: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute encoded features. Args: x (torch.Tensor): (#batch, size, 1, chunk_size) att_mask (torch.Tensor): Mask tensor for the input (#batch, head, chunk_size, cache_t1 + chunk_size), att_cache (torch.Tensor): Cache tensor of the KEY & VALUE (#batch=1, head, d_k * 2, cache_t1), head * d_k == size. cnn_cache (torch.Tensor): Convolution cache in conformer layer (#batch=1, size, 1, cache_t2) Returns: torch.Tensor: Output tensor (#batch, size, 1, chunk_size). torch.Tensor: att_cache tensor, (1, head, d_k * 2, cache_t1 + chunk_size). torch.Tensor: cnn_cahce tensor (#batch, size, 1, cache_t2). """ # 1. ffn_macaron residual = x x = self.norm_ff_macron(x) x = residual + self.ff_scale * self.feed_forward_macaron(x) # 2. attention residual = x x = self.norm_mha(x) x_att, new_att_cache = self.self_attn(x, x, x, att_mask, att_cache) x = residual + x_att # 3. convolution residual = x x = self.norm_conv(x) x, new_cnn_cache = self.conv_module(x, cnn_cache) x = residual + x # 4. ffn residual = x x = self.norm_ff(x) x = residual + self.ff_scale * self.feed_forward(x) # 5. final post-norm x = self.norm_final(x) return x, new_att_cache, new_cnn_cache class BPUConformerEncoder(torch.nn.Module): """Refactor wenet/transformer/encoder.py::ConformerEncoder """ def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False): super().__init__() # Unchanged submodules and attributes original = copy.deepcopy(module) output_size = module.output_size() self._output_size = module.output_size() self.after_norm = module.after_norm self.chunk_size = chunk_size self.left_chunks = left_chunks self.head = module.encoders[0].self_attn.h self.layers = len(module.encoders) # 1. Modify submodules self.global_cmvn = BPUGlobalCMVN(module.global_cmvn) self.embed = BPUConv2dSubsampling8(module.embed) self.encoders = torch.nn.ModuleList() for layer in module.encoders: self.encoders.append( BPUConformerEncoderLayer(layer, chunk_size, left_chunks, ln_run_on_bpu)) # 2. Auxiliary conv self.identity_cnncache = BPUIdentity(output_size) self.check_equal(original) def check_equal(self, module): time1 = self.encoders[0].self_attn.chunk_size time2 = self.encoders[0].self_attn.time layers = self.layers h, d_k = self.head, self.encoders[0].self_attn.d_k decoding_window = (self.chunk_size - 1) * \ module.embed.subsampling_rate + \ module.embed.right_context + 1 lorder = self.encoders[0].conv_module.lorder random_x = torch.randn(1, decoding_window, 80) att_mask = torch.ones(1, h, time1, time2) att_cache = torch.zeros(layers, h, time2 - time1, d_k * 2) cnn_cache = torch.zeros(layers, 1, self._output_size, lorder) orig_x, orig_att_cache, orig_cnn_cache = module.forward_chunk( random_x, 0, time2 - time1, att_mask=att_mask[:, 0, :, :], att_cache=att_cache, cnn_cache=cnn_cache) random_x = random_x.unsqueeze(0) att_cache = att_cache.reshape(1, h * layers, d_k * 2, time2 - time1) cnn_cache = cnn_cache.reshape(1, self._output_size, layers, lorder) new_x, new_att_cache, new_cnn_cache = self.forward( random_x, att_cache, cnn_cache, att_mask) caches = torch.split(new_att_cache, h, dim=1) caches = [c.transpose(2, 3) for c in caches] np.testing.assert_allclose(to_numpy(orig_att_cache), to_numpy(torch.cat(caches, dim=0)), rtol=1e-02, atol=1e-03) np.testing.assert_allclose(to_numpy(orig_x), to_numpy(new_x.squeeze(2).transpose(1, 2)), rtol=1e-02, atol=1e-03) np.testing.assert_allclose( to_numpy(orig_cnn_cache), to_numpy(new_cnn_cache.transpose(0, 2).transpose(1, 2)), rtol=1e-02, atol=1e-03) def forward( self, xs: torch.Tensor, att_cache: torch.Tensor, cnn_cache: torch.Tensor, att_mask: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward just one chunk Args: xs (torch.Tensor): chunk input, with shape (b=1, 1, time, mel-dim), where `time == (chunk_size - 1) * subsample_rate + \ subsample.right_context + 1` att_cache (torch.Tensor): cache tensor for KEY & VALUE in transformer/conformer attention, with shape (1, head * elayers, d_k * 2, cache_t1), where `head * d_k == hidden-dim` and `cache_t1 == chunk_size * left_chunks`. cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, (1, hidden-dim, elayers, cache_t2), where `cache_t2 == cnn.lorder - 1` att_mask (torch.Tensor): Mask tensor for the input (#batch, head, chunk_size, cache_t1 + chunk_size), Returns: torch.Tensor: output of current input xs, with shape (b=1, hidden-dim, 1, chunk_size). torch.Tensor: new attention cache required for next chunk, with same shape as the original att_cache. torch.Tensor: new conformer cnn cache required for next chunk, with same shape as the original cnn_cache. """ # xs: (B, 1, time, mel_dim) -> (B, 1, mel_dim, time) xs = xs.transpose(2, 3) xs = self.global_cmvn(xs) # xs: (B, 1, mel_dim, time) -> (B, hidden_dim, 1, chunk_size) xs = self.embed(xs) att_cache = torch.split(att_cache, self.head, dim=1) cnn_cache = self.identity_cnncache(cnn_cache) cnn_cache = torch.split(cnn_cache, 1, dim=2) r_att_cache = [] r_cnn_cache = [] for i, layer in enumerate(self.encoders): xs, new_att_cache, new_cnn_cache = layer(xs, att_mask, att_cache=att_cache[i], cnn_cache=cnn_cache[i]) r_att_cache.append(new_att_cache[:, :, :, self.chunk_size:]) r_cnn_cache.append(new_cnn_cache) r_att_cache = torch.cat(r_att_cache, dim=1) r_cnn_cache = self.identity_cnncache(torch.cat(r_cnn_cache, dim=2)) xs = xs.squeeze(2).transpose(1, 2).contiguous() xs = self.after_norm(xs) # NOTE(xcsong): 4D in, 4D out to meet the requirment of CTC input. xs = xs.transpose(1, 2).contiguous().unsqueeze(2) # (B, C, 1, T) return (xs, r_att_cache, r_cnn_cache) class BPUCTC(torch.nn.Module): """Refactor wenet/transformer/ctc.py::CTC """ def __init__(self, module): super().__init__() # Unchanged submodules and attributes original = copy.deepcopy(module) self.idim = module.ctc_lo.weight.size(1) num_class = module.ctc_lo.weight.size(0) # 1. Modify self.ctc_lo, Split final projection to meet the # requirment of maximum in/out channels (2048 for XJ3) self.ctc_lo = torch.nn.ModuleList() self.split_size = [] num_split = (num_class - 1) // 2048 + 1 for idx in range(num_split): out_channel = min(num_class, (idx + 1) * 2048) - idx * 2048 conv_ele = torch.nn.Conv2d(self.idim, out_channel, 1, 1) self.ctc_lo.append(conv_ele) self.split_size.append(out_channel) orig_weight = torch.split(module.ctc_lo.weight, self.split_size, dim=0) orig_bias = torch.split(module.ctc_lo.bias, self.split_size, dim=0) for i, (w, b) in enumerate(zip(orig_weight, orig_bias)): w = w.unsqueeze(2).unsqueeze(3) self.ctc_lo[i].weight = torch.nn.Parameter(w) self.ctc_lo[i].bias = torch.nn.Parameter(b) self.check_equal(original) def check_equal(self, module): random_data = torch.randn(1, 100, self.idim) original_result = module.ctc_lo(random_data) random_data = random_data.transpose(1, 2).unsqueeze(2) new_result = self.forward(random_data) np.testing.assert_allclose(to_numpy(original_result), to_numpy( new_result.squeeze(2).transpose(1, 2)), rtol=1e-02, atol=1e-03) def forward(self, x: torch.Tensor) -> torch.Tensor: """frame activations, without softmax. Args: Tensor x: 4d tensor (B, hidden_dim, 1, chunk_size) Returns: torch.Tensor: (B, num_class, 1, chunk_size) """ out = [] for i, layer in enumerate(self.ctc_lo): out.append(layer(x)) out = torch.cat(out, dim=1) return out def export_encoder(asr_model, args): logger.info("Stage-1: export encoder") decode_window, mel_dim = args.decoding_window, args.feature_size encoder = BPUConformerEncoder(asr_model.encoder, args.chunk_size, args.num_decoding_left_chunks, args.ln_run_on_bpu) encoder.eval() encoder_outpath = os.path.join(args.output_dir, 'encoder.onnx') logger.info("Stage-1.1: prepare inputs for encoder") chunk = torch.randn((1, 1, decode_window, mel_dim)) required_cache_size = encoder.chunk_size * encoder.left_chunks kv_time = required_cache_size + encoder.chunk_size hidden, layers = encoder._output_size, len(encoder.encoders) head = encoder.encoders[0].self_attn.h d_k = hidden // head lorder = encoder.encoders[0].conv_module.lorder att_cache = torch.zeros(1, layers * head, d_k * 2, required_cache_size) att_mask = torch.ones((1, head, encoder.chunk_size, kv_time)) att_mask[:, :, :, :required_cache_size] = 0 cnn_cache = torch.zeros((1, hidden, layers, lorder)) inputs = (chunk, att_cache, cnn_cache, att_mask) logger.info("chunk.size(): {} att_cache.size(): {} " "cnn_cache.size(): {} att_mask.size(): {}".format( list(chunk.size()), list(att_cache.size()), list(cnn_cache.size()), list(att_mask.size()))) logger.info("Stage-1.2: torch.onnx.export") # NOTE(xcsong): Below attributes will be used in # onnx2horizonbin.py::generate_config() attributes = {} attributes['input_name'] = "chunk;att_cache;cnn_cache;att_mask" attributes['output_name'] = "output;r_att_cache;r_cnn_cache" attributes['input_type'] = "featuremap;featuremap;featuremap;featuremap" attributes['norm_type'] = \ "no_preprocess;no_preprocess;no_preprocess;no_preprocess" attributes['input_layout_train'] = "NCHW;NCHW;NCHW;NCHW" attributes['input_layout_rt'] = "NCHW;NCHW;NCHW;NCHW" attributes['input_shape'] = \ "{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{}".format( chunk.size(0), chunk.size(1), chunk.size(2), chunk.size(3), att_cache.size(0), att_cache.size(1), att_cache.size(2), att_cache.size(3), cnn_cache.size(0), cnn_cache.size(1), cnn_cache.size(2), cnn_cache.size(3), att_mask.size(0), att_mask.size(1), att_mask.size(2), att_mask.size(3) ) torch.onnx.export( # NOTE(xcsong): only support opset==11 encoder, inputs, encoder_outpath, opset_version=11, export_params=True, do_constant_folding=True, input_names=attributes['input_name'].split(';'), output_names=attributes['output_name'].split(';'), dynamic_axes=None, verbose=False) onnx_encoder = onnx.load(encoder_outpath) for k in vars(args): meta = onnx_encoder.metadata_props.add() meta.key, meta.value = str(k), str(getattr(args, k)) for k in attributes: meta = onnx_encoder.metadata_props.add() meta.key, meta.value = str(k), str(attributes[k]) onnx.checker.check_model(onnx_encoder) onnx.helper.printable_graph(onnx_encoder.graph) onnx.save(onnx_encoder, encoder_outpath) print_input_output_info(onnx_encoder, "onnx_encoder") logger.info('Export onnx_encoder, done! see {}'.format(encoder_outpath)) logger.info("Stage-1.3: check onnx_encoder and torch_encoder") torch_output = [] torch_chunk, torch_att_mask = copy.deepcopy(chunk), copy.deepcopy(att_mask) torch_att_cache = copy.deepcopy(att_cache) torch_cnn_cache = copy.deepcopy(cnn_cache) for i in range(10): logger.info("torch chunk-{}: {}, att_cache: {}, cnn_cache: {}" ", att_mask: {}".format(i, list(torch_chunk.size()), list(torch_att_cache.size()), list(torch_cnn_cache.size()), list(torch_att_mask.size()))) torch_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)):] = 1 out, torch_att_cache, torch_cnn_cache = encoder( torch_chunk, torch_att_cache, torch_cnn_cache, torch_att_mask) torch_output.append(out) torch_output = torch.cat(torch_output, dim=-1) onnx_output = [] onnx_chunk, onnx_att_mask = to_numpy(chunk), to_numpy(att_mask) onnx_att_cache = to_numpy(att_cache) onnx_cnn_cache = to_numpy(cnn_cache) ort_session = onnxruntime.InferenceSession(encoder_outpath) input_names = [node.name for node in onnx_encoder.graph.input] for i in range(10): logger.info("onnx chunk-{}: {}, att_cache: {}, cnn_cache: {}," " att_mask: {}".format(i, onnx_chunk.shape, onnx_att_cache.shape, onnx_cnn_cache.shape, onnx_att_mask.shape)) onnx_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)):] = 1 ort_inputs = { 'chunk': onnx_chunk, 'att_cache': onnx_att_cache, 'cnn_cache': onnx_cnn_cache, 'att_mask': onnx_att_mask, } ort_outs = ort_session.run(None, ort_inputs) onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2] onnx_output.append(ort_outs[0]) onnx_output = np.concatenate(onnx_output, axis=-1) np.testing.assert_allclose(to_numpy(torch_output), onnx_output, rtol=1e-03, atol=1e-04) meta = ort_session.get_modelmeta() logger.info("custom_metadata_map={}".format(meta.custom_metadata_map)) logger.info("Check onnx_encoder, pass!") return encoder, ort_session def export_ctc(asr_model, args): logger.info("Stage-2: export ctc") ctc = BPUCTC(asr_model.ctc).eval() ctc_outpath = os.path.join(args.output_dir, 'ctc.onnx') logger.info("Stage-2.1: prepare inputs for ctc") hidden = torch.randn((1, args.output_size, 1, args.chunk_size)) logger.info("Stage-2.2: torch.onnx.export") # NOTE(xcsong): Below attributes will be used in # onnx2horizonbin.py::generate_config() attributes = {} attributes['input_name'], attributes['input_type'] = "hidden", "featuremap" attributes['norm_type'] = "no_preprocess" attributes['input_layout_train'] = "NCHW" attributes['input_layout_rt'] = "NCHW" attributes['input_shape'] = "{}x{}x{}x{}".format( hidden.size(0), hidden.size(1), hidden.size(2), hidden.size(3), ) torch.onnx.export(ctc, hidden, ctc_outpath, opset_version=11, export_params=True, do_constant_folding=True, input_names=['hidden'], output_names=['probs'], dynamic_axes=None, verbose=False) onnx_ctc = onnx.load(ctc_outpath) for k in vars(args): meta = onnx_ctc.metadata_props.add() meta.key, meta.value = str(k), str(getattr(args, k)) for k in attributes: meta = onnx_ctc.metadata_props.add() meta.key, meta.value = str(k), str(attributes[k]) onnx.checker.check_model(onnx_ctc) onnx.helper.printable_graph(onnx_ctc.graph) onnx.save(onnx_ctc, ctc_outpath) print_input_output_info(onnx_ctc, "onnx_ctc") logger.info('Export onnx_ctc, done! see {}'.format(ctc_outpath)) logger.info("Stage-2.3: check onnx_ctc and torch_ctc") torch_output = ctc(hidden) ort_session = onnxruntime.InferenceSession(ctc_outpath) onnx_output = ort_session.run(None, {'hidden': to_numpy(hidden)}) np.testing.assert_allclose(to_numpy(torch_output), onnx_output[0], rtol=1e-03, atol=1e-04) meta = ort_session.get_modelmeta() logger.info("custom_metadata_map={}".format(meta.custom_metadata_map)) logger.info("Check onnx_ctc, pass!") return ctc, ort_session def export_decoder(asr_model, args): logger.info("Currently, Decoder is not supported.") if __name__ == '__main__': torch.manual_seed(777) args = get_args() args.ln_run_on_bpu = False # NOTE(xcsong): XJ3 BPU only support static shapes assert args.chunk_size > 0 assert args.num_decoding_left_chunks > 0 os.system("mkdir -p " + args.output_dir) os.environ['CUDA_VISIBLE_DEVICES'] = '-1' with open(args.config, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) model, configs = init_model(args, configs) model.eval() print(model) args.feature_size = configs['input_dim'] args.output_size = model.encoder.output_size() args.decoding_window = (args.chunk_size - 1) * \ model.encoder.embed.subsampling_rate + \ model.encoder.embed.right_context + 1 export_encoder(model, args) export_ctc(model, args) export_decoder(model, args)