csukuangfj's picture
Update model card.
6e94afa
|
raw
history blame
5.21 kB
metadata
language: en
tags:
  - icefall
  - k2
  - transducer
  - aishell
  - ASR
  - stateless transducer
  - PyTorch
license: apache-2.0
datasets:
  - aishell
  - aidatatang_200zh
metrics:
  - WER

Introduction

This repo contains pre-trained model using https://github.com/k2-fsa/icefall/pull/219.

It is trained on AIShell dataset using modified transducer from optimized_transducer. Also, it uses aidatatang_200zh as extra training data.

How to clone this repo

sudo apt-get install git-lfs
git clone https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01

cd icefall-aishell-transducer-stateless-modified-2-2022-03-01
git lfs pull

Catuion: You have to run git lfs pull. Otherwise, you will be SAD later.

The model in this repo is trained using the commit TODO.

You can use

git clone https://github.com/k2-fsa/icefall
cd icefall
git checkout TODO

to download icefall.

You can find the model information by visiting https://github.com/k2-fsa/icefall/blob/TODO/egs/aishell/ASR/transducer_stateless_modified-2/train.py#L232.

In short, the encoder is a Conformer model with 8 heads, 12 encoder layers, 512-dim attention, 2048-dim feedforward; the decoder contains a 512-dim embedding layer and a Conv1d with kernel size 2.

The decoder architecture is modified from Rnn-Transducer with Stateless Prediction Network. A Conv1d layer is placed right after the input embedding layer.


Description

This repo provides pre-trained transducer Conformer model for the AIShell dataset using icefall. There are no RNNs in the decoder. The decoder is stateless and contains only an embedding layer and a Conv1d.

The commands for training are:

cd egs/aishell/ASR
./prepare.sh --stop-stage 6
./prepare_aidatatang_200zh.sh

export CUDA_VISIBLE_DEVICES="0,1,2"

./transducer_stateless_modified-2/train.py \
  --world-size 3 \
  --num-epochs 90 \
  --start-epoch 0 \
  --exp-dir transducer_stateless_modified-2/exp-2 \
  --max-duration 250 \
  --lr-factor 2.0 \
  --context-size 2 \
  --modified-transducer-prob 0.25 \
  --datatang-prob 0.2

The tensorboard training log can be found at https://tensorboard.dev/experiment/oG72ZlWaSGua6fXkcGRRjA/

The commands for decoding are

# greedy search
for epoch in 89; do
  for avg in 38; do
  ./transducer_stateless_modified-2/decode.py \
    --epoch $epoch \
    --avg $avg \
    --exp-dir transducer_stateless_modified-2/exp-2 \
    --max-duration 100 \
    --context-size 2 \
    --decoding-method greedy_search \
    --max-sym-per-frame 1
  done
done

# modified beam search
for epoch in 89; do
  for avg in 38; do
    ./transducer_stateless_modified-2/decode.py \
    --epoch $epoch \
    --avg $avg \
    --exp-dir transducer_stateless_modified-2/exp-2 \
    --max-duration 100 \
    --context-size 2 \
    --decoding-method modified_beam_search \
    --beam-size 4
  done
done

You can find the decoding log for the above command in this repo (in the folder log).

The WER for the test dataset is

test comment
greedy search 4.94 --epoch 89, --avg 38, --max-duration 100, --max-sym-per-frame 1
modified beam search 4.68 --epoch 89, --avg 38, --max-duration 100 --beam-size 4

File description

  • log, this directory contains the decoding log and decoding results
  • test_wavs, this directory contains wave files for testing the pre-trained model
  • data, this directory contains files generated by prepare.sh
  • exp, this directory contains only one file: preprained.pt

exp/pretrained.pt is generated by the following command:

epoch=89
avg=38

./transducer_stateless_modified-2/export.py \
  --exp-dir ./transducer_stateless_modified-2/exp-2 \
  --lang-dir ./data/lang_char \
  --epoch $epoch \
  --avg $avg

HINT: To use pretrained.pt to compute the WER for the test dataset, just do the following:

cp icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
  /path/to/icefall/egs/aishell/ASR/transducer_stateless_modified-2/exp/epoch-999.pt

and pass --epoch 999 --avg 1 to transducer_stateless_modified-2/decode.py.