|
import os |
|
from collections import Counter |
|
from multiprocessing import get_context |
|
|
|
from .fairseq_dictionary import Dictionary as fairseq_Dictionary |
|
|
|
|
|
class Dictionary(fairseq_Dictionary): |
|
"""Dictionary inheritted from FairSeq""" |
|
|
|
@staticmethod |
|
def _add_transcripts_to_dictionary_single_worker( |
|
transcripts, eos_word, worker_id=0, num_workers=1 |
|
): |
|
counter = Counter() |
|
size = len(transcripts) |
|
chunk_size = size // num_workers |
|
offset = worker_id * chunk_size |
|
end = min(size + 1, offset + chunk_size) |
|
for line in transcripts[offset:end]: |
|
for word in line.split(): |
|
counter.update([word]) |
|
counter.update([eos_word]) |
|
return counter |
|
|
|
@staticmethod |
|
def add_transcripts_to_dictionary(transcripts, dict, num_workers): |
|
def merge_result(counter): |
|
for w, c in sorted(counter.items()): |
|
dict.add_symbol(w, c) |
|
|
|
if num_workers > 1: |
|
pool = get_context("spawn").Pool(processes=num_workers) |
|
results = [] |
|
for worker_id in range(num_workers): |
|
results.append( |
|
pool.apply_async( |
|
Dictionary._add_transcripts_to_dictionary_single_worker, |
|
(transcripts, dict.eos_word, worker_id, num_workers), |
|
) |
|
) |
|
pool.close() |
|
pool.join() |
|
for r in results: |
|
merge_result(r.get()) |
|
else: |
|
merge_result( |
|
Dictionary._add_transcripts_to_dictionary_single_worker( |
|
transcripts, dict.eos_word |
|
) |
|
) |
|
|