Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2022, Horizon Inc. Xingchen Song ([email protected]) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""NOTE(xcsong): Currently, we only support | |
1. specific conformer encoder architecture, see: | |
encoder: conformer | |
encoder_conf: | |
activation_type: **must be** relu | |
attention_heads: 2 or 4 or 8 or any number divisible by output_size | |
causal: **must be** true | |
cnn_module_kernel: 1 ~ 7 | |
cnn_module_norm: **must be** batch_norm | |
input_layer: **must be** conv2d8 | |
linear_units: 1 ~ 2048 | |
normalize_before: **must be** true | |
num_blocks: 1 ~ 12 | |
output_size: 1 ~ 512 | |
pos_enc_layer_type: **must be** no_pos | |
selfattention_layer_type: **must be** selfattn | |
use_cnn_module: **must be** true | |
use_dynamic_chunk: **must be** true | |
use_dynamic_left_chunk: **must be** true | |
2. specific decoding method: ctc_greedy_search | |
""" | |
from __future__ import print_function | |
import os | |
import sys | |
import copy | |
import math | |
import yaml | |
import logging | |
from typing import Tuple | |
import torch | |
import numpy as np | |
from wenet.transformer.embedding import NoPositionalEncoding | |
from wenet.utils.init_model import init_model | |
from wenet.bin.export_onnx_cpu import (get_args, to_numpy, | |
print_input_output_info) | |
try: | |
import onnx | |
import onnxruntime | |
except ImportError: | |
print('Please install onnx and onnxruntime!') | |
sys.exit(1) | |
logger = logging.getLogger(__file__) | |
logger.setLevel(logging.INFO) | |
class BPULayerNorm(torch.nn.Module): | |
"""Refactor torch.nn.LayerNorm to meet 4-D dataflow.""" | |
def __init__(self, module, chunk_size=8, run_on_bpu=False): | |
super().__init__() | |
original = copy.deepcopy(module) | |
self.hidden = module.weight.size(0) | |
self.chunk_size = chunk_size | |
self.run_on_bpu = run_on_bpu | |
if self.run_on_bpu: | |
self.weight = torch.nn.Parameter( | |
module.weight.reshape(1, self.hidden, 1, | |
1).repeat(1, 1, 1, chunk_size)) | |
self.bias = torch.nn.Parameter( | |
module.bias.reshape(1, self.hidden, 1, | |
1).repeat(1, 1, 1, chunk_size)) | |
self.negtive = torch.nn.Parameter( | |
torch.ones((1, self.hidden, 1, chunk_size)) * -1.0) | |
self.eps = torch.nn.Parameter( | |
torch.zeros((1, self.hidden, 1, chunk_size)) + module.eps) | |
self.mean_conv_1 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False) | |
self.mean_conv_1.weight = torch.nn.Parameter( | |
torch.ones(self.hidden, self.hidden, 1, 1) / | |
(1.0 * self.hidden)) | |
self.mean_conv_2 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False) | |
self.mean_conv_2.weight = torch.nn.Parameter( | |
torch.ones(self.hidden, self.hidden, 1, 1) / | |
(1.0 * self.hidden)) | |
else: | |
self.norm = module | |
self.check_equal(original) | |
def check_equal(self, module): | |
random_data = torch.randn(1, self.chunk_size, self.hidden) | |
orig_out = module(random_data) | |
new_out = self.forward(random_data.transpose(1, 2).unsqueeze(2)) | |
np.testing.assert_allclose(to_numpy(orig_out), | |
to_numpy( | |
new_out.squeeze(2).transpose(1, 2)), | |
rtol=1e-02, | |
atol=1e-03) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if self.run_on_bpu: | |
u = self.mean_conv_1(x) # (1, h, 1, c) | |
numerator = x + u * self.negtive # (1, h, 1, c) | |
s = torch.pow(numerator, 2) # (1, h, 1, c) | |
s = self.mean_conv_2(s) # (1, h, 1, c) | |
denominator = torch.sqrt(s + self.eps) # (1, h, 1, c) | |
x = torch.div(numerator, denominator) # (1, h, 1, c) | |
x = x * self.weight + self.bias | |
else: | |
x = x.squeeze(2).transpose(1, 2).contiguous() | |
x = self.norm(x) | |
x = x.transpose(1, 2).contiguous().unsqueeze(2) | |
return x | |
class BPUIdentity(torch.nn.Module): | |
"""Refactor torch.nn.Identity(). | |
For inserting BPU node whose input == output. | |
""" | |
def __init__(self, channels): | |
super().__init__() | |
self.channels = channels | |
self.identity_conv = torch.nn.Conv2d(channels, | |
channels, | |
1, | |
groups=channels, | |
bias=False) | |
torch.nn.init.dirac_(self.identity_conv.weight.data, groups=channels) | |
self.check_equal() | |
def check_equal(self): | |
random_data = torch.randn(1, self.channels, 1, 10) | |
result = self.forward(random_data) | |
np.testing.assert_allclose(to_numpy(random_data), | |
to_numpy(result), | |
rtol=1e-02, | |
atol=1e-03) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Identity with 4-D dataflow, input == output. | |
Args: | |
x (torch.Tensor): (batch, in_channel, 1, time) | |
Returns: | |
(torch.Tensor): (batch, in_channel, 1, time). | |
""" | |
return self.identity_conv(x) | |
class BPULinear(torch.nn.Module): | |
"""Refactor torch.nn.Linear or pointwise_conv""" | |
def __init__(self, module, is_pointwise_conv=False): | |
super().__init__() | |
# Unchanged submodules and attributes | |
original = copy.deepcopy(module) | |
self.idim = module.weight.size(1) | |
self.odim = module.weight.size(0) | |
self.is_pointwise_conv = is_pointwise_conv | |
# Modify weight & bias | |
self.linear = torch.nn.Conv2d(self.idim, self.odim, 1, 1) | |
if is_pointwise_conv: | |
# (odim, idim, kernel=1) -> (odim, idim, 1, 1) | |
self.linear.weight = torch.nn.Parameter( | |
module.weight.unsqueeze(-1)) | |
else: | |
# (odim, idim) -> (odim, idim, 1, 1) | |
self.linear.weight = torch.nn.Parameter( | |
module.weight.unsqueeze(2).unsqueeze(3)) | |
self.linear.bias = module.bias | |
self.check_equal(original) | |
def check_equal(self, module): | |
random_data = torch.randn(1, 8, self.idim) | |
if self.is_pointwise_conv: | |
random_data = random_data.transpose(1, 2) | |
original_result = module(random_data) | |
if self.is_pointwise_conv: | |
random_data = random_data.transpose(1, 2) | |
original_result = original_result.transpose(1, 2) | |
random_data = random_data.transpose(1, 2).unsqueeze(2) | |
new_result = self.forward(random_data) | |
np.testing.assert_allclose(to_numpy(original_result), | |
to_numpy( | |
new_result.squeeze(2).transpose(1, 2)), | |
rtol=1e-02, | |
atol=1e-03) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Linear with 4-D dataflow. | |
Args: | |
x (torch.Tensor): (batch, in_channel, 1, time) | |
Returns: | |
(torch.Tensor): (batch, out_channel, 1, time). | |
""" | |
return self.linear(x) | |
class BPUGlobalCMVN(torch.nn.Module): | |
"""Refactor wenet/transformer/cmvn.py::GlobalCMVN""" | |
def __init__(self, module): | |
super().__init__() | |
# Unchanged submodules and attributes | |
self.norm_var = module.norm_var | |
# NOTE(xcsong): Expand to 4-D tensor, (mel_dim) -> (1, 1, mel_dim, 1) | |
self.mean = module.mean.unsqueeze(-1).unsqueeze(0).unsqueeze(0) | |
self.istd = module.istd.unsqueeze(-1).unsqueeze(0).unsqueeze(0) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""CMVN with 4-D dataflow. | |
Args: | |
x (torch.Tensor): (batch, 1, mel_dim, time) | |
Returns: | |
(torch.Tensor): normalized feature with same shape. | |
""" | |
x = x - self.mean | |
if self.norm_var: | |
x = x * self.istd | |
return x | |
class BPUConv2dSubsampling8(torch.nn.Module): | |
"""Refactor wenet/transformer/subsampling.py::Conv2dSubsampling8 | |
NOTE(xcsong): Only support pos_enc_class == NoPositionalEncoding | |
""" | |
def __init__(self, module): | |
super().__init__() | |
# Unchanged submodules and attributes | |
original = copy.deepcopy(module) | |
self.right_context = module.right_context | |
self.subsampling_rate = module.subsampling_rate | |
assert isinstance(module.pos_enc, NoPositionalEncoding) | |
# 1. Modify self.conv | |
# NOTE(xcsong): We change input shape from (1, 1, frames, mel_dim) | |
# to (1, 1, mel_dim, frames) for more efficient computation. | |
self.conv = module.conv | |
for idx in [0, 2, 4]: | |
self.conv[idx].weight = torch.nn.Parameter( | |
module.conv[idx].weight.transpose(2, 3)) | |
# 2. Modify self.linear | |
# NOTE(xcsong): Split final projection to meet the requirment of | |
# maximum kernel_size (7 for XJ3) | |
self.linear = torch.nn.ModuleList() | |
odim = module.linear.weight.size(0) # 512, in this case | |
freq = module.linear.weight.size(1) // odim # 4608 // 512 == 9 | |
self.odim, self.freq = odim, freq | |
weight = module.linear.weight.reshape( | |
odim, odim, freq, | |
1) # (odim, odim * freq) -> (odim, odim, freq, 1) | |
self.split_size = [] | |
num_split = (freq - 1) // 7 + 1 # XJ3 requires kernel_size <= 7 | |
slice_begin = 0 | |
for idx in range(num_split): | |
kernel_size = min(freq, (idx + 1) * 7) - idx * 7 | |
conv_ele = torch.nn.Conv2d(odim, odim, (kernel_size, 1), | |
(kernel_size, 1)) | |
conv_ele.weight = torch.nn.Parameter( | |
weight[:, :, slice_begin:slice_begin + kernel_size, :]) | |
conv_ele.bias = torch.nn.Parameter(torch.zeros_like(conv_ele.bias)) | |
self.linear.append(conv_ele) | |
self.split_size.append(kernel_size) | |
slice_begin += kernel_size | |
self.linear[0].bias = torch.nn.Parameter(module.linear.bias) | |
self.check_equal(original) | |
def check_equal(self, module): | |
random_data = torch.randn(1, 67, 80) | |
mask = torch.zeros(1, 1, 67) | |
original_result, _, _ = module(random_data, mask) # (1, 8, 512) | |
random_data = random_data.transpose(1, | |
2).unsqueeze(0) # (1, 1, 80, 67) | |
new_result = self.forward(random_data) # (1, 512, 1, 8) | |
np.testing.assert_allclose(to_numpy(original_result), | |
to_numpy( | |
new_result.squeeze(2).transpose(1, 2)), | |
rtol=1e-02, | |
atol=1e-03) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Subsample x with 4-D dataflow. | |
Args: | |
x (torch.Tensor): Input tensor (#batch, 1, mel_dim, time). | |
Returns: | |
torch.Tensor: Subsampled tensor (#batch, odim, 1, time'), | |
where time' = time // 8. | |
""" | |
x = self.conv(x) # (1, odim, freq, time') | |
x_out = torch.zeros(x.size(0), self.odim, 1, x.size(3)) | |
x = torch.split(x, self.split_size, dim=2) | |
for idx, (x_part, layer) in enumerate(zip(x, self.linear)): | |
x_out += layer(x_part) | |
return x_out | |
class BPUMultiHeadedAttention(torch.nn.Module): | |
"""Refactor wenet/transformer/attention.py::MultiHeadedAttention | |
NOTE(xcsong): Only support attention_class == MultiHeadedAttention, | |
we do not consider RelPositionMultiHeadedAttention currently. | |
""" | |
def __init__(self, module, chunk_size, left_chunks): | |
super().__init__() | |
# Unchanged submodules and attributes | |
original = copy.deepcopy(module) | |
self.d_k = module.d_k | |
self.h = module.h | |
n_feat = self.d_k * self.h | |
self.chunk_size = chunk_size | |
self.left_chunks = left_chunks | |
self.time = chunk_size * (left_chunks + 1) | |
self.activation = torch.nn.Softmax(dim=-1) | |
# 1. Modify self.linear_x | |
self.linear_q = BPULinear(module.linear_q) | |
self.linear_k = BPULinear(module.linear_k) | |
self.linear_v = BPULinear(module.linear_v) | |
self.linear_out = BPULinear(module.linear_out) | |
# 2. denom | |
self.register_buffer( | |
"denom", torch.full((1, self.h, 1, 1), 1.0 / math.sqrt(self.d_k))) | |
self.check_equal(original) | |
def check_equal(self, module): | |
random_data = torch.randn(1, self.chunk_size, self.d_k * self.h) | |
mask = torch.ones((1, self.h, self.chunk_size, self.time), | |
dtype=torch.bool) | |
cache = torch.zeros(1, self.h, self.chunk_size * self.left_chunks, | |
self.d_k * 2) | |
original_out, original_cache = module(random_data, random_data, | |
random_data, mask[:, 0, :, :], | |
torch.empty(0), cache) | |
random_data = random_data.transpose(1, 2).unsqueeze(2) | |
cache = cache.reshape(1, self.h, self.d_k * 2, | |
self.chunk_size * self.left_chunks) | |
new_out, new_cache = self.forward(random_data, random_data, | |
random_data, mask, cache) | |
np.testing.assert_allclose(to_numpy(original_out), | |
to_numpy( | |
new_out.squeeze(2).transpose(1, 2)), | |
rtol=1e-02, | |
atol=1e-03) | |
np.testing.assert_allclose(to_numpy(original_cache), | |
to_numpy(new_cache.transpose(2, 3)), | |
rtol=1e-02, | |
atol=1e-03) | |
def forward( | |
self, | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
mask: torch.Tensor, | |
cache: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Compute scaled dot product attention. | |
Args: | |
q (torch.Tensor): Query tensor (#batch, size, 1, chunk_size). | |
k (torch.Tensor): Key tensor (#batch, size, 1, chunk_size). | |
v (torch.Tensor): Value tensor (#batch, size, 1, chunk_size). | |
mask (torch.Tensor): Mask tensor, | |
(#batch, head, chunk_size, cache_t + chunk_size). | |
cache (torch.Tensor): Cache tensor | |
(1, head, d_k * 2, cache_t), | |
where `cache_t == chunk_size * left_chunks`. | |
Returns: | |
torch.Tensor: Output tensor (#batch, size, 1, chunk_size). | |
torch.Tensor: Cache tensor | |
(1, head, d_k * 2, cache_t + chunk_size) | |
where `cache_t == chunk_size * left_chunks` | |
""" | |
# 1. Forward QKV | |
q = self.linear_q(q) # (1, d, 1, c) d == size, c == chunk_size | |
k = self.linear_k(k) # (1, d, 1, c) | |
v = self.linear_v(v) # (1, d, 1, c) | |
q = q.view(1, self.h, self.d_k, self.chunk_size) | |
k = k.view(1, self.h, self.d_k, self.chunk_size) | |
v = v.view(1, self.h, self.d_k, self.chunk_size) | |
q = q.transpose(2, 3) # (batch, head, time1, d_k) | |
k_cache, v_cache = torch.split(cache, cache.size(2) // 2, dim=2) | |
k = torch.cat((k_cache, k), dim=3) | |
v = torch.cat((v_cache, v), dim=3) | |
new_cache = torch.cat((k, v), dim=2) | |
# 2. (Q^T)K | |
scores = torch.matmul(q, k) * self.denom # (#b, n_head, time1, time2) | |
# 3. Forward attention | |
mask = mask.eq(0) | |
scores = scores.masked_fill(mask, -float('inf')) | |
attn = self.activation(scores).masked_fill(mask, 0.0) | |
attn = attn.transpose(2, 3) | |
x = torch.matmul(v, attn) | |
x = x.view(1, self.d_k * self.h, 1, self.chunk_size) | |
x_out = self.linear_out(x) | |
return x_out, new_cache | |
class BPUConvolution(torch.nn.Module): | |
"""Refactor wenet/transformer/convolution.py::ConvolutionModule | |
NOTE(xcsong): Only suport use_layer_norm == False | |
""" | |
def __init__(self, module): | |
super().__init__() | |
# Unchanged submodules and attributes | |
original = copy.deepcopy(module) | |
self.lorder = module.lorder | |
self.use_layer_norm = False | |
self.activation = module.activation | |
channels = module.pointwise_conv1.weight.size(1) | |
self.channels = channels | |
kernel_size = module.depthwise_conv.weight.size(2) | |
assert module.use_layer_norm is False | |
# 1. Modify self.pointwise_conv1 | |
self.pointwise_conv1 = BPULinear(module.pointwise_conv1, True) | |
# 2. Modify self.depthwise_conv | |
self.depthwise_conv = torch.nn.Conv2d(channels, | |
channels, (1, kernel_size), | |
stride=1, | |
groups=channels) | |
self.depthwise_conv.weight = torch.nn.Parameter( | |
module.depthwise_conv.weight.unsqueeze(-2)) | |
self.depthwise_conv.bias = torch.nn.Parameter( | |
module.depthwise_conv.bias) | |
# 3. Modify self.norm, Only support batchnorm2d | |
self.norm = torch.nn.BatchNorm2d(channels) | |
self.norm.training = False | |
self.norm.num_features = module.norm.num_features | |
self.norm.eps = module.norm.eps | |
self.norm.momentum = module.norm.momentum | |
self.norm.weight = torch.nn.Parameter(module.norm.weight) | |
self.norm.bias = torch.nn.Parameter(module.norm.bias) | |
self.norm.running_mean = module.norm.running_mean | |
self.norm.running_var = module.norm.running_var | |
# 4. Modify self.pointwise_conv2 | |
self.pointwise_conv2 = BPULinear(module.pointwise_conv2, True) | |
# 5. Identity conv, for running `concat` on BPU | |
self.identity = BPUIdentity(channels) | |
self.check_equal(original) | |
def check_equal(self, module): | |
random_data = torch.randn(1, 8, self.channels) | |
cache = torch.zeros((1, self.channels, self.lorder)) | |
original_out, original_cache = module(random_data, cache=cache) | |
random_data = random_data.transpose(1, 2).unsqueeze(2) | |
cache = cache.unsqueeze(2) | |
new_out, new_cache = self.forward(random_data, cache) | |
np.testing.assert_allclose(to_numpy(original_out), | |
to_numpy( | |
new_out.squeeze(2).transpose(1, 2)), | |
rtol=1e-02, | |
atol=1e-03) | |
np.testing.assert_allclose(to_numpy(original_cache), | |
to_numpy(new_cache.squeeze(2)), | |
rtol=1e-02, | |
atol=1e-03) | |
def forward(self, x: torch.Tensor, | |
cache: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Compute convolution module. | |
Args: | |
x (torch.Tensor): Input tensor (#batch, channels, 1, chunk_size). | |
cache (torch.Tensor): left context cache, it is only | |
used in causal convolution (#batch, channels, 1, cache_t). | |
Returns: | |
torch.Tensor: Output tensor (#batch, channels, 1, chunk_size). | |
torch.Tensor: Cache tensor (#batch, channels, 1, cache_t). | |
""" | |
# Concat cache | |
x = torch.cat((self.identity(cache), self.identity(x)), dim=3) | |
new_cache = x[:, :, :, -self.lorder:] | |
# GLU mechanism | |
x = self.pointwise_conv1(x) # (batch, 2*channel, 1, dim) | |
x = torch.nn.functional.glu(x, dim=1) # (b, channel, 1, dim) | |
# Depthwise Conv | |
x = self.depthwise_conv(x) | |
x = self.activation(self.norm(x)) | |
x = self.pointwise_conv2(x) | |
return x, new_cache | |
class BPUFFN(torch.nn.Module): | |
"""Refactor wenet/transformer/positionwise_feed_forward.py::PositionwiseFeedForward | |
""" | |
def __init__(self, module): | |
super().__init__() | |
# Unchanged submodules and attributes | |
original = copy.deepcopy(module) | |
self.activation = module.activation | |
# 1. Modify self.w_x | |
self.w_1 = BPULinear(module.w_1) | |
self.w_2 = BPULinear(module.w_2) | |
self.check_equal(original) | |
def check_equal(self, module): | |
random_data = torch.randn(1, 8, self.w_1.idim) | |
original_out = module(random_data) | |
random_data = random_data.transpose(1, 2).unsqueeze(2) | |
new_out = self.forward(random_data) | |
np.testing.assert_allclose(to_numpy(original_out), | |
to_numpy( | |
new_out.squeeze(2).transpose(1, 2)), | |
rtol=1e-02, | |
atol=1e-03) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Forward function. | |
Args: | |
xs: input tensor (B, D, 1, L) | |
Returns: | |
output tensor, (B, D, 1, L) | |
""" | |
return self.w_2(self.activation(self.w_1(x))) | |
class BPUConformerEncoderLayer(torch.nn.Module): | |
"""Refactor wenet/transformer/encoder_layer.py::ConformerEncoderLayer | |
""" | |
def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False): | |
super().__init__() | |
# Unchanged submodules and attributes | |
original = copy.deepcopy(module) | |
self.size = module.size | |
assert module.normalize_before is True | |
assert module.concat_after is False | |
# 1. Modify submodules | |
self.feed_forward_macaron = BPUFFN(module.feed_forward_macaron) | |
self.self_attn = BPUMultiHeadedAttention(module.self_attn, chunk_size, | |
left_chunks) | |
self.conv_module = BPUConvolution(module.conv_module) | |
self.feed_forward = BPUFFN(module.feed_forward) | |
# 2. Modify norms | |
self.norm_ff = BPULayerNorm(module.norm_ff, chunk_size, ln_run_on_bpu) | |
self.norm_mha = BPULayerNorm(module.norm_mha, chunk_size, | |
ln_run_on_bpu) | |
self.norm_ff_macron = BPULayerNorm(module.norm_ff_macaron, chunk_size, | |
ln_run_on_bpu) | |
self.norm_conv = BPULayerNorm(module.norm_conv, chunk_size, | |
ln_run_on_bpu) | |
self.norm_final = BPULayerNorm(module.norm_final, chunk_size, | |
ln_run_on_bpu) | |
# 3. 4-D ff_scale | |
self.register_buffer("ff_scale", | |
torch.full((1, self.size, 1, 1), module.ff_scale)) | |
self.check_equal(original) | |
def check_equal(self, module): | |
time1 = self.self_attn.chunk_size | |
time2 = self.self_attn.time | |
h, d_k = self.self_attn.h, self.self_attn.d_k | |
random_x = torch.randn(1, time1, self.size) | |
att_mask = torch.ones(1, h, time1, time2) | |
att_cache = torch.zeros(1, h, time2 - time1, d_k * 2) | |
cnn_cache = torch.zeros(1, self.size, self.conv_module.lorder) | |
original_x, _, original_att_cache, original_cnn_cache = module( | |
random_x, | |
att_mask[:, 0, :, :], | |
torch.empty(0), | |
att_cache=att_cache, | |
cnn_cache=cnn_cache) | |
random_x = random_x.transpose(1, 2).unsqueeze(2) | |
att_cache = att_cache.reshape(1, h, d_k * 2, time2 - time1) | |
cnn_cache = cnn_cache.unsqueeze(2) | |
new_x, new_att_cache, new_cnn_cache = self.forward( | |
random_x, att_mask, att_cache, cnn_cache) | |
np.testing.assert_allclose(to_numpy(original_att_cache), | |
to_numpy(new_att_cache.transpose(2, 3)), | |
rtol=1e-02, | |
atol=1e-03) | |
np.testing.assert_allclose(to_numpy(original_x), | |
to_numpy(new_x.squeeze(2).transpose(1, 2)), | |
rtol=1e-02, | |
atol=1e-03) | |
np.testing.assert_allclose(to_numpy(original_cnn_cache), | |
to_numpy(new_cnn_cache.squeeze(2)), | |
rtol=1e-02, | |
atol=1e-03) | |
def forward( | |
self, x: torch.Tensor, att_mask: torch.Tensor, att_cache: torch.Tensor, | |
cnn_cache: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
"""Compute encoded features. | |
Args: | |
x (torch.Tensor): (#batch, size, 1, chunk_size) | |
att_mask (torch.Tensor): Mask tensor for the input | |
(#batch, head, chunk_size, cache_t1 + chunk_size), | |
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE | |
(#batch=1, head, d_k * 2, cache_t1), head * d_k == size. | |
cnn_cache (torch.Tensor): Convolution cache in conformer layer | |
(#batch=1, size, 1, cache_t2) | |
Returns: | |
torch.Tensor: Output tensor (#batch, size, 1, chunk_size). | |
torch.Tensor: att_cache tensor, | |
(1, head, d_k * 2, cache_t1 + chunk_size). | |
torch.Tensor: cnn_cahce tensor (#batch, size, 1, cache_t2). | |
""" | |
# 1. ffn_macaron | |
residual = x | |
x = self.norm_ff_macron(x) | |
x = residual + self.ff_scale * self.feed_forward_macaron(x) | |
# 2. attention | |
residual = x | |
x = self.norm_mha(x) | |
x_att, new_att_cache = self.self_attn(x, x, x, att_mask, att_cache) | |
x = residual + x_att | |
# 3. convolution | |
residual = x | |
x = self.norm_conv(x) | |
x, new_cnn_cache = self.conv_module(x, cnn_cache) | |
x = residual + x | |
# 4. ffn | |
residual = x | |
x = self.norm_ff(x) | |
x = residual + self.ff_scale * self.feed_forward(x) | |
# 5. final post-norm | |
x = self.norm_final(x) | |
return x, new_att_cache, new_cnn_cache | |
class BPUConformerEncoder(torch.nn.Module): | |
"""Refactor wenet/transformer/encoder.py::ConformerEncoder | |
""" | |
def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False): | |
super().__init__() | |
# Unchanged submodules and attributes | |
original = copy.deepcopy(module) | |
output_size = module.output_size() | |
self._output_size = module.output_size() | |
self.after_norm = module.after_norm | |
self.chunk_size = chunk_size | |
self.left_chunks = left_chunks | |
self.head = module.encoders[0].self_attn.h | |
self.layers = len(module.encoders) | |
# 1. Modify submodules | |
self.global_cmvn = BPUGlobalCMVN(module.global_cmvn) | |
self.embed = BPUConv2dSubsampling8(module.embed) | |
self.encoders = torch.nn.ModuleList() | |
for layer in module.encoders: | |
self.encoders.append( | |
BPUConformerEncoderLayer(layer, chunk_size, left_chunks, | |
ln_run_on_bpu)) | |
# 2. Auxiliary conv | |
self.identity_cnncache = BPUIdentity(output_size) | |
self.check_equal(original) | |
def check_equal(self, module): | |
time1 = self.encoders[0].self_attn.chunk_size | |
time2 = self.encoders[0].self_attn.time | |
layers = self.layers | |
h, d_k = self.head, self.encoders[0].self_attn.d_k | |
decoding_window = (self.chunk_size - 1) * \ | |
module.embed.subsampling_rate + \ | |
module.embed.right_context + 1 | |
lorder = self.encoders[0].conv_module.lorder | |
random_x = torch.randn(1, decoding_window, 80) | |
att_mask = torch.ones(1, h, time1, time2) | |
att_cache = torch.zeros(layers, h, time2 - time1, d_k * 2) | |
cnn_cache = torch.zeros(layers, 1, self._output_size, lorder) | |
orig_x, orig_att_cache, orig_cnn_cache = module.forward_chunk( | |
random_x, | |
0, | |
time2 - time1, | |
att_mask=att_mask[:, 0, :, :], | |
att_cache=att_cache, | |
cnn_cache=cnn_cache) | |
random_x = random_x.unsqueeze(0) | |
att_cache = att_cache.reshape(1, h * layers, d_k * 2, time2 - time1) | |
cnn_cache = cnn_cache.reshape(1, self._output_size, layers, lorder) | |
new_x, new_att_cache, new_cnn_cache = self.forward( | |
random_x, att_cache, cnn_cache, att_mask) | |
caches = torch.split(new_att_cache, h, dim=1) | |
caches = [c.transpose(2, 3) for c in caches] | |
np.testing.assert_allclose(to_numpy(orig_att_cache), | |
to_numpy(torch.cat(caches, dim=0)), | |
rtol=1e-02, | |
atol=1e-03) | |
np.testing.assert_allclose(to_numpy(orig_x), | |
to_numpy(new_x.squeeze(2).transpose(1, 2)), | |
rtol=1e-02, | |
atol=1e-03) | |
np.testing.assert_allclose( | |
to_numpy(orig_cnn_cache), | |
to_numpy(new_cnn_cache.transpose(0, 2).transpose(1, 2)), | |
rtol=1e-02, | |
atol=1e-03) | |
def forward( | |
self, xs: torch.Tensor, att_cache: torch.Tensor, | |
cnn_cache: torch.Tensor, att_mask: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" Forward just one chunk | |
Args: | |
xs (torch.Tensor): chunk input, with shape (b=1, 1, time, mel-dim), | |
where `time == (chunk_size - 1) * subsample_rate + \ | |
subsample.right_context + 1` | |
att_cache (torch.Tensor): cache tensor for KEY & VALUE in | |
transformer/conformer attention, with shape | |
(1, head * elayers, d_k * 2, cache_t1), where | |
`head * d_k == hidden-dim` and | |
`cache_t1 == chunk_size * left_chunks`. | |
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, | |
(1, hidden-dim, elayers, cache_t2), where | |
`cache_t2 == cnn.lorder - 1` | |
att_mask (torch.Tensor): Mask tensor for the input | |
(#batch, head, chunk_size, cache_t1 + chunk_size), | |
Returns: | |
torch.Tensor: output of current input xs, | |
with shape (b=1, hidden-dim, 1, chunk_size). | |
torch.Tensor: new attention cache required for next chunk, with | |
same shape as the original att_cache. | |
torch.Tensor: new conformer cnn cache required for next chunk, with | |
same shape as the original cnn_cache. | |
""" | |
# xs: (B, 1, time, mel_dim) -> (B, 1, mel_dim, time) | |
xs = xs.transpose(2, 3) | |
xs = self.global_cmvn(xs) | |
# xs: (B, 1, mel_dim, time) -> (B, hidden_dim, 1, chunk_size) | |
xs = self.embed(xs) | |
att_cache = torch.split(att_cache, self.head, dim=1) | |
cnn_cache = self.identity_cnncache(cnn_cache) | |
cnn_cache = torch.split(cnn_cache, 1, dim=2) | |
r_att_cache = [] | |
r_cnn_cache = [] | |
for i, layer in enumerate(self.encoders): | |
xs, new_att_cache, new_cnn_cache = layer(xs, | |
att_mask, | |
att_cache=att_cache[i], | |
cnn_cache=cnn_cache[i]) | |
r_att_cache.append(new_att_cache[:, :, :, self.chunk_size:]) | |
r_cnn_cache.append(new_cnn_cache) | |
r_att_cache = torch.cat(r_att_cache, dim=1) | |
r_cnn_cache = self.identity_cnncache(torch.cat(r_cnn_cache, dim=2)) | |
xs = xs.squeeze(2).transpose(1, 2).contiguous() | |
xs = self.after_norm(xs) | |
# NOTE(xcsong): 4D in, 4D out to meet the requirment of CTC input. | |
xs = xs.transpose(1, 2).contiguous().unsqueeze(2) # (B, C, 1, T) | |
return (xs, r_att_cache, r_cnn_cache) | |
class BPUCTC(torch.nn.Module): | |
"""Refactor wenet/transformer/ctc.py::CTC | |
""" | |
def __init__(self, module): | |
super().__init__() | |
# Unchanged submodules and attributes | |
original = copy.deepcopy(module) | |
self.idim = module.ctc_lo.weight.size(1) | |
num_class = module.ctc_lo.weight.size(0) | |
# 1. Modify self.ctc_lo, Split final projection to meet the | |
# requirment of maximum in/out channels (2048 for XJ3) | |
self.ctc_lo = torch.nn.ModuleList() | |
self.split_size = [] | |
num_split = (num_class - 1) // 2048 + 1 | |
for idx in range(num_split): | |
out_channel = min(num_class, (idx + 1) * 2048) - idx * 2048 | |
conv_ele = torch.nn.Conv2d(self.idim, out_channel, 1, 1) | |
self.ctc_lo.append(conv_ele) | |
self.split_size.append(out_channel) | |
orig_weight = torch.split(module.ctc_lo.weight, self.split_size, dim=0) | |
orig_bias = torch.split(module.ctc_lo.bias, self.split_size, dim=0) | |
for i, (w, b) in enumerate(zip(orig_weight, orig_bias)): | |
w = w.unsqueeze(2).unsqueeze(3) | |
self.ctc_lo[i].weight = torch.nn.Parameter(w) | |
self.ctc_lo[i].bias = torch.nn.Parameter(b) | |
self.check_equal(original) | |
def check_equal(self, module): | |
random_data = torch.randn(1, 100, self.idim) | |
original_result = module.ctc_lo(random_data) | |
random_data = random_data.transpose(1, 2).unsqueeze(2) | |
new_result = self.forward(random_data) | |
np.testing.assert_allclose(to_numpy(original_result), | |
to_numpy( | |
new_result.squeeze(2).transpose(1, 2)), | |
rtol=1e-02, | |
atol=1e-03) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""frame activations, without softmax. | |
Args: | |
Tensor x: 4d tensor (B, hidden_dim, 1, chunk_size) | |
Returns: | |
torch.Tensor: (B, num_class, 1, chunk_size) | |
""" | |
out = [] | |
for i, layer in enumerate(self.ctc_lo): | |
out.append(layer(x)) | |
out = torch.cat(out, dim=1) | |
return out | |
def export_encoder(asr_model, args): | |
logger.info("Stage-1: export encoder") | |
decode_window, mel_dim = args.decoding_window, args.feature_size | |
encoder = BPUConformerEncoder(asr_model.encoder, args.chunk_size, | |
args.num_decoding_left_chunks, | |
args.ln_run_on_bpu) | |
encoder.eval() | |
encoder_outpath = os.path.join(args.output_dir, 'encoder.onnx') | |
logger.info("Stage-1.1: prepare inputs for encoder") | |
chunk = torch.randn((1, 1, decode_window, mel_dim)) | |
required_cache_size = encoder.chunk_size * encoder.left_chunks | |
kv_time = required_cache_size + encoder.chunk_size | |
hidden, layers = encoder._output_size, len(encoder.encoders) | |
head = encoder.encoders[0].self_attn.h | |
d_k = hidden // head | |
lorder = encoder.encoders[0].conv_module.lorder | |
att_cache = torch.zeros(1, layers * head, d_k * 2, required_cache_size) | |
att_mask = torch.ones((1, head, encoder.chunk_size, kv_time)) | |
att_mask[:, :, :, :required_cache_size] = 0 | |
cnn_cache = torch.zeros((1, hidden, layers, lorder)) | |
inputs = (chunk, att_cache, cnn_cache, att_mask) | |
logger.info("chunk.size(): {} att_cache.size(): {} " | |
"cnn_cache.size(): {} att_mask.size(): {}".format( | |
list(chunk.size()), list(att_cache.size()), | |
list(cnn_cache.size()), list(att_mask.size()))) | |
logger.info("Stage-1.2: torch.onnx.export") | |
# NOTE(xcsong): Below attributes will be used in | |
# onnx2horizonbin.py::generate_config() | |
attributes = {} | |
attributes['input_name'] = "chunk;att_cache;cnn_cache;att_mask" | |
attributes['output_name'] = "output;r_att_cache;r_cnn_cache" | |
attributes['input_type'] = "featuremap;featuremap;featuremap;featuremap" | |
attributes['norm_type'] = \ | |
"no_preprocess;no_preprocess;no_preprocess;no_preprocess" | |
attributes['input_layout_train'] = "NCHW;NCHW;NCHW;NCHW" | |
attributes['input_layout_rt'] = "NCHW;NCHW;NCHW;NCHW" | |
attributes['input_shape'] = \ | |
"{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{}".format( | |
chunk.size(0), chunk.size(1), chunk.size(2), chunk.size(3), | |
att_cache.size(0), att_cache.size(1), att_cache.size(2), | |
att_cache.size(3), cnn_cache.size(0), cnn_cache.size(1), | |
cnn_cache.size(2), cnn_cache.size(3), att_mask.size(0), | |
att_mask.size(1), att_mask.size(2), att_mask.size(3) | |
) | |
torch.onnx.export( # NOTE(xcsong): only support opset==11 | |
encoder, | |
inputs, | |
encoder_outpath, | |
opset_version=11, | |
export_params=True, | |
do_constant_folding=True, | |
input_names=attributes['input_name'].split(';'), | |
output_names=attributes['output_name'].split(';'), | |
dynamic_axes=None, | |
verbose=False) | |
onnx_encoder = onnx.load(encoder_outpath) | |
for k in vars(args): | |
meta = onnx_encoder.metadata_props.add() | |
meta.key, meta.value = str(k), str(getattr(args, k)) | |
for k in attributes: | |
meta = onnx_encoder.metadata_props.add() | |
meta.key, meta.value = str(k), str(attributes[k]) | |
onnx.checker.check_model(onnx_encoder) | |
onnx.helper.printable_graph(onnx_encoder.graph) | |
onnx.save(onnx_encoder, encoder_outpath) | |
print_input_output_info(onnx_encoder, "onnx_encoder") | |
logger.info('Export onnx_encoder, done! see {}'.format(encoder_outpath)) | |
logger.info("Stage-1.3: check onnx_encoder and torch_encoder") | |
torch_output = [] | |
torch_chunk, torch_att_mask = copy.deepcopy(chunk), copy.deepcopy(att_mask) | |
torch_att_cache = copy.deepcopy(att_cache) | |
torch_cnn_cache = copy.deepcopy(cnn_cache) | |
for i in range(10): | |
logger.info("torch chunk-{}: {}, att_cache: {}, cnn_cache: {}" | |
", att_mask: {}".format(i, list(torch_chunk.size()), | |
list(torch_att_cache.size()), | |
list(torch_cnn_cache.size()), | |
list(torch_att_mask.size()))) | |
torch_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)):] = 1 | |
out, torch_att_cache, torch_cnn_cache = encoder( | |
torch_chunk, torch_att_cache, torch_cnn_cache, torch_att_mask) | |
torch_output.append(out) | |
torch_output = torch.cat(torch_output, dim=-1) | |
onnx_output = [] | |
onnx_chunk, onnx_att_mask = to_numpy(chunk), to_numpy(att_mask) | |
onnx_att_cache = to_numpy(att_cache) | |
onnx_cnn_cache = to_numpy(cnn_cache) | |
ort_session = onnxruntime.InferenceSession(encoder_outpath) | |
input_names = [node.name for node in onnx_encoder.graph.input] | |
for i in range(10): | |
logger.info("onnx chunk-{}: {}, att_cache: {}, cnn_cache: {}," | |
" att_mask: {}".format(i, onnx_chunk.shape, | |
onnx_att_cache.shape, | |
onnx_cnn_cache.shape, | |
onnx_att_mask.shape)) | |
onnx_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)):] = 1 | |
ort_inputs = { | |
'chunk': onnx_chunk, | |
'att_cache': onnx_att_cache, | |
'cnn_cache': onnx_cnn_cache, | |
'att_mask': onnx_att_mask, | |
} | |
ort_outs = ort_session.run(None, ort_inputs) | |
onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2] | |
onnx_output.append(ort_outs[0]) | |
onnx_output = np.concatenate(onnx_output, axis=-1) | |
np.testing.assert_allclose(to_numpy(torch_output), | |
onnx_output, | |
rtol=1e-03, | |
atol=1e-04) | |
meta = ort_session.get_modelmeta() | |
logger.info("custom_metadata_map={}".format(meta.custom_metadata_map)) | |
logger.info("Check onnx_encoder, pass!") | |
return encoder, ort_session | |
def export_ctc(asr_model, args): | |
logger.info("Stage-2: export ctc") | |
ctc = BPUCTC(asr_model.ctc).eval() | |
ctc_outpath = os.path.join(args.output_dir, 'ctc.onnx') | |
logger.info("Stage-2.1: prepare inputs for ctc") | |
hidden = torch.randn((1, args.output_size, 1, args.chunk_size)) | |
logger.info("Stage-2.2: torch.onnx.export") | |
# NOTE(xcsong): Below attributes will be used in | |
# onnx2horizonbin.py::generate_config() | |
attributes = {} | |
attributes['input_name'], attributes['input_type'] = "hidden", "featuremap" | |
attributes['norm_type'] = "no_preprocess" | |
attributes['input_layout_train'] = "NCHW" | |
attributes['input_layout_rt'] = "NCHW" | |
attributes['input_shape'] = "{}x{}x{}x{}".format( | |
hidden.size(0), | |
hidden.size(1), | |
hidden.size(2), | |
hidden.size(3), | |
) | |
torch.onnx.export(ctc, | |
hidden, | |
ctc_outpath, | |
opset_version=11, | |
export_params=True, | |
do_constant_folding=True, | |
input_names=['hidden'], | |
output_names=['probs'], | |
dynamic_axes=None, | |
verbose=False) | |
onnx_ctc = onnx.load(ctc_outpath) | |
for k in vars(args): | |
meta = onnx_ctc.metadata_props.add() | |
meta.key, meta.value = str(k), str(getattr(args, k)) | |
for k in attributes: | |
meta = onnx_ctc.metadata_props.add() | |
meta.key, meta.value = str(k), str(attributes[k]) | |
onnx.checker.check_model(onnx_ctc) | |
onnx.helper.printable_graph(onnx_ctc.graph) | |
onnx.save(onnx_ctc, ctc_outpath) | |
print_input_output_info(onnx_ctc, "onnx_ctc") | |
logger.info('Export onnx_ctc, done! see {}'.format(ctc_outpath)) | |
logger.info("Stage-2.3: check onnx_ctc and torch_ctc") | |
torch_output = ctc(hidden) | |
ort_session = onnxruntime.InferenceSession(ctc_outpath) | |
onnx_output = ort_session.run(None, {'hidden': to_numpy(hidden)}) | |
np.testing.assert_allclose(to_numpy(torch_output), | |
onnx_output[0], | |
rtol=1e-03, | |
atol=1e-04) | |
meta = ort_session.get_modelmeta() | |
logger.info("custom_metadata_map={}".format(meta.custom_metadata_map)) | |
logger.info("Check onnx_ctc, pass!") | |
return ctc, ort_session | |
def export_decoder(asr_model, args): | |
logger.info("Currently, Decoder is not supported.") | |
if __name__ == '__main__': | |
torch.manual_seed(777) | |
args = get_args() | |
args.ln_run_on_bpu = False | |
# NOTE(xcsong): XJ3 BPU only support static shapes | |
assert args.chunk_size > 0 | |
assert args.num_decoding_left_chunks > 0 | |
os.system("mkdir -p " + args.output_dir) | |
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | |
with open(args.config, 'r') as fin: | |
configs = yaml.load(fin, Loader=yaml.FullLoader) | |
model, configs = init_model(args, configs) | |
model.eval() | |
print(model) | |
args.feature_size = configs['input_dim'] | |
args.output_size = model.encoder.output_size() | |
args.decoding_window = (args.chunk_size - 1) * \ | |
model.encoder.embed.subsampling_rate + \ | |
model.encoder.embed.right_context + 1 | |
export_encoder(model, args) | |
export_ctc(model, args) | |
export_decoder(model, args) | |