Taejin commited on
Commit
fc17b57
·
verified ·
1 Parent(s): 5917f0a

Uploading images and scripts

Browse files

Uploading images and scripts from git

README.md CHANGED
@@ -1,25 +1,129 @@
1
  # llm_speaker_tagging
2
 
3
- SLT 2024 Challenge: Post-ASR-Speaker-Tagging Baseline
 
4
 
5
- # Project Name
6
 
7
- SLT 2024 Challenge GenSEC Track 2: Post-ASR-Speaker-Tagging Baseline
8
 
9
- ## Features
 
 
10
 
11
- - Data download and cleaning
12
- - n-gram + beam search decoder based baselinee system
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- ## Installation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  Run the following commands at the main level of this repository.
17
 
18
  ### Conda Environment
19
 
 
 
20
  ```
21
  conda create --name llmspk python=3.10
22
  ```
 
 
23
  ### Install requirements
24
 
25
  You need to install the following packages
@@ -58,26 +162,62 @@ Clone the dataset from Hugging Face server.
58
  git clone https://huggingface.co/datasets/GenSEC-LLM/SLT-Task2-Post-ASR-Speaker-Tagging
59
  ```
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ```
62
- find . $PWD/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev -name *.seglst.json > err_dev.src.list
63
- find . $PWD/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev -name *.seglst.json > err_dev.ref.list
64
  ```
65
 
66
  ### Launch the baseline script
67
 
68
- Now you are ready to launch the script.
69
  Launch the baseline script `run_speaker_tagging_beam_search.sh`
70
 
71
- ```
72
  BASEPATH=${PWD}
73
  DIAR_LM_PATH=$BASEPATH/arpa_model/4gram_small.arpa
74
  ASRDIAR_FILE_NAME=err_dev
 
75
  WORKSPACE=$BASEPATH/SLT-Task2-Post-ASR-Speaker-Tagging
76
  INPUT_ERROR_SRC_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.src.list
77
  GROUNDTRUTH_REF_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.ref.list
78
- DIAR_OUT_DOWNLOAD=$WORKSPACE/short2_all_seglst_infer
79
  mkdir -p $DIAR_OUT_DOWNLOAD
80
 
 
81
  ### SLT 2024 Speaker Tagging Setting v1.0.2
82
  ALPHA=0.4
83
  BETA=0.04
@@ -94,11 +234,8 @@ echo "UNIQ MEMO:" $UNIQ_MEMO
94
  TRIAL=telephonic
95
  BATCH_SIZE=11
96
 
97
- rm $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json
98
- rm $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
99
- rm $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json
100
-
101
  python $BASEPATH/speaker_tagging_beamsearch.py \
 
102
  port=[5501,5502,5511,5512,5521,5522,5531,5532] \
103
  arpa_language_model=$DIAR_LM_PATH \
104
  batch_size=$BATCH_SIZE \
@@ -111,7 +248,6 @@ python $BASEPATH/speaker_tagging_beamsearch.py \
111
  beam_width=$BEAM_WIDTH \
112
  word_window=$WORD_WINDOW \
113
  peak_prob=$PEAK_PROB \
114
- out_dir=$DIAR_OUT_DOWNLOAD
115
  ```
116
 
117
  ### Evaluate
@@ -120,7 +256,7 @@ We use [MeetEval](https://github.com/fgnt/meeteval) software to evaluate `cpWER`
120
  cpWER measures both speaker tagging and word error rate (WER) by testing all the permutation of trancripts and choosing the permutation that
121
  gives the lowest error.
122
 
123
- ```
124
  echo "Evaluating the original source transcript."
125
  meeteval-wer cpwer -h $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json -r $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
126
  echo "Source cpWER: " $(jq '.error_rate' "[ $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst_cpwer.json) ]"
@@ -130,6 +266,97 @@ meeteval-wer cpwer -h $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json -r $WORKSPAC
130
  echo "Hypothesis cpWER: " $(jq '.error_rate' $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst_cpwer.json)
131
  ```
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  ### Reference
134
 
135
  @inproceedings{park2024enhancing,
 
1
  # llm_speaker_tagging
2
 
3
+ SLT 2024 Challenge: Track-2 Post-ASR-Speaker-Tagging
4
+ Baseline and Instructions for Track-2
5
 
6
+ # GenSEC Challenge Track-2 Introduction
7
 
8
+ SLT 2024 Challenge GenSEC Track 2: Post-ASR-Speaker-Tagging
9
 
10
+ - Track-2 is a challenge track that aims to correct the speaker tagging of the ASR-generated transcripts tagged with a speaker diarization system.
11
+ - Since the traditional speaker diarization systems cannot take lexical cues into account, leading to errors that disrupt the context of human conversations.
12
+ - In the provided dataset, we refer to these erroneous transcript as `err_source_text` (Error source text). Here is an example.
13
 
14
+ - Erroneous Original Transcript `err_source_text`:
15
+ ```json
16
+ [
17
+ {"session_id":"session_gen1sec2", "start_time":10.02, "end_time":11.74, "speaker":"speaker1", "words":"what should we talk about well i"},
18
+ {"session_id":"session_gen1sec2", "start_time":13.32, "end_time":17.08, "speaker":"speaker2", "words":"don't tell you what's need to be"},
19
+ {"session_id":"session_gen1sec2", "start_time":17.11, "end_time":17.98, "speaker":"speaker1", "words":"discussed"},
20
+ {"session_id":"session_gen1sec2", "start_time":18.10, "end_time":19.54, "speaker":"speaker2", "words":"because that's something you should figure out"},
21
+ {"session_id":"session_gen1sec2", "start_time":20.10, "end_time":21.40, "speaker":"speaker1", "words":"okay, then let's talk about our gigs sounds"},
22
+ {"session_id":"session_gen1sec2", "start_time":21.65, "end_time":23.92, "speaker":"speaker2", "words":"good do you have any specific ideas"},
23
+ ]
24
+ ```
25
+ Note that the word `well` `i`, `discussed` and `sounds` are tagged with wrong speakers.
26
+
27
+ - We expect track2 participants to generate the corrected speaker taggings.
28
+ - Corrected Transcript Example (hypothesis):
29
+ ```json
30
+ [
31
+ {"session_id":"session_gen1sec2", "start_time":0.0, "end_time":0.0, "speaker":"speaker1", "words":"what should we talk about"},
32
+ {"session_id":"session_gen1sec2", "start_time":0.0, "end_time":0.0, "speaker":"speaker2", "words":"well i don't tell you what's need to be discussed"},
33
+ {"session_id":"session_gen1sec2", "start_time":0.0, "end_time":0.0, "speaker":"speaker2", "words":"because that's something you should figure out"},
34
+ {"session_id":"session_gen1sec2", "start_time":0.0, "end_time":0.0, "speaker":"speaker1", "words":"okay then let's talk about our gigs"},
35
+ {"session_id":"session_gen1sec2", "start_time":0.0, "end_time":0.0, "speaker":"speaker2", "words":"sounds good do you have any specific ideas"}
36
+ ]
37
+ ```
38
+ - Note that `start_time` and `end_time` cannot be estimated so the timestamps are all assigned as `0.0`.
39
+ - Please ensure that the order of sentences is maintained so that the output transcripts can be evaluated correctly.
40
+ - **Dataset:** All development set and evaluation set data samples are formatted in the `seglst.json` format, which is a list containing dictionary variables with the keys specified above:
41
+ ```python
42
+ {
43
+ "session_id": str,
44
+ "start_time": float,
45
+ "end_time": float,
46
+ "speaker": str,
47
+ "words": str,
48
+ }
49
+ ```
50
+
51
+ ## Track-2 Rules and Regulations
52
+
53
+ 1. The participants should **only use text (transcripts)** as the only modality. We do not provide any speech (audio) signal for the transcripts.
54
+ 2. The participants are allowed to correct the words (e.g. `spk1:hi are wow` to `spk1:how are you`) without changing the speaker labels. That is, this involves Track-1 in a way.
55
+ 3. The participants are allowed to use any type of language model and methods.
56
+ - It does not need to be instruct (chat-based) large language models such as GPTs, LLaMa.
57
+ - No restrictions on the parameter size of the LLM.
58
+ - The participants can use prompt tuning, model alignment or any type of fine-tuning methods.
59
+ - The participants are also allowed to use beam search decoding techniques with LLMs.
60
+ 4. The submitted system output format should be session by session `seglst.json` format and evaluated by `cpwer` metric.
61
+ 5. The participants will submit two json files:
62
+
63
+ (1) `err_dev.hyp.seglst.json`
64
+ (2) `err_eval.hyp.seglst.json`
65
+
66
+ for both dev and eval set, respectively.
67
+ 6. In each `err_dev.hyp.seglst.json` `err_eval.hyp.seglst.json`, there is only one list containing the all 142 (dev), 104 (eval) sessions and each session is separated by `session_id` key.
68
+
69
+ - Example of the final submission form `err_dev.hyp.seglst.json` and `err_eval.hyp.seglst.json`:
70
+ ```json
71
+ [
72
+ {"session_id":"session_abc123ab", "start_time":0.0, "end_time":0.0, "speaker":"speaker1", "words":"well it is what it is"},
73
+ {"session_id":"session_abc123ab", "start_time":0.0, "end_time":0.0, "speaker":"speaker2", "words":"yeah so be it"},
74
+ {"session_id":"session_xyz456cd", "start_time":0.0, "end_time":0.0, "speaker":"speaker1", "words":"wow you are late again"},
75
+ {"session_id":"session_xyz456cd", "start_time":0.0, "end_time":0.0, "speaker":"speaker2", "words":"sorry traffic jam"},
76
+ {"session_id":"session_xyz456cd", "start_time":0.0, "end_time":0.0, "speaker":"speaker3", "words":"hey how was last night"}
77
+ ]
78
+ ```
79
+
80
+ ## Baseline System Introduction: Contextudal Beam Search Decoding
81
+
82
+ The baseline system is based on the system proposed in [Enhancing Speaker Diarization with Large Language Models: A Contextual Beam Search Approach
83
+ ](https://arxiv.org/pdf/2309.05248) (We refer to this method as Contextual Beam Search (CBS)). Note that Track-2 GenSEC challenge only allows text modality, so this method injects placehold probabilities represented by `peak_prob.`
84
+
85
+ The prposed CBS method brings the beam search technique used for ASR language model to speaker diarization.
86
+
87
+
88
+ <img src="images/two_realms.png" width="720" alt="Two Realms"/>
89
+
90
+ In CBS method, the following three probability values are needed:
91
 
92
+ **P(E|S)**: Speaker diarization posterior probability (Given speaker S, acoustic observation E)
93
+ **P(W)**: th probability of the next word W
94
+ **P(S|W)**: the conditional probability value of the speaker S given the next word
95
+
96
+
97
+ <img src="images/bsd_equation.png" width="360" alt="BSD Equation"/>
98
+
99
+
100
+ Note that the CBS approach assumes that one word is spoken by one speaker. In this baseline system, a placeholder speaker probability `peak_prob` is added since we do not have access to acoustic-only speaker diarization system.
101
+
102
+ <img src="images/word_level_spk_prob.png" width="720" alt="Word Level Speaker Probability"/>
103
+
104
+ The following diagram explains how beam search decoding works with speaker diarization and ASR.
105
+
106
+ <img src="images/bsd_example_pic.png" width="880" alt="Example of beam search decoding with scores"/>
107
+
108
+ The overall data-flow is shown as follows. Note that we have fixed value for speaker probability values.
109
+
110
+
111
+ <img src="images/overall_dataflow.png" width="720" alt="Overall Dataflow"/>
112
+
113
+
114
+ ## Baseline System Installation
115
 
116
  Run the following commands at the main level of this repository.
117
 
118
  ### Conda Environment
119
 
120
+ The baseline system works with `conda` environment with python 3.10.
121
+
122
  ```
123
  conda create --name llmspk python=3.10
124
  ```
125
+
126
+
127
  ### Install requirements
128
 
129
  You need to install the following packages
 
162
  git clone https://huggingface.co/datasets/GenSEC-LLM/SLT-Task2-Post-ASR-Speaker-Tagging
163
  ```
164
 
165
+ In folder, you will see the following folder structures.
166
+
167
+ ```bash
168
+ .
169
+ ├── err_source_text
170
+ │   ├── dev
171
+ │   │   ├── session_014b5cda.seglst.json
172
+ │   │   ├── session_02d73d95.seglst.json
173
+ │.
174
+ │..
175
+ │   │   ├── session_fcd0a550.seglst.json
176
+ │   │   └── session_ff16b903.seglst.json
177
+ │   └── eval
178
+ │   ├── session_0259446c.seglst.json
179
+ │   ├── session_0bea34fa.seglst.json
180
+ │..
181
+ │...
182
+ │   ├── session_f84edf1f.seglst.json
183
+ │   └── session_febfa7aa.seglst.json
184
+ ├── ref_annotated_text
185
+ │   └── dev
186
+ │   ├── session_014b5cda.seglst.json
187
+ │   ├── session_02d73d95.seglst.json
188
+ │.
189
+ │..
190
+ │   ├── session_fcd0a550.seglst.json
191
+ │   └── session_ff16b903.seglst.json
192
+ ```
193
+
194
+ The file counts are as follows:
195
+ - `err_source_text`: dev 142 files, eval 104 files
196
+ - `ref_annotated_text`: dev 142 files
197
+
198
+ Run the following commands to construct the input list files `err_dev.src.list` and `err_dev.ref.list`.
199
  ```
200
+ find $PWD/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev -maxdepth 1 -type f -name "*.seglst.json" > err_dev.src.list
201
+ find $PWD/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev -maxdepth 1 -type f -name "*.seglst.json" > err_dev.ref.list
202
  ```
203
 
204
  ### Launch the baseline script
205
 
206
+ Now you are ready to launch the baseline script.
207
  Launch the baseline script `run_speaker_tagging_beam_search.sh`
208
 
209
+ ```bash
210
  BASEPATH=${PWD}
211
  DIAR_LM_PATH=$BASEPATH/arpa_model/4gram_small.arpa
212
  ASRDIAR_FILE_NAME=err_dev
213
+ OPTUNA_STUDY_NAME=speaker_beam_search_${ASRDIAR_FILE_NAME}
214
  WORKSPACE=$BASEPATH/SLT-Task2-Post-ASR-Speaker-Tagging
215
  INPUT_ERROR_SRC_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.src.list
216
  GROUNDTRUTH_REF_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.ref.list
217
+ DIAR_OUT_DOWNLOAD=$WORKSPACE/$ASRDIAR_FILE_NAME
218
  mkdir -p $DIAR_OUT_DOWNLOAD
219
 
220
+
221
  ### SLT 2024 Speaker Tagging Setting v1.0.2
222
  ALPHA=0.4
223
  BETA=0.04
 
234
  TRIAL=telephonic
235
  BATCH_SIZE=11
236
 
 
 
 
 
237
  python $BASEPATH/speaker_tagging_beamsearch.py \
238
+ hyper_params_optim=false \
239
  port=[5501,5502,5511,5512,5521,5522,5531,5532] \
240
  arpa_language_model=$DIAR_LM_PATH \
241
  batch_size=$BATCH_SIZE \
 
248
  beam_width=$BEAM_WIDTH \
249
  word_window=$WORD_WINDOW \
250
  peak_prob=$PEAK_PROB \
 
251
  ```
252
 
253
  ### Evaluate
 
256
  cpWER measures both speaker tagging and word error rate (WER) by testing all the permutation of trancripts and choosing the permutation that
257
  gives the lowest error.
258
 
259
+ ```bash
260
  echo "Evaluating the original source transcript."
261
  meeteval-wer cpwer -h $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json -r $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
262
  echo "Source cpWER: " $(jq '.error_rate' "[ $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst_cpwer.json) ]"
 
266
  echo "Hypothesis cpWER: " $(jq '.error_rate' $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst_cpwer.json)
267
  ```
268
 
269
+ The `cpwer` result will be stored in `./SLT-Task2-Post-ASR-Speaker-Tagging/err_dev.hyp.seglst_cpwer.json` file.
270
+
271
+ ```bash
272
+ cat ./SLT-Task2-Post-ASR-Speaker-Tagging/err_dev.hyp.seglst_cpwer.json`
273
+ ```
274
+ The result file contains a json-dictionary. `"error_rate"` is the `cpwer` value we want to minimize.
275
+ ```json
276
+ {
277
+ "error_rate": 0.18784847090516965,
278
+ "errors": 73077,
279
+ "length": 389021,
280
+ "insertions": 13739,
281
+ "deletions": 42173,
282
+ "substitutions": 17165,
283
+ "reference_self_overlap": null,
284
+ "hypothesis_self_overlap": null,
285
+ "missed_speaker": 0,
286
+ "falarm_speaker": 6,
287
+ "scored_speaker": 330,
288
+ "assignment": null
289
+ }
290
+ ```
291
+
292
+
293
+ ## Appendix A: LLM example of speaker tagging correction
294
+
295
+ This is an example of GPT-based speaker tagging correction. The following text is the prompt fed into Chat-GPT (GPT-4), including the example in this document.
296
+
297
+ ```markdown
298
+ - Track-2 is a challenge track that aims to correct the speaker tagging of the ASR-generated transcripts tagged with a speaker diarization system.
299
+ - Since the traditional speaker diarization systems cannot take lexical cues into account, leading to errors that disrupt the context of human conversations.
300
+ - In the provided dataset, we refer to these erroneous transcript as `err_source_text` (Error source text). Here is an example.
301
+
302
+ - Erroneous Original Transcript `err_source_text`:
303
+
304
+ [
305
+ {"session_id":"session_gen1sec2", "start_time":10.02, "end_time":11.74, "speaker":"speaker1", "words":"what should we talk about well i"},
306
+ {"session_id":"session_gen1sec2", "start_time":13.32, "end_time":17.08, "speaker":"speaker2", "words":"don't tell you what's need to be"},
307
+ {"session_id":"session_gen1sec2", "start_time":17.11, "end_time":17.98, "speaker":"speaker1", "words":"discussed"},
308
+ {"session_id":"session_gen1sec2", "start_time":18.10, "end_time":19.54, "speaker":"speaker2", "words":"because that's something you should figure out"},
309
+ {"session_id":"session_gen1sec2", "start_time":20.10, "end_time":21.40, "speaker":"speaker1", "words":"okay then let's talk about our gigs sounds"},
310
+ {"session_id":"session_gen1sec2", "start_time":21.65, "end_time":23.92, "speaker":"speaker2", "words":"good do you have any specific ideas"},
311
+ ]
312
+
313
+ Relocate the words to make the conversation looks and sounds natural.
314
+ - Do not add or remove words.
315
+ - Do not switch the speaker labels.
316
+ - Do not change speaker1 to speaker2 or speaker2 to speaker 1.
317
+ - Move one or two words to other speakers dictionary to make the conversation sound correct.
318
+ ```
319
+
320
+ - The answer from GPT-4 (Tested on 05/01/2024).
321
+
322
+ ```markdown
323
+ To correct the conversation's flow in the transcript, I'll rearrange the words between the speakers while keeping the speakers consistent with the rules you provided. Here is the corrected version:
324
+
325
+ json [code]
326
+ [
327
+ {"session_id":"session_gen1sec2", "start_time":10.02, "end_time":11.74, "speaker":"speaker1", "words":"what should we talk about"},
328
+ {"session_id":"session_gen1sec2", "start_time":13.32, "end_time":17.08, "speaker":"speaker2", "words":"well i don't tell you what's need to be"},
329
+ {"session_id":"session_gen1sec2", "start_time":17.11, "end_time":17.98, "speaker":"speaker1", "words":"discussed"},
330
+ {"session_id":"session_gen1sec2", "start_time":18.10, "end_time":19.54, "speaker":"speaker2", "words":"because that's something you should figure out"},
331
+ {"session_id":"session_gen1sec2", "start_time":20.10, "end_time":21.40, "speaker":"speaker1", "words":"okay then let's talk about our gigs"},
332
+ {"session_id":"session_gen1sec2", "start_time":21.65, "end_time":23.92, "speaker":"speaker2", "words":"sounds good do you have any specific ideas"}
333
+ ]
334
+
335
+ Here are the changes I made:
336
+
337
+ I moved "well i" from speaker1's first utterance to the beginning of speaker2's first utterance to better connect the dialogue.
338
+ "sounds" was moved to the beginning of speaker2's last utterance to better transition from speaker1's proposal to discuss their gigs.
339
+ ```
340
+ In this example, note that the word `well` `i`, and `sounds` are corrected properly while GPT-4 missed `discussed` and left it assigned to `speaker1`.
341
+
342
+
343
+ - The answer from Llama3 (Tested on 05/01/2024).
344
+
345
+ ```markdown
346
+ Here is the relocated transcript to make the conversation sound natural:
347
+ [
348
+ {"session_id":"session_gen1sec2", "start_time":10.02, "end_time":11.74, "speaker":"speaker1", "words":"what should we talk about"},
349
+ {"session_id":"session_gen1sec2", "start_time":13.32, "end_time":15.08, "speaker":"speaker1", "words":"well i don't tell you what's need to be"},
350
+ {"session_id":"session_gen1sec2", "start_time":15.11, "end_time":17.08, "speaker":"speaker2", "words":"discussed because that's something you should figure out"},
351
+ {"session_id":"session_gen1sec2", "start_time":18.10, "end_time":19.54, "speaker":"speaker1", "words":"okay then let's talk about our gigs"},
352
+ {"session_id":"session_gen1sec2", "start_time":21.65, "end_time":22.40, "speaker":"speaker1", "words":"sounds good"},
353
+ {"session_id":"session_gen1sec2", "start_time":22.45, "end_time":23.92, "speaker":"speaker2", "words":"do you have any specific ideas"},
354
+ ]
355
+ I made some adjustments to the words to make the conversation flow naturally. Let me know if this meets your requirements!
356
+ ```
357
+ In this Llama3 example, note that the word `well`, `i` and `sounds` are corrected but nearly all sentences got swapped speaker labels.
358
+
359
+
360
  ### Reference
361
 
362
  @inproceedings{park2024enhancing,
beam_search_utils.py CHANGED
@@ -8,7 +8,7 @@ import json
8
  import concurrent.futures
9
  import kenlm
10
 
11
- __INFO_TAG__ = "[INFO]"
12
 
13
  class SpeakerTaggingBeamSearchDecoder:
14
  def __init__(self, loaded_kenlm_model: kenlm, cfg: dict):
@@ -127,11 +127,10 @@ class SpeakerTaggingBeamSearchDecoder:
127
  div_trans_info_dict[seq_id]['words'] = w_seq
128
  return div_trans_info_dict
129
 
130
-
131
  def run_mp_beam_search_decoding(
132
  speaker_beam_search_decoder,
133
  loaded_kenlm_model,
134
- trans_info_dict,
135
  org_trans_info_dict,
136
  div_mp,
137
  win_len,
@@ -147,7 +146,7 @@ def run_mp_beam_search_decoding(
147
  else:
148
  num_workers = len(port)
149
 
150
- uniq_id_list = sorted(list(trans_info_dict.keys() ))
151
  tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_workers)
152
  futures = []
153
 
@@ -159,7 +158,7 @@ def run_mp_beam_search_decoding(
159
  else:
160
  port_num = None
161
  count += 1
162
- uniq_trans_info_dict = {uniq_id: trans_info_dict[uniq_id]}
163
  futures.append(tp.submit(speaker_beam_search_decoder.beam_search_diarization, uniq_trans_info_dict, port_num=port_num))
164
 
165
  pbar = tqdm(total=len(uniq_id_list), desc="Running beam search decoding", unit="files")
@@ -321,5 +320,5 @@ def write_seglst_jsons(
321
 
322
  print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json")
323
  total_output_filename = total_output_filename.replace("src", ext_str).replace("ref", ext_str)
324
- with open(f'{diar_out_path}/../{total_output_filename}.seglst.json', 'w') as file:
325
  json.dump(total_infer_list, file, indent=4) # indent=4 for pretty printing
 
8
  import concurrent.futures
9
  import kenlm
10
 
11
+ __INFO_TAG__ = "[BeamSearchUtil INFO]"
12
 
13
  class SpeakerTaggingBeamSearchDecoder:
14
  def __init__(self, loaded_kenlm_model: kenlm, cfg: dict):
 
127
  div_trans_info_dict[seq_id]['words'] = w_seq
128
  return div_trans_info_dict
129
 
 
130
  def run_mp_beam_search_decoding(
131
  speaker_beam_search_decoder,
132
  loaded_kenlm_model,
133
+ div_trans_info_dict,
134
  org_trans_info_dict,
135
  div_mp,
136
  win_len,
 
146
  else:
147
  num_workers = len(port)
148
 
149
+ uniq_id_list = sorted(list(div_trans_info_dict.keys() ))
150
  tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_workers)
151
  futures = []
152
 
 
158
  else:
159
  port_num = None
160
  count += 1
161
+ uniq_trans_info_dict = {uniq_id: div_trans_info_dict[uniq_id]}
162
  futures.append(tp.submit(speaker_beam_search_decoder.beam_search_diarization, uniq_trans_info_dict, port_num=port_num))
163
 
164
  pbar = tqdm(total=len(uniq_id_list), desc="Running beam search decoding", unit="files")
 
320
 
321
  print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json")
322
  total_output_filename = total_output_filename.replace("src", ext_str).replace("ref", ext_str)
323
+ with open(f'{diar_out_path}/{total_output_filename}.seglst.json', 'w') as file:
324
  json.dump(total_infer_list, file, indent=4) # indent=4 for pretty printing
hyper_optim.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+ import os
3
+ import tempfile
4
+ import time
5
+ import json
6
+ import subprocess
7
+ import logging
8
+ from beam_search_utils import (
9
+ write_seglst_jsons,
10
+ run_mp_beam_search_decoding,
11
+ convert_nemo_json_to_seglst,
12
+ )
13
+ from hydra.core.config_store import ConfigStore
14
+
15
+
16
+ def evaluate(cfg, temp_out_dir, workspace_dir, asrdiar_file_name, source_info_dict, hypothesis_sessions_dict, reference_info_dict):
17
+ write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=temp_out_dir, ext_str='hyp')
18
+ write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=temp_out_dir, ext_str='ref')
19
+ write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=temp_out_dir, ext_str='src')
20
+
21
+ # Construct the file paths
22
+ src_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst.json")
23
+ hyp_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst.json")
24
+ ref_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.ref.seglst.json")
25
+
26
+ # Construct the output JSON file path
27
+ output_cpwer_hyp_json_file = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst_cpwer.json")
28
+ output_cpwer_src_json_file = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst_cpwer.json")
29
+
30
+ # Run meeteval-wer command
31
+ cmd_hyp = [
32
+ "meeteval-wer",
33
+ "cpwer",
34
+ "-h", hyp_seglst_json,
35
+ "-r", ref_seglst_json
36
+ ]
37
+ subprocess.run(cmd_hyp)
38
+
39
+ cmd_src = [
40
+ "meeteval-wer",
41
+ "cpwer",
42
+ "-h", src_seglst_json,
43
+ "-r", ref_seglst_json
44
+ ]
45
+ subprocess.run(cmd_src)
46
+
47
+ # Read the JSON file and print the cpWER
48
+ try:
49
+ with open(output_cpwer_hyp_json_file, "r") as file:
50
+ data_h = json.load(file)
51
+ print("Hypothesis cpWER:", data_h["error_rate"])
52
+ cpwer = data_h["error_rate"]
53
+ logging.info(f"-> HYPOTHESIS cpWER={cpwer:.4f}")
54
+ except FileNotFoundError:
55
+ raise FileNotFoundError(f"Output JSON: {output_cpwer_hyp_json_file}\nfile not found.")
56
+
57
+ try:
58
+ with open(output_cpwer_src_json_file, "r") as file:
59
+ data_s = json.load(file)
60
+ print("Source cpWER:", data_s["error_rate"])
61
+ source_cpwer = data_s["error_rate"]
62
+ logging.info(f"-> SOURCE cpWER={source_cpwer:.4f}")
63
+ except FileNotFoundError:
64
+ raise FileNotFoundError(f"Output JSON: {output_cpwer_src_json_file}\nfile not found.")
65
+ return cpwer
66
+
67
+
68
+ def optuna_suggest_params(cfg, trial):
69
+ cfg.alpha = trial.suggest_float("alpha", 0.01, 5.0)
70
+ cfg.beta = trial.suggest_float("beta", 0.001, 2.0)
71
+ cfg.beam_width = trial.suggest_int("beam_width", 4, 64)
72
+ cfg.word_window = trial.suggest_int("word_window", 16, 64)
73
+ cfg.use_ngram = True
74
+ cfg.parallel_chunk_word_len = trial.suggest_int("parallel_chunk_word_len", 50, 300)
75
+ cfg.peak_prob = trial.suggest_float("peak_prob", 0.9, 1.0)
76
+ return cfg
77
+
78
+ def beamsearch_objective(
79
+ trial,
80
+ cfg,
81
+ speaker_beam_search_decoder,
82
+ loaded_kenlm_model,
83
+ div_trans_info_dict,
84
+ org_trans_info_dict,
85
+ source_info_dict,
86
+ reference_info_dict,
87
+ ):
88
+ with tempfile.TemporaryDirectory(dir=cfg.temp_out_dir, prefix="GenSEC_") as loca_temp_out_dir:
89
+ start_time2 = time.time()
90
+ cfg = optuna_suggest_params(cfg, trial)
91
+ trans_info_dict = run_mp_beam_search_decoding(speaker_beam_search_decoder,
92
+ loaded_kenlm_model=loaded_kenlm_model,
93
+ div_trans_info_dict=div_trans_info_dict,
94
+ org_trans_info_dict=org_trans_info_dict,
95
+ div_mp=True,
96
+ win_len=cfg.parallel_chunk_word_len,
97
+ word_window=cfg.word_window,
98
+ port=cfg.port,
99
+ use_ngram=cfg.use_ngram,
100
+ )
101
+ hypothesis_sessions_dict = convert_nemo_json_to_seglst(trans_info_dict)
102
+ cpwer = evaluate(cfg, loca_temp_out_dir, cfg.workspace_dir, cfg.asrdiar_file_name, source_info_dict, hypothesis_sessions_dict, reference_info_dict)
103
+ logging.info(f"Beam Search time taken for trial {trial}: {(time.time() - start_time2)/60:.2f} mins")
104
+ logging.info(f"Trial: {trial.number}")
105
+ logging.info(f"[ cpWER={cpwer:.4f} ]")
106
+ logging.info("-----------------------------------------------")
107
+
108
+ return cpwer
109
+
110
+
111
+ def optuna_hyper_optim(
112
+ cfg,
113
+ speaker_beam_search_decoder,
114
+ loaded_kenlm_model,
115
+ div_trans_info_dict,
116
+ org_trans_info_dict,
117
+ source_info_dict,
118
+ reference_info_dict,
119
+ ):
120
+ """
121
+ Optuna hyper-parameter optimization function.
122
+
123
+ Parameters:
124
+ cfg (dict): A dictionary containing the configuration parameters.
125
+
126
+ """
127
+ worker_function = lambda trial: beamsearch_objective( # noqa: E731
128
+ trial=trial,
129
+ cfg=cfg,
130
+ speaker_beam_search_decoder=speaker_beam_search_decoder,
131
+ loaded_kenlm_model=loaded_kenlm_model,
132
+ div_trans_info_dict=div_trans_info_dict,
133
+ org_trans_info_dict=org_trans_info_dict,
134
+ source_info_dict=source_info_dict,
135
+ reference_info_dict=reference_info_dict,
136
+ )
137
+ study = optuna.create_study(
138
+ direction="minimize",
139
+ study_name=cfg.optuna_study_name,
140
+ storage=cfg.storage,
141
+ load_if_exists=True
142
+ )
143
+ logger = logging.getLogger()
144
+ logger.setLevel(logging.INFO) # Setup the root logger.
145
+ if cfg.output_log_file is not None:
146
+ logger.addHandler(logging.FileHandler(cfg.output_log_file, mode="a"))
147
+ logger.addHandler(logging.StreamHandler())
148
+ optuna.logging.enable_propagation() # Propagate logs to the root logger.
149
+ study.optimize(worker_function, n_trials=cfg.optuna_n_trials)
images/bsd_equation.png ADDED
images/bsd_example_pic.png ADDED
images/overall_dataflow.png ADDED
images/two_realms.png ADDED
images/word_level_spk_prob.png ADDED
requirements.txt CHANGED
@@ -6,4 +6,5 @@ meeteval
6
  tqdm
7
  requests
8
  simplejson
 
9
  pydiardecode @ git+https://github.com/tango4j/pydiardecode@main
 
6
  tqdm
7
  requests
8
  simplejson
9
+ optuna
10
  pydiardecode @ git+https://github.com/tango4j/pydiardecode@main
run_optuna_hyper_optim.sh ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Speaker Tagging Task-2 Parameters
2
+ BASEPATH=${PWD}
3
+
4
+ # OPTUNA TRIALS
5
+ OPTUNA_N_TRIALS=999999999
6
+
7
+ DIAR_LM_PATH=$BASEPATH/arpa_model/4gram_small.arpa
8
+ ASRDIAR_FILE_NAME=err_dev
9
+ OPTUNA_STUDY_NAME=speaker_beam_search_${ASRDIAR_FILE_NAME}
10
+ WORKSPACE=$BASEPATH/SLT-Task2-Post-ASR-Speaker-Tagging
11
+ INPUT_ERROR_SRC_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.src.list
12
+ GROUNDTRUTH_REF_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.ref.list
13
+ DIAR_OUT_DOWNLOAD=$WORKSPACE/$ASRDIAR_FILE_NAME
14
+ TEMP_OUT_DIR=$WORKSPACE/temp_out_dir
15
+ OPTUNA_OUTPUT_LOG_FOLDER=$WORKSPACE/log_outputs
16
+ OPTUNA_OUTPUT_LOG_FILE=$OPTUNA_OUTPUT_LOG_FOLDER/${OPTUNA_STUDY_NAME}.log
17
+ STORAGE_PATH="sqlite:///$WORKSPACE/log_outputs/${OPTUNA_STUDY_NAME}.db"
18
+
19
+ mkdir -p $DIAR_OUT_DOWNLOAD
20
+ mkdir -p $TEMP_OUT_DIR
21
+ mkdir -p $OPTUNA_OUTPUT_LOG_FOLDER
22
+
23
+
24
+ ### SLT 2024 Speaker Tagging Setting v1.0.2
25
+ ALPHA=0.4
26
+ BETA=0.04
27
+ PARALLEL_CHUNK_WORD_LEN=100
28
+ BEAM_WIDTH=8
29
+ WORD_WINDOW=32
30
+ PEAK_PROB=0.95
31
+ USE_NGRAM=True
32
+ LM_METHOD=ngram
33
+
34
+ # Get the base name of the test_manifest and remove extension
35
+ UNIQ_MEMO=$(basename "${INPUT_ERROR_SRC_LIST_PATH}" .json | sed 's/\./_/g')
36
+ echo "UNIQ MEMO:" $UNIQ_MEMO
37
+ TRIAL=telephonic
38
+ BATCH_SIZE=11
39
+
40
+
41
+ rm $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json
42
+ rm $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
43
+ rm $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json
44
+
45
+
46
+ python $BASEPATH/speaker_tagging_beamsearch.py \
47
+ port=[5501,5502,5511,5512,5521,5522,5531,5532] \
48
+ arpa_language_model=$DIAR_LM_PATH \
49
+ batch_size=$BATCH_SIZE \
50
+ groundtruth_ref_list_path=$GROUNDTRUTH_REF_LIST_PATH \
51
+ input_error_src_list_path=$INPUT_ERROR_SRC_LIST_PATH \
52
+ parallel_chunk_word_len=$PARALLEL_CHUNK_WORD_LEN \
53
+ use_ngram=$USE_NGRAM \
54
+ alpha=$ALPHA \
55
+ beta=$BETA \
56
+ beam_width=$BEAM_WIDTH \
57
+ word_window=$WORD_WINDOW \
58
+ peak_prob=$PEAK_PROB \
59
+ out_dir=$DIAR_OUT_DOWNLOAD \
60
+ hyper_params_optim=true \
61
+ optuna_n_trials=$OPTUNA_N_TRIALS \
62
+ workspace_dir=$WORKSPACE \
63
+ asrdiar_file_name=$ASRDIAR_FILE_NAME \
64
+ storage=$STORAGE_PATH \
65
+ optuna_study_name=$OPTUNA_STUDY_NAME \
66
+ temp_out_dir=$TEMP_OUT_DIR \
67
+ output_log_file=$OPTUNA_OUTPUT_LOG_FILE || exit 1
68
+
run_speaker_tagging_beam_search.sh CHANGED
@@ -5,10 +5,11 @@
5
  BASEPATH=${PWD}
6
  DIAR_LM_PATH=$BASEPATH/arpa_model/4gram_small.arpa
7
  ASRDIAR_FILE_NAME=err_dev
 
8
  WORKSPACE=$BASEPATH/SLT-Task2-Post-ASR-Speaker-Tagging
9
  INPUT_ERROR_SRC_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.src.list
10
  GROUNDTRUTH_REF_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.ref.list
11
- DIAR_OUT_DOWNLOAD=$WORKSPACE/short2_all_seglst_infer
12
  mkdir -p $DIAR_OUT_DOWNLOAD
13
 
14
 
@@ -35,6 +36,7 @@ rm $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json
35
 
36
 
37
  python $BASEPATH/speaker_tagging_beamsearch.py \
 
38
  port=[5501,5502,5511,5512,5521,5522,5531,5532] \
39
  arpa_language_model=$DIAR_LM_PATH \
40
  batch_size=$BATCH_SIZE \
@@ -47,7 +49,7 @@ python $BASEPATH/speaker_tagging_beamsearch.py \
47
  beam_width=$BEAM_WIDTH \
48
  word_window=$WORD_WINDOW \
49
  peak_prob=$PEAK_PROB \
50
- out_dir=$DIAR_OUT_DOWNLOAD
51
 
52
 
53
  echo "Evaluating the original source transcript."
 
5
  BASEPATH=${PWD}
6
  DIAR_LM_PATH=$BASEPATH/arpa_model/4gram_small.arpa
7
  ASRDIAR_FILE_NAME=err_dev
8
+ OPTUNA_STUDY_NAME=speaker_beam_search_${ASRDIAR_FILE_NAME}
9
  WORKSPACE=$BASEPATH/SLT-Task2-Post-ASR-Speaker-Tagging
10
  INPUT_ERROR_SRC_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.src.list
11
  GROUNDTRUTH_REF_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.ref.list
12
+ DIAR_OUT_DOWNLOAD=$WORKSPACE/$ASRDIAR_FILE_NAME
13
  mkdir -p $DIAR_OUT_DOWNLOAD
14
 
15
 
 
36
 
37
 
38
  python $BASEPATH/speaker_tagging_beamsearch.py \
39
+ hyper_params_optim=false \
40
  port=[5501,5502,5511,5512,5521,5522,5531,5532] \
41
  arpa_language_model=$DIAR_LM_PATH \
42
  batch_size=$BATCH_SIZE \
 
49
  beam_width=$BEAM_WIDTH \
50
  word_window=$WORD_WINDOW \
51
  peak_prob=$PEAK_PROB \
52
+
53
 
54
 
55
  echo "Evaluating the original source transcript."
speaker_tagging_beamsearch.py CHANGED
@@ -11,11 +11,13 @@ from beam_search_utils import (
11
  convert_nemo_json_to_seglst,
12
  )
13
  from hydra.core.config_store import ConfigStore
 
 
14
 
15
- __INFO_TAG__ = "[INFO]"
16
 
17
  @dataclass
18
  class RealigningLanguageModelParameters:
 
19
  batch_size: int = 32
20
  use_mp: bool = True
21
  input_error_src_list_path: Optional[str] = None
@@ -31,46 +33,72 @@ class RealigningLanguageModelParameters:
31
  beam_width: int = 16
32
  out_dir: Optional[str] = None
33
 
 
 
 
 
 
 
 
 
 
 
34
  cs = ConfigStore.instance()
35
  cs.store(name="config", node=RealigningLanguageModelParameters)
36
 
37
  @hydra.main(config_name="config", version_base="1.1")
38
  def main(cfg: RealigningLanguageModelParameters) -> None:
 
39
  trans_info_dict = load_input_jsons(input_error_src_list_path=cfg.input_error_src_list_path, peak_prob=float(cfg.peak_prob))
40
  reference_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.groundtruth_ref_list_path)
41
  source_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.input_error_src_list_path)
 
 
42
  loaded_kenlm_model = kenlm.Model(cfg.arpa_language_model)
43
-
44
  speaker_beam_search_decoder = SpeakerTaggingBeamSearchDecoder(loaded_kenlm_model=loaded_kenlm_model, cfg=cfg)
45
 
46
  div_trans_info_dict = speaker_beam_search_decoder.divide_chunks(trans_info_dict=trans_info_dict,
47
  win_len=cfg.parallel_chunk_word_len,
48
  word_window=cfg.word_window,
49
  port=cfg.port,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- trans_info_dict = run_mp_beam_search_decoding(speaker_beam_search_decoder,
52
- loaded_kenlm_model=loaded_kenlm_model,
53
- trans_info_dict=div_trans_info_dict,
54
- org_trans_info_dict=trans_info_dict,
55
- div_mp=True,
56
- win_len=cfg.parallel_chunk_word_len,
57
- word_window=cfg.word_window,
58
- port=cfg.port,
59
- use_ngram=cfg.use_ngram,
60
- )
61
- hypothesis_sessions_dict = convert_nemo_json_to_seglst(trans_info_dict)
62
-
63
- write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=cfg.out_dir, ext_str='hyp')
64
- write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='ref')
65
- write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='src')
66
- print(f"{__INFO_TAG__} Parameters used: \
67
- \n ALPHA: {cfg.alpha} \
68
- \n BETA: {cfg.beta} \
69
- \n BEAM WIDTH: {cfg.beam_width} \
70
- \n Word Window: {cfg.word_window} \
71
- \n Use Ngram: {cfg.use_ngram} \
72
- \n Chunk Word Len: {cfg.parallel_chunk_word_len} \
73
- \n SpeakerLM Model: {cfg.arpa_language_model}") \
74
 
75
  if __name__ == '__main__':
76
  main()
 
11
  convert_nemo_json_to_seglst,
12
  )
13
  from hydra.core.config_store import ConfigStore
14
+ from hyper_optim import optuna_hyper_optim
15
+
16
 
 
17
 
18
  @dataclass
19
  class RealigningLanguageModelParameters:
20
+ # Beam search parameters
21
  batch_size: int = 32
22
  use_mp: bool = True
23
  input_error_src_list_path: Optional[str] = None
 
33
  beam_width: int = 16
34
  out_dir: Optional[str] = None
35
 
36
+ # Optuna parameters
37
+ hyper_params_optim: bool = False
38
+ optuna_n_trials: int = 200
39
+ workspace_dir: Optional[str] = None
40
+ asrdiar_file_name: Optional[str] = None
41
+ storage: Optional[str] = "sqlite:///optuna-speaker-beam-search.db"
42
+ optuna_study_name: Optional[str] = "speaker_beam_search"
43
+ output_log_file: Optional[str] = None
44
+ temp_out_dir: Optional[str] = None
45
+
46
  cs = ConfigStore.instance()
47
  cs.store(name="config", node=RealigningLanguageModelParameters)
48
 
49
  @hydra.main(config_name="config", version_base="1.1")
50
  def main(cfg: RealigningLanguageModelParameters) -> None:
51
+ __INFO_TAG__ = "[INFO]"
52
  trans_info_dict = load_input_jsons(input_error_src_list_path=cfg.input_error_src_list_path, peak_prob=float(cfg.peak_prob))
53
  reference_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.groundtruth_ref_list_path)
54
  source_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.input_error_src_list_path)
55
+
56
+ # Load ARPA language model in advance
57
  loaded_kenlm_model = kenlm.Model(cfg.arpa_language_model)
 
58
  speaker_beam_search_decoder = SpeakerTaggingBeamSearchDecoder(loaded_kenlm_model=loaded_kenlm_model, cfg=cfg)
59
 
60
  div_trans_info_dict = speaker_beam_search_decoder.divide_chunks(trans_info_dict=trans_info_dict,
61
  win_len=cfg.parallel_chunk_word_len,
62
  word_window=cfg.word_window,
63
  port=cfg.port,)
64
+
65
+ if cfg.hyper_params_optim:
66
+ print(f"{__INFO_TAG__} Optimizing hyper-parameters...")
67
+ cfg = optuna_hyper_optim(cfg=cfg,
68
+ speaker_beam_search_decoder=speaker_beam_search_decoder,
69
+ loaded_kenlm_model=loaded_kenlm_model,
70
+ div_trans_info_dict=div_trans_info_dict,
71
+ org_trans_info_dict=trans_info_dict,
72
+ source_info_dict=source_info_dict,
73
+ reference_info_dict=reference_info_dict,
74
+ )
75
+
76
+ __INFO_TAG__ = f"{__INFO_TAG__} Optimized hyper-parameters - "
77
+ else:
78
+ trans_info_dict = run_mp_beam_search_decoding(speaker_beam_search_decoder,
79
+ loaded_kenlm_model=loaded_kenlm_model,
80
+ div_trans_info_dict=div_trans_info_dict,
81
+ org_trans_info_dict=trans_info_dict,
82
+ div_mp=True,
83
+ win_len=cfg.parallel_chunk_word_len,
84
+ word_window=cfg.word_window,
85
+ port=cfg.port,
86
+ use_ngram=cfg.use_ngram,
87
+ )
88
+ hypothesis_sessions_dict = convert_nemo_json_to_seglst(trans_info_dict)
89
+
90
+ write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=cfg.out_dir, ext_str='hyp')
91
+ write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='ref')
92
+ write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='src')
93
 
94
+ print(f"{__INFO_TAG__} Parameters used: \
95
+ \n ALPHA: {cfg.alpha} \
96
+ \n BETA: {cfg.beta} \
97
+ \n BEAM WIDTH: {cfg.beam_width} \
98
+ \n Word Window: {cfg.word_window} \
99
+ \n Use Ngram: {cfg.use_ngram} \
100
+ \n Chunk Word Len: {cfg.parallel_chunk_word_len} \
101
+ \n SpeakerLM Model: {cfg.arpa_language_model}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  if __name__ == '__main__':
104
  main()