|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
|
import torch, torchaudio |
|
import soundfile as sf |
|
import numpy as np |
|
from scipy import signal |
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
torch.random.manual_seed(0) |
|
|
|
|
|
def get_processor_labels(processor,word_sep="|",max_labels=100): |
|
ixs = sorted(list(range(max_labels)),reverse=True) |
|
return {processor.tokenizer.decode(n) or word_sep:n for n in ixs} |
|
|
|
|
|
is_MODEL_PATH="carlosdanielhernandezmena/wav2vec2-large-xlsr-53-icelandic-ep10-1000h" |
|
is_model_blank_token = '[PAD]' |
|
is_model_word_separator = '|' |
|
|
|
is_model = Wav2Vec2ForCTC.from_pretrained(is_MODEL_PATH).to(device) |
|
is_processor = Wav2Vec2Processor.from_pretrained(is_MODEL_PATH) |
|
is_labels_dict = get_processor_labels(is_processor, is_model_word_separator) |
|
is_inverse_dict = {v:k for k,v in is_labels_dict.items()} |
|
is_all_labels = tuple(is_labels_dict.keys()) |
|
is_blank_id = is_labels_dict[is_model_blank_token] |
|
|
|
|
|
fo_MODEL_PATH="carlosdanielhernandezmena/wav2vec2-large-xlsr-53-faroese-100h" |
|
fo_model_blank_token = '[PAD]' |
|
fo_model_word_separator = '|' |
|
|
|
fo_model = Wav2Vec2ForCTC.from_pretrained(fo_MODEL_PATH).to(device) |
|
fo_processor = Wav2Vec2Processor.from_pretrained(fo_MODEL_PATH) |
|
fo_labels_dict = get_processor_labels(fo_processor, fo_model_word_separator) |
|
fo_inverse_dict = {v:k for k,v in fo_labels_dict.items()} |
|
fo_all_labels = tuple(fo_labels_dict.keys()) |
|
fo_blank_id = fo_labels_dict[fo_model_blank_token] |
|
|
|
|
|
no_MODEL_PATH="NbAiLab/nb-wav2vec2-1b-bokmaal" |
|
no_model_blank_token = '[PAD]' |
|
no_model_word_separator = '|' |
|
|
|
no_model = Wav2Vec2ForCTC.from_pretrained(no_MODEL_PATH).to(device) |
|
no_processor = Wav2Vec2Processor.from_pretrained(no_MODEL_PATH) |
|
no_labels_dict = get_processor_labels(no_processor, no_model_word_separator) |
|
no_inverse_dict = {v:k for k,v in no_labels_dict.items()} |
|
no_all_labels = tuple(no_labels_dict.keys()) |
|
no_blank_id = no_labels_dict[no_model_blank_token] |
|
|
|
d = {"Icelandic": {'model': is_model, 'processor': is_processor, 'inverse_dict': is_inverse_dict, 'labels_dict': is_labels_dict, 'all_labels': is_all_labels, 'blank_id': is_blank_id, 'model_blank_token': is_model_blank_token, 'model_word_separator': is_model_word_separator}, "Faroese": {'model': fo_model, 'processor': fo_processor, 'inverse_dict': fo_inverse_dict, 'labels_dict': fo_labels_dict, 'all_labels': fo_all_labels, 'blank_id': fo_blank_id, 'model_blank_token': fo_model_blank_token, 'model_word_separator': fo_model_word_separator}, "Norwegian": {'model': no_model, 'processor': no_processor, 'inverse_dict': no_inverse_dict, 'labels_dict': no_labels_dict, 'all_labels': no_all_labels, 'blank_id': no_blank_id, 'model_blank_token': no_model_blank_token, 'model_word_separator': no_model_word_separator} } |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def f2s(fr): |
|
return fr/50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_frame_probs(wav_path,lang): |
|
wav = readwav(wav_path) |
|
with torch.inference_mode(): |
|
input_values = d[lang]['processor'](wav,sampling_rate=16000).input_values[0] |
|
input_values = torch.tensor(input_values, device=device).unsqueeze(0) |
|
emits = d[lang]['model'](input_values).logits |
|
emits = torch.log_softmax(emits, dim=-1) |
|
emit = emits[0].cpu().detach() |
|
return emit |
|
|
|
|
|
def get_trellis(emission, tokens, blank_id): |
|
num_frame = emission.size(0) |
|
num_tokens = len(tokens) |
|
|
|
|
|
|
|
trellis = torch.empty((num_frame + 1, num_tokens + 1)) |
|
trellis[0, 0] = 0 |
|
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) |
|
trellis[0, -num_tokens:] = -float("inf") |
|
trellis[-num_tokens:, 0] = float("inf") |
|
for t in range(num_frame): |
|
trellis[t + 1, 1:] = torch.maximum( |
|
|
|
trellis[t, 1:] + emission[t, blank_id], |
|
|
|
trellis[t, :-1] + emission[t, tokens], |
|
) |
|
return trellis |
|
|
|
|
|
@dataclass |
|
class Point: |
|
token_index: int |
|
time_index: int |
|
score: float |
|
|
|
@dataclass |
|
class Segment: |
|
label: str |
|
start: int |
|
end: int |
|
score: float |
|
|
|
@property |
|
def mfaform(self): |
|
return f"{f2s(self.start)},{f2s(self.end)},{self.label}" |
|
|
|
@property |
|
def length(self): |
|
return self.end - self.start |
|
|
|
|
|
|
|
def backtrack(trellis, emission, tokens, blank_id): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
j = trellis.size(1) - 1 |
|
t_start = torch.argmax(trellis[:, j]).item() |
|
|
|
path = [] |
|
for t in range(t_start, 0, -1): |
|
|
|
|
|
|
|
stayed = trellis[t - 1, j] + emission[t - 1, blank_id] |
|
|
|
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] |
|
|
|
|
|
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() |
|
|
|
path.append(Point(j - 1, t - 1, prob)) |
|
|
|
|
|
if changed > stayed: |
|
j -= 1 |
|
if j == 0: |
|
break |
|
else: |
|
raise ValueError("Failed to align") |
|
return path[::-1] |
|
|
|
|
|
def merge_repeats(path,transcript): |
|
i1, i2 = 0, 0 |
|
segments = [] |
|
while i1 < len(path): |
|
while i2 < len(path) and path[i1].token_index == path[i2].token_index: |
|
i2 += 1 |
|
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) |
|
segments.append( |
|
Segment( |
|
transcript[path[i1].token_index], |
|
path[i1].time_index, |
|
path[i2 - 1].time_index + 1, |
|
score, |
|
) |
|
) |
|
i1 = i2 |
|
return segments |
|
|
|
|
|
|
|
def merge_words(segments, separator): |
|
words = [] |
|
i1, i2 = 0, 0 |
|
while i1 < len(segments): |
|
if i2 >= len(segments) or segments[i2].label == separator: |
|
if i1 != i2: |
|
segs = segments[i1:i2] |
|
word = "".join([seg.label for seg in segs]) |
|
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) |
|
words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score)) |
|
i1 = i2 + 1 |
|
i2 = i1 |
|
else: |
|
i2 += 1 |
|
return words |
|
|
|
|
|
|
|
|
|
|
|
|
|
def readwav(wav_path): |
|
wav, sr = sf.read(wav_path, dtype=np.float32) |
|
if len(wav.shape) == 2: |
|
wav = wav.mean(1) |
|
if sr != 16000: |
|
wlen = int(wav.shape[0] / sr * 16000) |
|
wav = signal.resample(wav, wlen) |
|
return wav |
|
|
|
|
|
|
|
def mfalike(chars,wds,wsep): |
|
hed = ['Begin,End,Label,Type,Speaker\n'] |
|
wlines = [f'{w.mfaform},words,000\n' for w in wds] |
|
slines = [f'{ch.mfaform},phones,000\n' for ch in chars if ch.label != wsep] |
|
return (''.join(hed+wlines+slines)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def prep_transcript(xcp,lang): |
|
xcp = xcp.lower() |
|
while ' ' in xcp: |
|
xcp = xcp.replace(' ', ' ') |
|
xcp = xcp.replace(' ',d[lang]['model_word_separator']) |
|
label_ids = [d[lang]['labels_dict'][c] for c in xcp] |
|
label_ids = [d[lang]['blank_id']] + label_ids + [d[lang]['blank_id']] |
|
xcp = f"{d[lang]['model_word_separator']}{xcp}{d[lang]['model_word_separator']}" |
|
return xcp, label_ids |
|
|
|
|
|
def langsalign(wav_path,transcript_string,lang): |
|
norm_txt, rec_label_ids = prep_transcript(transcript_string, lang) |
|
emit = get_frame_probs(wav_path, lang) |
|
trellis = get_trellis(emit, rec_label_ids, d[lang]['blank_id']) |
|
path = backtrack(trellis, emit, rec_label_ids, d[lang]['blank_id']) |
|
segments = merge_repeats(path,norm_txt) |
|
words = merge_words(segments, d[lang]['model_word_separator']) |
|
|
|
|
|
print(segments) |
|
return mfalike(segments,words,d[lang]['model_word_separator']) |
|
|
|
|