Uploading images and scripts
Browse filesUploading images and scripts from git
- README.md +245 -18
- beam_search_utils.py +5 -6
- hyper_optim.py +149 -0
- images/bsd_equation.png +0 -0
- images/bsd_example_pic.png +0 -0
- images/overall_dataflow.png +0 -0
- images/two_realms.png +0 -0
- images/word_level_spk_prob.png +0 -0
- requirements.txt +1 -0
- run_optuna_hyper_optim.sh +68 -0
- run_speaker_tagging_beam_search.sh +4 -2
- speaker_tagging_beamsearch.py +53 -25
README.md
CHANGED
@@ -1,25 +1,129 @@
|
|
1 |
# llm_speaker_tagging
|
2 |
|
3 |
-
SLT 2024 Challenge: Post-ASR-Speaker-Tagging
|
|
|
4 |
|
5 |
-
#
|
6 |
|
7 |
-
SLT 2024 Challenge GenSEC Track 2: Post-ASR-Speaker-Tagging
|
8 |
|
9 |
-
|
|
|
|
|
10 |
|
11 |
-
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
Run the following commands at the main level of this repository.
|
17 |
|
18 |
### Conda Environment
|
19 |
|
|
|
|
|
20 |
```
|
21 |
conda create --name llmspk python=3.10
|
22 |
```
|
|
|
|
|
23 |
### Install requirements
|
24 |
|
25 |
You need to install the following packages
|
@@ -58,26 +162,62 @@ Clone the dataset from Hugging Face server.
|
|
58 |
git clone https://huggingface.co/datasets/GenSEC-LLM/SLT-Task2-Post-ASR-Speaker-Tagging
|
59 |
```
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
```
|
62 |
-
find
|
63 |
-
find
|
64 |
```
|
65 |
|
66 |
### Launch the baseline script
|
67 |
|
68 |
-
Now you are ready to launch the script.
|
69 |
Launch the baseline script `run_speaker_tagging_beam_search.sh`
|
70 |
|
71 |
-
```
|
72 |
BASEPATH=${PWD}
|
73 |
DIAR_LM_PATH=$BASEPATH/arpa_model/4gram_small.arpa
|
74 |
ASRDIAR_FILE_NAME=err_dev
|
|
|
75 |
WORKSPACE=$BASEPATH/SLT-Task2-Post-ASR-Speaker-Tagging
|
76 |
INPUT_ERROR_SRC_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.src.list
|
77 |
GROUNDTRUTH_REF_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.ref.list
|
78 |
-
DIAR_OUT_DOWNLOAD=$WORKSPACE
|
79 |
mkdir -p $DIAR_OUT_DOWNLOAD
|
80 |
|
|
|
81 |
### SLT 2024 Speaker Tagging Setting v1.0.2
|
82 |
ALPHA=0.4
|
83 |
BETA=0.04
|
@@ -94,11 +234,8 @@ echo "UNIQ MEMO:" $UNIQ_MEMO
|
|
94 |
TRIAL=telephonic
|
95 |
BATCH_SIZE=11
|
96 |
|
97 |
-
rm $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json
|
98 |
-
rm $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
|
99 |
-
rm $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json
|
100 |
-
|
101 |
python $BASEPATH/speaker_tagging_beamsearch.py \
|
|
|
102 |
port=[5501,5502,5511,5512,5521,5522,5531,5532] \
|
103 |
arpa_language_model=$DIAR_LM_PATH \
|
104 |
batch_size=$BATCH_SIZE \
|
@@ -111,7 +248,6 @@ python $BASEPATH/speaker_tagging_beamsearch.py \
|
|
111 |
beam_width=$BEAM_WIDTH \
|
112 |
word_window=$WORD_WINDOW \
|
113 |
peak_prob=$PEAK_PROB \
|
114 |
-
out_dir=$DIAR_OUT_DOWNLOAD
|
115 |
```
|
116 |
|
117 |
### Evaluate
|
@@ -120,7 +256,7 @@ We use [MeetEval](https://github.com/fgnt/meeteval) software to evaluate `cpWER`
|
|
120 |
cpWER measures both speaker tagging and word error rate (WER) by testing all the permutation of trancripts and choosing the permutation that
|
121 |
gives the lowest error.
|
122 |
|
123 |
-
```
|
124 |
echo "Evaluating the original source transcript."
|
125 |
meeteval-wer cpwer -h $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json -r $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
|
126 |
echo "Source cpWER: " $(jq '.error_rate' "[ $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst_cpwer.json) ]"
|
@@ -130,6 +266,97 @@ meeteval-wer cpwer -h $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json -r $WORKSPAC
|
|
130 |
echo "Hypothesis cpWER: " $(jq '.error_rate' $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst_cpwer.json)
|
131 |
```
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
### Reference
|
134 |
|
135 |
@inproceedings{park2024enhancing,
|
|
|
1 |
# llm_speaker_tagging
|
2 |
|
3 |
+
SLT 2024 Challenge: Track-2 Post-ASR-Speaker-Tagging
|
4 |
+
Baseline and Instructions for Track-2
|
5 |
|
6 |
+
# GenSEC Challenge Track-2 Introduction
|
7 |
|
8 |
+
SLT 2024 Challenge GenSEC Track 2: Post-ASR-Speaker-Tagging
|
9 |
|
10 |
+
- Track-2 is a challenge track that aims to correct the speaker tagging of the ASR-generated transcripts tagged with a speaker diarization system.
|
11 |
+
- Since the traditional speaker diarization systems cannot take lexical cues into account, leading to errors that disrupt the context of human conversations.
|
12 |
+
- In the provided dataset, we refer to these erroneous transcript as `err_source_text` (Error source text). Here is an example.
|
13 |
|
14 |
+
- Erroneous Original Transcript `err_source_text`:
|
15 |
+
```json
|
16 |
+
[
|
17 |
+
{"session_id":"session_gen1sec2", "start_time":10.02, "end_time":11.74, "speaker":"speaker1", "words":"what should we talk about well i"},
|
18 |
+
{"session_id":"session_gen1sec2", "start_time":13.32, "end_time":17.08, "speaker":"speaker2", "words":"don't tell you what's need to be"},
|
19 |
+
{"session_id":"session_gen1sec2", "start_time":17.11, "end_time":17.98, "speaker":"speaker1", "words":"discussed"},
|
20 |
+
{"session_id":"session_gen1sec2", "start_time":18.10, "end_time":19.54, "speaker":"speaker2", "words":"because that's something you should figure out"},
|
21 |
+
{"session_id":"session_gen1sec2", "start_time":20.10, "end_time":21.40, "speaker":"speaker1", "words":"okay, then let's talk about our gigs sounds"},
|
22 |
+
{"session_id":"session_gen1sec2", "start_time":21.65, "end_time":23.92, "speaker":"speaker2", "words":"good do you have any specific ideas"},
|
23 |
+
]
|
24 |
+
```
|
25 |
+
Note that the word `well` `i`, `discussed` and `sounds` are tagged with wrong speakers.
|
26 |
+
|
27 |
+
- We expect track2 participants to generate the corrected speaker taggings.
|
28 |
+
- Corrected Transcript Example (hypothesis):
|
29 |
+
```json
|
30 |
+
[
|
31 |
+
{"session_id":"session_gen1sec2", "start_time":0.0, "end_time":0.0, "speaker":"speaker1", "words":"what should we talk about"},
|
32 |
+
{"session_id":"session_gen1sec2", "start_time":0.0, "end_time":0.0, "speaker":"speaker2", "words":"well i don't tell you what's need to be discussed"},
|
33 |
+
{"session_id":"session_gen1sec2", "start_time":0.0, "end_time":0.0, "speaker":"speaker2", "words":"because that's something you should figure out"},
|
34 |
+
{"session_id":"session_gen1sec2", "start_time":0.0, "end_time":0.0, "speaker":"speaker1", "words":"okay then let's talk about our gigs"},
|
35 |
+
{"session_id":"session_gen1sec2", "start_time":0.0, "end_time":0.0, "speaker":"speaker2", "words":"sounds good do you have any specific ideas"}
|
36 |
+
]
|
37 |
+
```
|
38 |
+
- Note that `start_time` and `end_time` cannot be estimated so the timestamps are all assigned as `0.0`.
|
39 |
+
- Please ensure that the order of sentences is maintained so that the output transcripts can be evaluated correctly.
|
40 |
+
- **Dataset:** All development set and evaluation set data samples are formatted in the `seglst.json` format, which is a list containing dictionary variables with the keys specified above:
|
41 |
+
```python
|
42 |
+
{
|
43 |
+
"session_id": str,
|
44 |
+
"start_time": float,
|
45 |
+
"end_time": float,
|
46 |
+
"speaker": str,
|
47 |
+
"words": str,
|
48 |
+
}
|
49 |
+
```
|
50 |
+
|
51 |
+
## Track-2 Rules and Regulations
|
52 |
+
|
53 |
+
1. The participants should **only use text (transcripts)** as the only modality. We do not provide any speech (audio) signal for the transcripts.
|
54 |
+
2. The participants are allowed to correct the words (e.g. `spk1:hi are wow` to `spk1:how are you`) without changing the speaker labels. That is, this involves Track-1 in a way.
|
55 |
+
3. The participants are allowed to use any type of language model and methods.
|
56 |
+
- It does not need to be instruct (chat-based) large language models such as GPTs, LLaMa.
|
57 |
+
- No restrictions on the parameter size of the LLM.
|
58 |
+
- The participants can use prompt tuning, model alignment or any type of fine-tuning methods.
|
59 |
+
- The participants are also allowed to use beam search decoding techniques with LLMs.
|
60 |
+
4. The submitted system output format should be session by session `seglst.json` format and evaluated by `cpwer` metric.
|
61 |
+
5. The participants will submit two json files:
|
62 |
+
|
63 |
+
(1) `err_dev.hyp.seglst.json`
|
64 |
+
(2) `err_eval.hyp.seglst.json`
|
65 |
+
|
66 |
+
for both dev and eval set, respectively.
|
67 |
+
6. In each `err_dev.hyp.seglst.json` `err_eval.hyp.seglst.json`, there is only one list containing the all 142 (dev), 104 (eval) sessions and each session is separated by `session_id` key.
|
68 |
+
|
69 |
+
- Example of the final submission form `err_dev.hyp.seglst.json` and `err_eval.hyp.seglst.json`:
|
70 |
+
```json
|
71 |
+
[
|
72 |
+
{"session_id":"session_abc123ab", "start_time":0.0, "end_time":0.0, "speaker":"speaker1", "words":"well it is what it is"},
|
73 |
+
{"session_id":"session_abc123ab", "start_time":0.0, "end_time":0.0, "speaker":"speaker2", "words":"yeah so be it"},
|
74 |
+
{"session_id":"session_xyz456cd", "start_time":0.0, "end_time":0.0, "speaker":"speaker1", "words":"wow you are late again"},
|
75 |
+
{"session_id":"session_xyz456cd", "start_time":0.0, "end_time":0.0, "speaker":"speaker2", "words":"sorry traffic jam"},
|
76 |
+
{"session_id":"session_xyz456cd", "start_time":0.0, "end_time":0.0, "speaker":"speaker3", "words":"hey how was last night"}
|
77 |
+
]
|
78 |
+
```
|
79 |
+
|
80 |
+
## Baseline System Introduction: Contextudal Beam Search Decoding
|
81 |
+
|
82 |
+
The baseline system is based on the system proposed in [Enhancing Speaker Diarization with Large Language Models: A Contextual Beam Search Approach
|
83 |
+
](https://arxiv.org/pdf/2309.05248) (We refer to this method as Contextual Beam Search (CBS)). Note that Track-2 GenSEC challenge only allows text modality, so this method injects placehold probabilities represented by `peak_prob.`
|
84 |
+
|
85 |
+
The prposed CBS method brings the beam search technique used for ASR language model to speaker diarization.
|
86 |
+
|
87 |
+
|
88 |
+
<img src="images/two_realms.png" width="720" alt="Two Realms"/>
|
89 |
+
|
90 |
+
In CBS method, the following three probability values are needed:
|
91 |
|
92 |
+
**P(E|S)**: Speaker diarization posterior probability (Given speaker S, acoustic observation E)
|
93 |
+
**P(W)**: th probability of the next word W
|
94 |
+
**P(S|W)**: the conditional probability value of the speaker S given the next word
|
95 |
+
|
96 |
+
|
97 |
+
<img src="images/bsd_equation.png" width="360" alt="BSD Equation"/>
|
98 |
+
|
99 |
+
|
100 |
+
Note that the CBS approach assumes that one word is spoken by one speaker. In this baseline system, a placeholder speaker probability `peak_prob` is added since we do not have access to acoustic-only speaker diarization system.
|
101 |
+
|
102 |
+
<img src="images/word_level_spk_prob.png" width="720" alt="Word Level Speaker Probability"/>
|
103 |
+
|
104 |
+
The following diagram explains how beam search decoding works with speaker diarization and ASR.
|
105 |
+
|
106 |
+
<img src="images/bsd_example_pic.png" width="880" alt="Example of beam search decoding with scores"/>
|
107 |
+
|
108 |
+
The overall data-flow is shown as follows. Note that we have fixed value for speaker probability values.
|
109 |
+
|
110 |
+
|
111 |
+
<img src="images/overall_dataflow.png" width="720" alt="Overall Dataflow"/>
|
112 |
+
|
113 |
+
|
114 |
+
## Baseline System Installation
|
115 |
|
116 |
Run the following commands at the main level of this repository.
|
117 |
|
118 |
### Conda Environment
|
119 |
|
120 |
+
The baseline system works with `conda` environment with python 3.10.
|
121 |
+
|
122 |
```
|
123 |
conda create --name llmspk python=3.10
|
124 |
```
|
125 |
+
|
126 |
+
|
127 |
### Install requirements
|
128 |
|
129 |
You need to install the following packages
|
|
|
162 |
git clone https://huggingface.co/datasets/GenSEC-LLM/SLT-Task2-Post-ASR-Speaker-Tagging
|
163 |
```
|
164 |
|
165 |
+
In folder, you will see the following folder structures.
|
166 |
+
|
167 |
+
```bash
|
168 |
+
.
|
169 |
+
├── err_source_text
|
170 |
+
│ ├── dev
|
171 |
+
│ │ ├── session_014b5cda.seglst.json
|
172 |
+
│ │ ├── session_02d73d95.seglst.json
|
173 |
+
│.
|
174 |
+
│..
|
175 |
+
│ │ ├── session_fcd0a550.seglst.json
|
176 |
+
│ │ └── session_ff16b903.seglst.json
|
177 |
+
│ └── eval
|
178 |
+
│ ├── session_0259446c.seglst.json
|
179 |
+
│ ├── session_0bea34fa.seglst.json
|
180 |
+
│..
|
181 |
+
│...
|
182 |
+
│ ├── session_f84edf1f.seglst.json
|
183 |
+
│ └── session_febfa7aa.seglst.json
|
184 |
+
├── ref_annotated_text
|
185 |
+
│ └── dev
|
186 |
+
│ ├── session_014b5cda.seglst.json
|
187 |
+
│ ├── session_02d73d95.seglst.json
|
188 |
+
│.
|
189 |
+
│..
|
190 |
+
│ ├── session_fcd0a550.seglst.json
|
191 |
+
│ └── session_ff16b903.seglst.json
|
192 |
+
```
|
193 |
+
|
194 |
+
The file counts are as follows:
|
195 |
+
- `err_source_text`: dev 142 files, eval 104 files
|
196 |
+
- `ref_annotated_text`: dev 142 files
|
197 |
+
|
198 |
+
Run the following commands to construct the input list files `err_dev.src.list` and `err_dev.ref.list`.
|
199 |
```
|
200 |
+
find $PWD/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev -maxdepth 1 -type f -name "*.seglst.json" > err_dev.src.list
|
201 |
+
find $PWD/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev -maxdepth 1 -type f -name "*.seglst.json" > err_dev.ref.list
|
202 |
```
|
203 |
|
204 |
### Launch the baseline script
|
205 |
|
206 |
+
Now you are ready to launch the baseline script.
|
207 |
Launch the baseline script `run_speaker_tagging_beam_search.sh`
|
208 |
|
209 |
+
```bash
|
210 |
BASEPATH=${PWD}
|
211 |
DIAR_LM_PATH=$BASEPATH/arpa_model/4gram_small.arpa
|
212 |
ASRDIAR_FILE_NAME=err_dev
|
213 |
+
OPTUNA_STUDY_NAME=speaker_beam_search_${ASRDIAR_FILE_NAME}
|
214 |
WORKSPACE=$BASEPATH/SLT-Task2-Post-ASR-Speaker-Tagging
|
215 |
INPUT_ERROR_SRC_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.src.list
|
216 |
GROUNDTRUTH_REF_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.ref.list
|
217 |
+
DIAR_OUT_DOWNLOAD=$WORKSPACE/$ASRDIAR_FILE_NAME
|
218 |
mkdir -p $DIAR_OUT_DOWNLOAD
|
219 |
|
220 |
+
|
221 |
### SLT 2024 Speaker Tagging Setting v1.0.2
|
222 |
ALPHA=0.4
|
223 |
BETA=0.04
|
|
|
234 |
TRIAL=telephonic
|
235 |
BATCH_SIZE=11
|
236 |
|
|
|
|
|
|
|
|
|
237 |
python $BASEPATH/speaker_tagging_beamsearch.py \
|
238 |
+
hyper_params_optim=false \
|
239 |
port=[5501,5502,5511,5512,5521,5522,5531,5532] \
|
240 |
arpa_language_model=$DIAR_LM_PATH \
|
241 |
batch_size=$BATCH_SIZE \
|
|
|
248 |
beam_width=$BEAM_WIDTH \
|
249 |
word_window=$WORD_WINDOW \
|
250 |
peak_prob=$PEAK_PROB \
|
|
|
251 |
```
|
252 |
|
253 |
### Evaluate
|
|
|
256 |
cpWER measures both speaker tagging and word error rate (WER) by testing all the permutation of trancripts and choosing the permutation that
|
257 |
gives the lowest error.
|
258 |
|
259 |
+
```bash
|
260 |
echo "Evaluating the original source transcript."
|
261 |
meeteval-wer cpwer -h $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json -r $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
|
262 |
echo "Source cpWER: " $(jq '.error_rate' "[ $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst_cpwer.json) ]"
|
|
|
266 |
echo "Hypothesis cpWER: " $(jq '.error_rate' $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst_cpwer.json)
|
267 |
```
|
268 |
|
269 |
+
The `cpwer` result will be stored in `./SLT-Task2-Post-ASR-Speaker-Tagging/err_dev.hyp.seglst_cpwer.json` file.
|
270 |
+
|
271 |
+
```bash
|
272 |
+
cat ./SLT-Task2-Post-ASR-Speaker-Tagging/err_dev.hyp.seglst_cpwer.json`
|
273 |
+
```
|
274 |
+
The result file contains a json-dictionary. `"error_rate"` is the `cpwer` value we want to minimize.
|
275 |
+
```json
|
276 |
+
{
|
277 |
+
"error_rate": 0.18784847090516965,
|
278 |
+
"errors": 73077,
|
279 |
+
"length": 389021,
|
280 |
+
"insertions": 13739,
|
281 |
+
"deletions": 42173,
|
282 |
+
"substitutions": 17165,
|
283 |
+
"reference_self_overlap": null,
|
284 |
+
"hypothesis_self_overlap": null,
|
285 |
+
"missed_speaker": 0,
|
286 |
+
"falarm_speaker": 6,
|
287 |
+
"scored_speaker": 330,
|
288 |
+
"assignment": null
|
289 |
+
}
|
290 |
+
```
|
291 |
+
|
292 |
+
|
293 |
+
## Appendix A: LLM example of speaker tagging correction
|
294 |
+
|
295 |
+
This is an example of GPT-based speaker tagging correction. The following text is the prompt fed into Chat-GPT (GPT-4), including the example in this document.
|
296 |
+
|
297 |
+
```markdown
|
298 |
+
- Track-2 is a challenge track that aims to correct the speaker tagging of the ASR-generated transcripts tagged with a speaker diarization system.
|
299 |
+
- Since the traditional speaker diarization systems cannot take lexical cues into account, leading to errors that disrupt the context of human conversations.
|
300 |
+
- In the provided dataset, we refer to these erroneous transcript as `err_source_text` (Error source text). Here is an example.
|
301 |
+
|
302 |
+
- Erroneous Original Transcript `err_source_text`:
|
303 |
+
|
304 |
+
[
|
305 |
+
{"session_id":"session_gen1sec2", "start_time":10.02, "end_time":11.74, "speaker":"speaker1", "words":"what should we talk about well i"},
|
306 |
+
{"session_id":"session_gen1sec2", "start_time":13.32, "end_time":17.08, "speaker":"speaker2", "words":"don't tell you what's need to be"},
|
307 |
+
{"session_id":"session_gen1sec2", "start_time":17.11, "end_time":17.98, "speaker":"speaker1", "words":"discussed"},
|
308 |
+
{"session_id":"session_gen1sec2", "start_time":18.10, "end_time":19.54, "speaker":"speaker2", "words":"because that's something you should figure out"},
|
309 |
+
{"session_id":"session_gen1sec2", "start_time":20.10, "end_time":21.40, "speaker":"speaker1", "words":"okay then let's talk about our gigs sounds"},
|
310 |
+
{"session_id":"session_gen1sec2", "start_time":21.65, "end_time":23.92, "speaker":"speaker2", "words":"good do you have any specific ideas"},
|
311 |
+
]
|
312 |
+
|
313 |
+
Relocate the words to make the conversation looks and sounds natural.
|
314 |
+
- Do not add or remove words.
|
315 |
+
- Do not switch the speaker labels.
|
316 |
+
- Do not change speaker1 to speaker2 or speaker2 to speaker 1.
|
317 |
+
- Move one or two words to other speakers dictionary to make the conversation sound correct.
|
318 |
+
```
|
319 |
+
|
320 |
+
- The answer from GPT-4 (Tested on 05/01/2024).
|
321 |
+
|
322 |
+
```markdown
|
323 |
+
To correct the conversation's flow in the transcript, I'll rearrange the words between the speakers while keeping the speakers consistent with the rules you provided. Here is the corrected version:
|
324 |
+
|
325 |
+
json [code]
|
326 |
+
[
|
327 |
+
{"session_id":"session_gen1sec2", "start_time":10.02, "end_time":11.74, "speaker":"speaker1", "words":"what should we talk about"},
|
328 |
+
{"session_id":"session_gen1sec2", "start_time":13.32, "end_time":17.08, "speaker":"speaker2", "words":"well i don't tell you what's need to be"},
|
329 |
+
{"session_id":"session_gen1sec2", "start_time":17.11, "end_time":17.98, "speaker":"speaker1", "words":"discussed"},
|
330 |
+
{"session_id":"session_gen1sec2", "start_time":18.10, "end_time":19.54, "speaker":"speaker2", "words":"because that's something you should figure out"},
|
331 |
+
{"session_id":"session_gen1sec2", "start_time":20.10, "end_time":21.40, "speaker":"speaker1", "words":"okay then let's talk about our gigs"},
|
332 |
+
{"session_id":"session_gen1sec2", "start_time":21.65, "end_time":23.92, "speaker":"speaker2", "words":"sounds good do you have any specific ideas"}
|
333 |
+
]
|
334 |
+
|
335 |
+
Here are the changes I made:
|
336 |
+
|
337 |
+
I moved "well i" from speaker1's first utterance to the beginning of speaker2's first utterance to better connect the dialogue.
|
338 |
+
"sounds" was moved to the beginning of speaker2's last utterance to better transition from speaker1's proposal to discuss their gigs.
|
339 |
+
```
|
340 |
+
In this example, note that the word `well` `i`, and `sounds` are corrected properly while GPT-4 missed `discussed` and left it assigned to `speaker1`.
|
341 |
+
|
342 |
+
|
343 |
+
- The answer from Llama3 (Tested on 05/01/2024).
|
344 |
+
|
345 |
+
```markdown
|
346 |
+
Here is the relocated transcript to make the conversation sound natural:
|
347 |
+
[
|
348 |
+
{"session_id":"session_gen1sec2", "start_time":10.02, "end_time":11.74, "speaker":"speaker1", "words":"what should we talk about"},
|
349 |
+
{"session_id":"session_gen1sec2", "start_time":13.32, "end_time":15.08, "speaker":"speaker1", "words":"well i don't tell you what's need to be"},
|
350 |
+
{"session_id":"session_gen1sec2", "start_time":15.11, "end_time":17.08, "speaker":"speaker2", "words":"discussed because that's something you should figure out"},
|
351 |
+
{"session_id":"session_gen1sec2", "start_time":18.10, "end_time":19.54, "speaker":"speaker1", "words":"okay then let's talk about our gigs"},
|
352 |
+
{"session_id":"session_gen1sec2", "start_time":21.65, "end_time":22.40, "speaker":"speaker1", "words":"sounds good"},
|
353 |
+
{"session_id":"session_gen1sec2", "start_time":22.45, "end_time":23.92, "speaker":"speaker2", "words":"do you have any specific ideas"},
|
354 |
+
]
|
355 |
+
I made some adjustments to the words to make the conversation flow naturally. Let me know if this meets your requirements!
|
356 |
+
```
|
357 |
+
In this Llama3 example, note that the word `well`, `i` and `sounds` are corrected but nearly all sentences got swapped speaker labels.
|
358 |
+
|
359 |
+
|
360 |
### Reference
|
361 |
|
362 |
@inproceedings{park2024enhancing,
|
beam_search_utils.py
CHANGED
@@ -8,7 +8,7 @@ import json
|
|
8 |
import concurrent.futures
|
9 |
import kenlm
|
10 |
|
11 |
-
__INFO_TAG__ = "[INFO]"
|
12 |
|
13 |
class SpeakerTaggingBeamSearchDecoder:
|
14 |
def __init__(self, loaded_kenlm_model: kenlm, cfg: dict):
|
@@ -127,11 +127,10 @@ class SpeakerTaggingBeamSearchDecoder:
|
|
127 |
div_trans_info_dict[seq_id]['words'] = w_seq
|
128 |
return div_trans_info_dict
|
129 |
|
130 |
-
|
131 |
def run_mp_beam_search_decoding(
|
132 |
speaker_beam_search_decoder,
|
133 |
loaded_kenlm_model,
|
134 |
-
|
135 |
org_trans_info_dict,
|
136 |
div_mp,
|
137 |
win_len,
|
@@ -147,7 +146,7 @@ def run_mp_beam_search_decoding(
|
|
147 |
else:
|
148 |
num_workers = len(port)
|
149 |
|
150 |
-
uniq_id_list = sorted(list(
|
151 |
tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_workers)
|
152 |
futures = []
|
153 |
|
@@ -159,7 +158,7 @@ def run_mp_beam_search_decoding(
|
|
159 |
else:
|
160 |
port_num = None
|
161 |
count += 1
|
162 |
-
uniq_trans_info_dict = {uniq_id:
|
163 |
futures.append(tp.submit(speaker_beam_search_decoder.beam_search_diarization, uniq_trans_info_dict, port_num=port_num))
|
164 |
|
165 |
pbar = tqdm(total=len(uniq_id_list), desc="Running beam search decoding", unit="files")
|
@@ -321,5 +320,5 @@ def write_seglst_jsons(
|
|
321 |
|
322 |
print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json")
|
323 |
total_output_filename = total_output_filename.replace("src", ext_str).replace("ref", ext_str)
|
324 |
-
with open(f'{diar_out_path}
|
325 |
json.dump(total_infer_list, file, indent=4) # indent=4 for pretty printing
|
|
|
8 |
import concurrent.futures
|
9 |
import kenlm
|
10 |
|
11 |
+
__INFO_TAG__ = "[BeamSearchUtil INFO]"
|
12 |
|
13 |
class SpeakerTaggingBeamSearchDecoder:
|
14 |
def __init__(self, loaded_kenlm_model: kenlm, cfg: dict):
|
|
|
127 |
div_trans_info_dict[seq_id]['words'] = w_seq
|
128 |
return div_trans_info_dict
|
129 |
|
|
|
130 |
def run_mp_beam_search_decoding(
|
131 |
speaker_beam_search_decoder,
|
132 |
loaded_kenlm_model,
|
133 |
+
div_trans_info_dict,
|
134 |
org_trans_info_dict,
|
135 |
div_mp,
|
136 |
win_len,
|
|
|
146 |
else:
|
147 |
num_workers = len(port)
|
148 |
|
149 |
+
uniq_id_list = sorted(list(div_trans_info_dict.keys() ))
|
150 |
tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_workers)
|
151 |
futures = []
|
152 |
|
|
|
158 |
else:
|
159 |
port_num = None
|
160 |
count += 1
|
161 |
+
uniq_trans_info_dict = {uniq_id: div_trans_info_dict[uniq_id]}
|
162 |
futures.append(tp.submit(speaker_beam_search_decoder.beam_search_diarization, uniq_trans_info_dict, port_num=port_num))
|
163 |
|
164 |
pbar = tqdm(total=len(uniq_id_list), desc="Running beam search decoding", unit="files")
|
|
|
320 |
|
321 |
print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json")
|
322 |
total_output_filename = total_output_filename.replace("src", ext_str).replace("ref", ext_str)
|
323 |
+
with open(f'{diar_out_path}/{total_output_filename}.seglst.json', 'w') as file:
|
324 |
json.dump(total_infer_list, file, indent=4) # indent=4 for pretty printing
|
hyper_optim.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import optuna
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
import time
|
5 |
+
import json
|
6 |
+
import subprocess
|
7 |
+
import logging
|
8 |
+
from beam_search_utils import (
|
9 |
+
write_seglst_jsons,
|
10 |
+
run_mp_beam_search_decoding,
|
11 |
+
convert_nemo_json_to_seglst,
|
12 |
+
)
|
13 |
+
from hydra.core.config_store import ConfigStore
|
14 |
+
|
15 |
+
|
16 |
+
def evaluate(cfg, temp_out_dir, workspace_dir, asrdiar_file_name, source_info_dict, hypothesis_sessions_dict, reference_info_dict):
|
17 |
+
write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=temp_out_dir, ext_str='hyp')
|
18 |
+
write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=temp_out_dir, ext_str='ref')
|
19 |
+
write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=temp_out_dir, ext_str='src')
|
20 |
+
|
21 |
+
# Construct the file paths
|
22 |
+
src_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst.json")
|
23 |
+
hyp_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst.json")
|
24 |
+
ref_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.ref.seglst.json")
|
25 |
+
|
26 |
+
# Construct the output JSON file path
|
27 |
+
output_cpwer_hyp_json_file = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst_cpwer.json")
|
28 |
+
output_cpwer_src_json_file = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst_cpwer.json")
|
29 |
+
|
30 |
+
# Run meeteval-wer command
|
31 |
+
cmd_hyp = [
|
32 |
+
"meeteval-wer",
|
33 |
+
"cpwer",
|
34 |
+
"-h", hyp_seglst_json,
|
35 |
+
"-r", ref_seglst_json
|
36 |
+
]
|
37 |
+
subprocess.run(cmd_hyp)
|
38 |
+
|
39 |
+
cmd_src = [
|
40 |
+
"meeteval-wer",
|
41 |
+
"cpwer",
|
42 |
+
"-h", src_seglst_json,
|
43 |
+
"-r", ref_seglst_json
|
44 |
+
]
|
45 |
+
subprocess.run(cmd_src)
|
46 |
+
|
47 |
+
# Read the JSON file and print the cpWER
|
48 |
+
try:
|
49 |
+
with open(output_cpwer_hyp_json_file, "r") as file:
|
50 |
+
data_h = json.load(file)
|
51 |
+
print("Hypothesis cpWER:", data_h["error_rate"])
|
52 |
+
cpwer = data_h["error_rate"]
|
53 |
+
logging.info(f"-> HYPOTHESIS cpWER={cpwer:.4f}")
|
54 |
+
except FileNotFoundError:
|
55 |
+
raise FileNotFoundError(f"Output JSON: {output_cpwer_hyp_json_file}\nfile not found.")
|
56 |
+
|
57 |
+
try:
|
58 |
+
with open(output_cpwer_src_json_file, "r") as file:
|
59 |
+
data_s = json.load(file)
|
60 |
+
print("Source cpWER:", data_s["error_rate"])
|
61 |
+
source_cpwer = data_s["error_rate"]
|
62 |
+
logging.info(f"-> SOURCE cpWER={source_cpwer:.4f}")
|
63 |
+
except FileNotFoundError:
|
64 |
+
raise FileNotFoundError(f"Output JSON: {output_cpwer_src_json_file}\nfile not found.")
|
65 |
+
return cpwer
|
66 |
+
|
67 |
+
|
68 |
+
def optuna_suggest_params(cfg, trial):
|
69 |
+
cfg.alpha = trial.suggest_float("alpha", 0.01, 5.0)
|
70 |
+
cfg.beta = trial.suggest_float("beta", 0.001, 2.0)
|
71 |
+
cfg.beam_width = trial.suggest_int("beam_width", 4, 64)
|
72 |
+
cfg.word_window = trial.suggest_int("word_window", 16, 64)
|
73 |
+
cfg.use_ngram = True
|
74 |
+
cfg.parallel_chunk_word_len = trial.suggest_int("parallel_chunk_word_len", 50, 300)
|
75 |
+
cfg.peak_prob = trial.suggest_float("peak_prob", 0.9, 1.0)
|
76 |
+
return cfg
|
77 |
+
|
78 |
+
def beamsearch_objective(
|
79 |
+
trial,
|
80 |
+
cfg,
|
81 |
+
speaker_beam_search_decoder,
|
82 |
+
loaded_kenlm_model,
|
83 |
+
div_trans_info_dict,
|
84 |
+
org_trans_info_dict,
|
85 |
+
source_info_dict,
|
86 |
+
reference_info_dict,
|
87 |
+
):
|
88 |
+
with tempfile.TemporaryDirectory(dir=cfg.temp_out_dir, prefix="GenSEC_") as loca_temp_out_dir:
|
89 |
+
start_time2 = time.time()
|
90 |
+
cfg = optuna_suggest_params(cfg, trial)
|
91 |
+
trans_info_dict = run_mp_beam_search_decoding(speaker_beam_search_decoder,
|
92 |
+
loaded_kenlm_model=loaded_kenlm_model,
|
93 |
+
div_trans_info_dict=div_trans_info_dict,
|
94 |
+
org_trans_info_dict=org_trans_info_dict,
|
95 |
+
div_mp=True,
|
96 |
+
win_len=cfg.parallel_chunk_word_len,
|
97 |
+
word_window=cfg.word_window,
|
98 |
+
port=cfg.port,
|
99 |
+
use_ngram=cfg.use_ngram,
|
100 |
+
)
|
101 |
+
hypothesis_sessions_dict = convert_nemo_json_to_seglst(trans_info_dict)
|
102 |
+
cpwer = evaluate(cfg, loca_temp_out_dir, cfg.workspace_dir, cfg.asrdiar_file_name, source_info_dict, hypothesis_sessions_dict, reference_info_dict)
|
103 |
+
logging.info(f"Beam Search time taken for trial {trial}: {(time.time() - start_time2)/60:.2f} mins")
|
104 |
+
logging.info(f"Trial: {trial.number}")
|
105 |
+
logging.info(f"[ cpWER={cpwer:.4f} ]")
|
106 |
+
logging.info("-----------------------------------------------")
|
107 |
+
|
108 |
+
return cpwer
|
109 |
+
|
110 |
+
|
111 |
+
def optuna_hyper_optim(
|
112 |
+
cfg,
|
113 |
+
speaker_beam_search_decoder,
|
114 |
+
loaded_kenlm_model,
|
115 |
+
div_trans_info_dict,
|
116 |
+
org_trans_info_dict,
|
117 |
+
source_info_dict,
|
118 |
+
reference_info_dict,
|
119 |
+
):
|
120 |
+
"""
|
121 |
+
Optuna hyper-parameter optimization function.
|
122 |
+
|
123 |
+
Parameters:
|
124 |
+
cfg (dict): A dictionary containing the configuration parameters.
|
125 |
+
|
126 |
+
"""
|
127 |
+
worker_function = lambda trial: beamsearch_objective( # noqa: E731
|
128 |
+
trial=trial,
|
129 |
+
cfg=cfg,
|
130 |
+
speaker_beam_search_decoder=speaker_beam_search_decoder,
|
131 |
+
loaded_kenlm_model=loaded_kenlm_model,
|
132 |
+
div_trans_info_dict=div_trans_info_dict,
|
133 |
+
org_trans_info_dict=org_trans_info_dict,
|
134 |
+
source_info_dict=source_info_dict,
|
135 |
+
reference_info_dict=reference_info_dict,
|
136 |
+
)
|
137 |
+
study = optuna.create_study(
|
138 |
+
direction="minimize",
|
139 |
+
study_name=cfg.optuna_study_name,
|
140 |
+
storage=cfg.storage,
|
141 |
+
load_if_exists=True
|
142 |
+
)
|
143 |
+
logger = logging.getLogger()
|
144 |
+
logger.setLevel(logging.INFO) # Setup the root logger.
|
145 |
+
if cfg.output_log_file is not None:
|
146 |
+
logger.addHandler(logging.FileHandler(cfg.output_log_file, mode="a"))
|
147 |
+
logger.addHandler(logging.StreamHandler())
|
148 |
+
optuna.logging.enable_propagation() # Propagate logs to the root logger.
|
149 |
+
study.optimize(worker_function, n_trials=cfg.optuna_n_trials)
|
images/bsd_equation.png
ADDED
![]() |
images/bsd_example_pic.png
ADDED
![]() |
images/overall_dataflow.png
ADDED
![]() |
images/two_realms.png
ADDED
![]() |
images/word_level_spk_prob.png
ADDED
![]() |
requirements.txt
CHANGED
@@ -6,4 +6,5 @@ meeteval
|
|
6 |
tqdm
|
7 |
requests
|
8 |
simplejson
|
|
|
9 |
pydiardecode @ git+https://github.com/tango4j/pydiardecode@main
|
|
|
6 |
tqdm
|
7 |
requests
|
8 |
simplejson
|
9 |
+
optuna
|
10 |
pydiardecode @ git+https://github.com/tango4j/pydiardecode@main
|
run_optuna_hyper_optim.sh
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Speaker Tagging Task-2 Parameters
|
2 |
+
BASEPATH=${PWD}
|
3 |
+
|
4 |
+
# OPTUNA TRIALS
|
5 |
+
OPTUNA_N_TRIALS=999999999
|
6 |
+
|
7 |
+
DIAR_LM_PATH=$BASEPATH/arpa_model/4gram_small.arpa
|
8 |
+
ASRDIAR_FILE_NAME=err_dev
|
9 |
+
OPTUNA_STUDY_NAME=speaker_beam_search_${ASRDIAR_FILE_NAME}
|
10 |
+
WORKSPACE=$BASEPATH/SLT-Task2-Post-ASR-Speaker-Tagging
|
11 |
+
INPUT_ERROR_SRC_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.src.list
|
12 |
+
GROUNDTRUTH_REF_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.ref.list
|
13 |
+
DIAR_OUT_DOWNLOAD=$WORKSPACE/$ASRDIAR_FILE_NAME
|
14 |
+
TEMP_OUT_DIR=$WORKSPACE/temp_out_dir
|
15 |
+
OPTUNA_OUTPUT_LOG_FOLDER=$WORKSPACE/log_outputs
|
16 |
+
OPTUNA_OUTPUT_LOG_FILE=$OPTUNA_OUTPUT_LOG_FOLDER/${OPTUNA_STUDY_NAME}.log
|
17 |
+
STORAGE_PATH="sqlite:///$WORKSPACE/log_outputs/${OPTUNA_STUDY_NAME}.db"
|
18 |
+
|
19 |
+
mkdir -p $DIAR_OUT_DOWNLOAD
|
20 |
+
mkdir -p $TEMP_OUT_DIR
|
21 |
+
mkdir -p $OPTUNA_OUTPUT_LOG_FOLDER
|
22 |
+
|
23 |
+
|
24 |
+
### SLT 2024 Speaker Tagging Setting v1.0.2
|
25 |
+
ALPHA=0.4
|
26 |
+
BETA=0.04
|
27 |
+
PARALLEL_CHUNK_WORD_LEN=100
|
28 |
+
BEAM_WIDTH=8
|
29 |
+
WORD_WINDOW=32
|
30 |
+
PEAK_PROB=0.95
|
31 |
+
USE_NGRAM=True
|
32 |
+
LM_METHOD=ngram
|
33 |
+
|
34 |
+
# Get the base name of the test_manifest and remove extension
|
35 |
+
UNIQ_MEMO=$(basename "${INPUT_ERROR_SRC_LIST_PATH}" .json | sed 's/\./_/g')
|
36 |
+
echo "UNIQ MEMO:" $UNIQ_MEMO
|
37 |
+
TRIAL=telephonic
|
38 |
+
BATCH_SIZE=11
|
39 |
+
|
40 |
+
|
41 |
+
rm $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json
|
42 |
+
rm $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
|
43 |
+
rm $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json
|
44 |
+
|
45 |
+
|
46 |
+
python $BASEPATH/speaker_tagging_beamsearch.py \
|
47 |
+
port=[5501,5502,5511,5512,5521,5522,5531,5532] \
|
48 |
+
arpa_language_model=$DIAR_LM_PATH \
|
49 |
+
batch_size=$BATCH_SIZE \
|
50 |
+
groundtruth_ref_list_path=$GROUNDTRUTH_REF_LIST_PATH \
|
51 |
+
input_error_src_list_path=$INPUT_ERROR_SRC_LIST_PATH \
|
52 |
+
parallel_chunk_word_len=$PARALLEL_CHUNK_WORD_LEN \
|
53 |
+
use_ngram=$USE_NGRAM \
|
54 |
+
alpha=$ALPHA \
|
55 |
+
beta=$BETA \
|
56 |
+
beam_width=$BEAM_WIDTH \
|
57 |
+
word_window=$WORD_WINDOW \
|
58 |
+
peak_prob=$PEAK_PROB \
|
59 |
+
out_dir=$DIAR_OUT_DOWNLOAD \
|
60 |
+
hyper_params_optim=true \
|
61 |
+
optuna_n_trials=$OPTUNA_N_TRIALS \
|
62 |
+
workspace_dir=$WORKSPACE \
|
63 |
+
asrdiar_file_name=$ASRDIAR_FILE_NAME \
|
64 |
+
storage=$STORAGE_PATH \
|
65 |
+
optuna_study_name=$OPTUNA_STUDY_NAME \
|
66 |
+
temp_out_dir=$TEMP_OUT_DIR \
|
67 |
+
output_log_file=$OPTUNA_OUTPUT_LOG_FILE || exit 1
|
68 |
+
|
run_speaker_tagging_beam_search.sh
CHANGED
@@ -5,10 +5,11 @@
|
|
5 |
BASEPATH=${PWD}
|
6 |
DIAR_LM_PATH=$BASEPATH/arpa_model/4gram_small.arpa
|
7 |
ASRDIAR_FILE_NAME=err_dev
|
|
|
8 |
WORKSPACE=$BASEPATH/SLT-Task2-Post-ASR-Speaker-Tagging
|
9 |
INPUT_ERROR_SRC_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.src.list
|
10 |
GROUNDTRUTH_REF_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.ref.list
|
11 |
-
DIAR_OUT_DOWNLOAD=$WORKSPACE
|
12 |
mkdir -p $DIAR_OUT_DOWNLOAD
|
13 |
|
14 |
|
@@ -35,6 +36,7 @@ rm $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json
|
|
35 |
|
36 |
|
37 |
python $BASEPATH/speaker_tagging_beamsearch.py \
|
|
|
38 |
port=[5501,5502,5511,5512,5521,5522,5531,5532] \
|
39 |
arpa_language_model=$DIAR_LM_PATH \
|
40 |
batch_size=$BATCH_SIZE \
|
@@ -47,7 +49,7 @@ python $BASEPATH/speaker_tagging_beamsearch.py \
|
|
47 |
beam_width=$BEAM_WIDTH \
|
48 |
word_window=$WORD_WINDOW \
|
49 |
peak_prob=$PEAK_PROB \
|
50 |
-
|
51 |
|
52 |
|
53 |
echo "Evaluating the original source transcript."
|
|
|
5 |
BASEPATH=${PWD}
|
6 |
DIAR_LM_PATH=$BASEPATH/arpa_model/4gram_small.arpa
|
7 |
ASRDIAR_FILE_NAME=err_dev
|
8 |
+
OPTUNA_STUDY_NAME=speaker_beam_search_${ASRDIAR_FILE_NAME}
|
9 |
WORKSPACE=$BASEPATH/SLT-Task2-Post-ASR-Speaker-Tagging
|
10 |
INPUT_ERROR_SRC_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.src.list
|
11 |
GROUNDTRUTH_REF_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.ref.list
|
12 |
+
DIAR_OUT_DOWNLOAD=$WORKSPACE/$ASRDIAR_FILE_NAME
|
13 |
mkdir -p $DIAR_OUT_DOWNLOAD
|
14 |
|
15 |
|
|
|
36 |
|
37 |
|
38 |
python $BASEPATH/speaker_tagging_beamsearch.py \
|
39 |
+
hyper_params_optim=false \
|
40 |
port=[5501,5502,5511,5512,5521,5522,5531,5532] \
|
41 |
arpa_language_model=$DIAR_LM_PATH \
|
42 |
batch_size=$BATCH_SIZE \
|
|
|
49 |
beam_width=$BEAM_WIDTH \
|
50 |
word_window=$WORD_WINDOW \
|
51 |
peak_prob=$PEAK_PROB \
|
52 |
+
|
53 |
|
54 |
|
55 |
echo "Evaluating the original source transcript."
|
speaker_tagging_beamsearch.py
CHANGED
@@ -11,11 +11,13 @@ from beam_search_utils import (
|
|
11 |
convert_nemo_json_to_seglst,
|
12 |
)
|
13 |
from hydra.core.config_store import ConfigStore
|
|
|
|
|
14 |
|
15 |
-
__INFO_TAG__ = "[INFO]"
|
16 |
|
17 |
@dataclass
|
18 |
class RealigningLanguageModelParameters:
|
|
|
19 |
batch_size: int = 32
|
20 |
use_mp: bool = True
|
21 |
input_error_src_list_path: Optional[str] = None
|
@@ -31,46 +33,72 @@ class RealigningLanguageModelParameters:
|
|
31 |
beam_width: int = 16
|
32 |
out_dir: Optional[str] = None
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
cs = ConfigStore.instance()
|
35 |
cs.store(name="config", node=RealigningLanguageModelParameters)
|
36 |
|
37 |
@hydra.main(config_name="config", version_base="1.1")
|
38 |
def main(cfg: RealigningLanguageModelParameters) -> None:
|
|
|
39 |
trans_info_dict = load_input_jsons(input_error_src_list_path=cfg.input_error_src_list_path, peak_prob=float(cfg.peak_prob))
|
40 |
reference_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.groundtruth_ref_list_path)
|
41 |
source_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.input_error_src_list_path)
|
|
|
|
|
42 |
loaded_kenlm_model = kenlm.Model(cfg.arpa_language_model)
|
43 |
-
|
44 |
speaker_beam_search_decoder = SpeakerTaggingBeamSearchDecoder(loaded_kenlm_model=loaded_kenlm_model, cfg=cfg)
|
45 |
|
46 |
div_trans_info_dict = speaker_beam_search_decoder.divide_chunks(trans_info_dict=trans_info_dict,
|
47 |
win_len=cfg.parallel_chunk_word_len,
|
48 |
word_window=cfg.word_window,
|
49 |
port=cfg.port,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
use_ngram=cfg.use_ngram,
|
60 |
-
)
|
61 |
-
hypothesis_sessions_dict = convert_nemo_json_to_seglst(trans_info_dict)
|
62 |
-
|
63 |
-
write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=cfg.out_dir, ext_str='hyp')
|
64 |
-
write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='ref')
|
65 |
-
write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='src')
|
66 |
-
print(f"{__INFO_TAG__} Parameters used: \
|
67 |
-
\n ALPHA: {cfg.alpha} \
|
68 |
-
\n BETA: {cfg.beta} \
|
69 |
-
\n BEAM WIDTH: {cfg.beam_width} \
|
70 |
-
\n Word Window: {cfg.word_window} \
|
71 |
-
\n Use Ngram: {cfg.use_ngram} \
|
72 |
-
\n Chunk Word Len: {cfg.parallel_chunk_word_len} \
|
73 |
-
\n SpeakerLM Model: {cfg.arpa_language_model}") \
|
74 |
|
75 |
if __name__ == '__main__':
|
76 |
main()
|
|
|
11 |
convert_nemo_json_to_seglst,
|
12 |
)
|
13 |
from hydra.core.config_store import ConfigStore
|
14 |
+
from hyper_optim import optuna_hyper_optim
|
15 |
+
|
16 |
|
|
|
17 |
|
18 |
@dataclass
|
19 |
class RealigningLanguageModelParameters:
|
20 |
+
# Beam search parameters
|
21 |
batch_size: int = 32
|
22 |
use_mp: bool = True
|
23 |
input_error_src_list_path: Optional[str] = None
|
|
|
33 |
beam_width: int = 16
|
34 |
out_dir: Optional[str] = None
|
35 |
|
36 |
+
# Optuna parameters
|
37 |
+
hyper_params_optim: bool = False
|
38 |
+
optuna_n_trials: int = 200
|
39 |
+
workspace_dir: Optional[str] = None
|
40 |
+
asrdiar_file_name: Optional[str] = None
|
41 |
+
storage: Optional[str] = "sqlite:///optuna-speaker-beam-search.db"
|
42 |
+
optuna_study_name: Optional[str] = "speaker_beam_search"
|
43 |
+
output_log_file: Optional[str] = None
|
44 |
+
temp_out_dir: Optional[str] = None
|
45 |
+
|
46 |
cs = ConfigStore.instance()
|
47 |
cs.store(name="config", node=RealigningLanguageModelParameters)
|
48 |
|
49 |
@hydra.main(config_name="config", version_base="1.1")
|
50 |
def main(cfg: RealigningLanguageModelParameters) -> None:
|
51 |
+
__INFO_TAG__ = "[INFO]"
|
52 |
trans_info_dict = load_input_jsons(input_error_src_list_path=cfg.input_error_src_list_path, peak_prob=float(cfg.peak_prob))
|
53 |
reference_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.groundtruth_ref_list_path)
|
54 |
source_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.input_error_src_list_path)
|
55 |
+
|
56 |
+
# Load ARPA language model in advance
|
57 |
loaded_kenlm_model = kenlm.Model(cfg.arpa_language_model)
|
|
|
58 |
speaker_beam_search_decoder = SpeakerTaggingBeamSearchDecoder(loaded_kenlm_model=loaded_kenlm_model, cfg=cfg)
|
59 |
|
60 |
div_trans_info_dict = speaker_beam_search_decoder.divide_chunks(trans_info_dict=trans_info_dict,
|
61 |
win_len=cfg.parallel_chunk_word_len,
|
62 |
word_window=cfg.word_window,
|
63 |
port=cfg.port,)
|
64 |
+
|
65 |
+
if cfg.hyper_params_optim:
|
66 |
+
print(f"{__INFO_TAG__} Optimizing hyper-parameters...")
|
67 |
+
cfg = optuna_hyper_optim(cfg=cfg,
|
68 |
+
speaker_beam_search_decoder=speaker_beam_search_decoder,
|
69 |
+
loaded_kenlm_model=loaded_kenlm_model,
|
70 |
+
div_trans_info_dict=div_trans_info_dict,
|
71 |
+
org_trans_info_dict=trans_info_dict,
|
72 |
+
source_info_dict=source_info_dict,
|
73 |
+
reference_info_dict=reference_info_dict,
|
74 |
+
)
|
75 |
+
|
76 |
+
__INFO_TAG__ = f"{__INFO_TAG__} Optimized hyper-parameters - "
|
77 |
+
else:
|
78 |
+
trans_info_dict = run_mp_beam_search_decoding(speaker_beam_search_decoder,
|
79 |
+
loaded_kenlm_model=loaded_kenlm_model,
|
80 |
+
div_trans_info_dict=div_trans_info_dict,
|
81 |
+
org_trans_info_dict=trans_info_dict,
|
82 |
+
div_mp=True,
|
83 |
+
win_len=cfg.parallel_chunk_word_len,
|
84 |
+
word_window=cfg.word_window,
|
85 |
+
port=cfg.port,
|
86 |
+
use_ngram=cfg.use_ngram,
|
87 |
+
)
|
88 |
+
hypothesis_sessions_dict = convert_nemo_json_to_seglst(trans_info_dict)
|
89 |
+
|
90 |
+
write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=cfg.out_dir, ext_str='hyp')
|
91 |
+
write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='ref')
|
92 |
+
write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='src')
|
93 |
|
94 |
+
print(f"{__INFO_TAG__} Parameters used: \
|
95 |
+
\n ALPHA: {cfg.alpha} \
|
96 |
+
\n BETA: {cfg.beta} \
|
97 |
+
\n BEAM WIDTH: {cfg.beam_width} \
|
98 |
+
\n Word Window: {cfg.word_window} \
|
99 |
+
\n Use Ngram: {cfg.use_ngram} \
|
100 |
+
\n Chunk Word Len: {cfg.parallel_chunk_word_len} \
|
101 |
+
\n SpeakerLM Model: {cfg.arpa_language_model}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
if __name__ == '__main__':
|
104 |
main()
|